diff --git a/.gitignore b/.gitignore
index 39ae5706..9806457b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,3 +2,4 @@
.DS_Store
*.swp
*.swo
+beego.iml
diff --git a/.travis.yml b/.travis.yml
new file mode 100644
index 00000000..c59cef61
--- /dev/null
+++ b/.travis.yml
@@ -0,0 +1,30 @@
+language: go
+
+go:
+ - 1.5.1
+
+services:
+ - redis-server
+ - mysql
+ - postgresql
+ - memcached
+env:
+ - ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db
+ - ORM_DRIVER=mysql ORM_SOURCE="root:@/orm_test?charset=utf8"
+ - ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
+install:
+ - go get github.com/lib/pq
+ - go get github.com/go-sql-driver/mysql
+ - go get github.com/mattn/go-sqlite3
+ - go get github.com/bradfitz/gomemcache/memcache
+ - go get github.com/garyburd/redigo/redis
+ - go get github.com/beego/x2j
+ - go get github.com/beego/goyaml2
+ - go get github.com/belogik/goes
+ - go get github.com/couchbase/go-couchbase
+ - go get github.com/siddontang/ledisdb/config
+ - go get github.com/siddontang/ledisdb/ledis
+before_script:
+ - sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi"
+ - sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi"
+ - sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi"
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 00000000..9d511616
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,52 @@
+# Contributing to beego
+
+beego is an open source project.
+
+It is the work of hundreds of contributors. We appreciate your help!
+
+Here are instructions to get you started. They are probably not perfect,
+please let us know if anything feels wrong or incomplete.
+
+## Contribution guidelines
+
+### Pull requests
+
+First of all. beego follow the gitflow. So please send you pull request
+to **develop** branch. We will close the pull request to master branch.
+
+We are always happy to receive pull requests, and do our best to
+review them as fast as possible. Not sure if that typo is worth a pull
+request? Do it! We will appreciate it.
+
+If your pull request is not accepted on the first try, don't be
+discouraged! Sometimes we can make a mistake, please do more explaining
+for us. We will appreciate it.
+
+We're trying very hard to keep beego simple and fast. We don't want it
+to do everything for everybody. This means that we might decide against
+incorporating a new feature. But we will give you some advice on how to
+do it in other way.
+
+### Create issues
+
+Any significant improvement should be documented as [a GitHub
+issue](https://github.com/astaxie/beego/issues) before anybody
+starts working on it.
+
+Also when filing an issue, make sure to answer these five questions:
+
+- What version of beego are you using (bee version)?
+- What operating system and processor architecture are you using?
+- What did you do?
+- What did you expect to see?
+- What did you see instead?
+
+### but check existing issues and docs first!
+
+Please take a moment to check that an issue doesn't already exist
+documenting your bug report or improvement proposal. If it does, it
+never hurts to add a quick "+1" or "I have this problem too". This will
+help prioritize the most common problems and requests.
+
+Also if you don't know how to use it. please make sure you have read though
+the docs in http://beego.me/docs
\ No newline at end of file
diff --git a/README.md b/README.md
index 7b650887..fec6113f 100644
--- a/README.md
+++ b/README.md
@@ -1,16 +1,38 @@
## Beego
-[![Build Status](https://drone.io/github.com/astaxie/beego/status.png)](https://drone.io/github.com/astaxie/beego/latest)
+[![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego)
[![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego)
-beego is an open-source, high-performance, modular, full-stack web framework.
+beego is used for rapid development of RESTful APIs, web apps and backend services in Go.
+It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct embedding.
More info [beego.me](http://beego.me)
-## Installation
+##Quick Start
+######Download and install
go get github.com/astaxie/beego
+######Create file `hello.go`
+```go
+package main
+
+import "github.com/astaxie/beego"
+
+func main(){
+ beego.Run()
+}
+```
+######Build and run
+```bash
+ go build hello.go
+ ./hello
+```
+######Congratulations!
+You just built your first beego app.
+Open your browser and visit `http://localhost:8000`.
+Please see [Documentation](http://beego.me/docs) for more.
+
## Features
* RESTful support
@@ -26,6 +48,7 @@ More info [beego.me](http://beego.me)
* [English](http://beego.me/docs/intro/)
* [中文文档](http://beego.me/docs/intro/)
+* [Русский](http://beego.me/docs/intro/)
## Community
@@ -33,5 +56,5 @@ More info [beego.me](http://beego.me)
## LICENSE
-beego is licensed under the Apache Licence, Version 2.0
+beego source code is licensed under the Apache Licence, Version 2.0
(http://www.apache.org/licenses/LICENSE-2.0.html).
diff --git a/admin.go b/admin.go
index 64d7fe34..3effc582 100644
--- a/admin.go
+++ b/admin.go
@@ -65,24 +65,15 @@ func init() {
// AdminIndex is the default http.Handler for admin module.
// it matches url pattern "/".
func adminIndex(rw http.ResponseWriter, r *http.Request) {
- tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
- tmpl = template.Must(tmpl.Parse(indexTpl))
- tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
- data := make(map[interface{}]interface{})
- tmpl.Execute(rw, data)
+ execTpl(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl)
}
// QpsIndex is the http.Handler for writing qbs statistics map result info in http.ResponseWriter.
// it's registered with url pattern "/qbs" in admin module.
func qpsIndex(rw http.ResponseWriter, r *http.Request) {
- tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
- tmpl = template.Must(tmpl.Parse(qpsTpl))
- tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
data := make(map[interface{}]interface{})
data["Content"] = toolbox.StatisticsMap.GetMap()
-
- tmpl.Execute(rw, data)
-
+ execTpl(rw, data, qpsTpl, defaultScriptsTpl)
}
// ListConf is the http.Handler of displaying all beego configuration values as key/value pair.
@@ -90,178 +81,145 @@ func qpsIndex(rw http.ResponseWriter, r *http.Request) {
func listConf(rw http.ResponseWriter, r *http.Request) {
r.ParseForm()
command := r.Form.Get("command")
- if command != "" {
- data := make(map[interface{}]interface{})
- switch command {
- case "conf":
- m := make(map[string]interface{})
+ if command == "" {
+ rw.Write([]byte("command not support"))
+ return
+ }
- m["AppName"] = AppName
- m["AppPath"] = AppPath
- m["AppConfigPath"] = AppConfigPath
- m["StaticDir"] = StaticDir
- m["StaticExtensionsToGzip"] = StaticExtensionsToGzip
- m["HttpAddr"] = HttpAddr
- m["HttpPort"] = HttpPort
- m["HttpTLS"] = EnableHttpTLS
- m["HttpCertFile"] = HttpCertFile
- m["HttpKeyFile"] = HttpKeyFile
- m["RecoverPanic"] = RecoverPanic
- m["AutoRender"] = AutoRender
- m["ViewsPath"] = ViewsPath
- m["RunMode"] = RunMode
- m["SessionOn"] = SessionOn
- m["SessionProvider"] = SessionProvider
- m["SessionName"] = SessionName
- m["SessionGCMaxLifetime"] = SessionGCMaxLifetime
- m["SessionSavePath"] = SessionSavePath
- m["SessionCookieLifeTime"] = SessionCookieLifeTime
- m["UseFcgi"] = UseFcgi
- m["MaxMemory"] = MaxMemory
- m["EnableGzip"] = EnableGzip
- m["DirectoryIndex"] = DirectoryIndex
- m["HttpServerTimeOut"] = HttpServerTimeOut
- m["ErrorsShow"] = ErrorsShow
- m["XSRFKEY"] = XSRFKEY
- m["EnableXSRF"] = EnableXSRF
- m["XSRFExpire"] = XSRFExpire
- m["CopyRequestBody"] = CopyRequestBody
- m["TemplateLeft"] = TemplateLeft
- m["TemplateRight"] = TemplateRight
- m["BeegoServerName"] = BeegoServerName
- m["EnableAdmin"] = EnableAdmin
- m["AdminHttpAddr"] = AdminHttpAddr
- m["AdminHttpPort"] = AdminHttpPort
+ data := make(map[interface{}]interface{})
+ switch command {
+ case "conf":
+ m := make(map[string]interface{})
+ m["AppConfigPath"] = AppConfigPath
+ m["AppConfigProvider"] = AppConfigProvider
+ m["BConfig.AppName"] = BConfig.AppName
+ m["BConfig.RunMode"] = BConfig.RunMode
+ m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive
+ m["BConfig.ServerName"] = BConfig.ServerName
+ m["BConfig.RecoverPanic"] = BConfig.RecoverPanic
+ m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody
+ m["BConfig.EnableGzip"] = BConfig.EnableGzip
+ m["BConfig.MaxMemory"] = BConfig.MaxMemory
+ m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow
+ m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful
+ m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut
+ m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4
+ m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP
+ m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr
+ m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort
+ m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS
+ m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr
+ m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort
+ m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile
+ m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile
+ m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin
+ m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr
+ m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort
+ m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi
+ m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo
+ m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender
+ m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs
+ m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName
+ m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator
+ m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex
+ m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir
+ m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip
+ m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft
+ m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight
+ m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath
+ m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF
+ m["BConfig.WebConfig.XSRFKEY"] = BConfig.WebConfig.XSRFKey
+ m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire
+ m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn
+ m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider
+ m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName
+ m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime
+ m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig
+ m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime
+ m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie
+ m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain
+ m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs
+ m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum
+ m["BConfig.Log.Outputs"] = BConfig.Log.Outputs
+ tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
+ tmpl = template.Must(tmpl.Parse(configTpl))
+ tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
- tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
- tmpl = template.Must(tmpl.Parse(configTpl))
- tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
+ data["Content"] = m
- data["Content"] = m
+ tmpl.Execute(rw, data)
- tmpl.Execute(rw, data)
-
- case "router":
- content := make(map[string]interface{})
-
- var fields = []string{
- fmt.Sprintf("Router Pattern"),
- fmt.Sprintf("Methods"),
- fmt.Sprintf("Controller"),
+ case "router":
+ var (
+ content = map[string]interface{}{
+ "Fields": []string{
+ "Router Pattern",
+ "Methods",
+ "Controller",
+ },
}
- content["Fields"] = fields
+ methods = []string{}
+ methodsData = make(map[string]interface{})
+ )
+ for method, t := range BeeApp.Handlers.routers {
- methods := []string{}
- methodsData := make(map[string]interface{})
- for method, t := range BeeApp.Handlers.routers {
+ resultList := new([][]string)
- resultList := new([][]string)
+ printTree(resultList, t)
- printTree(resultList, t)
-
- methods = append(methods, method)
- methodsData[method] = resultList
- }
-
- content["Data"] = methodsData
- content["Methods"] = methods
- data["Content"] = content
- data["Title"] = "Routers"
-
- tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
- tmpl = template.Must(tmpl.Parse(routerAndFilterTpl))
- tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
- tmpl.Execute(rw, data)
- case "filter":
- content := make(map[string]interface{})
-
- var fields = []string{
- fmt.Sprintf("Router Pattern"),
- fmt.Sprintf("Filter Function"),
- }
- content["Fields"] = fields
-
- filterTypes := []string{}
- filterTypeData := make(map[string]interface{})
-
- if BeeApp.Handlers.enableFilter {
- var filterType string
-
- if bf, ok := BeeApp.Handlers.filters[BeforeRouter]; ok {
- filterType = "Before Router"
- filterTypes = append(filterTypes, filterType)
- resultList := new([][]string)
- for _, f := range bf {
-
- var result = []string{
- fmt.Sprintf("%s", f.pattern),
- fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)),
- }
- *resultList = append(*resultList, result)
- }
- filterTypeData[filterType] = resultList
- }
-
- if bf, ok := BeeApp.Handlers.filters[BeforeExec]; ok {
- filterType = "Before Exec"
- filterTypes = append(filterTypes, filterType)
- resultList := new([][]string)
- for _, f := range bf {
-
- var result = []string{
- fmt.Sprintf("%s", f.pattern),
- fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)),
- }
- *resultList = append(*resultList, result)
- }
- filterTypeData[filterType] = resultList
- }
-
- if bf, ok := BeeApp.Handlers.filters[AfterExec]; ok {
- filterType = "After Exec"
- filterTypes = append(filterTypes, filterType)
- resultList := new([][]string)
- for _, f := range bf {
-
- var result = []string{
- fmt.Sprintf("%s", f.pattern),
- fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)),
- }
- *resultList = append(*resultList, result)
- }
- filterTypeData[filterType] = resultList
- }
-
- if bf, ok := BeeApp.Handlers.filters[FinishRouter]; ok {
- filterType = "Finish Router"
- filterTypes = append(filterTypes, filterType)
- resultList := new([][]string)
- for _, f := range bf {
-
- var result = []string{
- fmt.Sprintf("%s", f.pattern),
- fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)),
- }
- *resultList = append(*resultList, result)
- }
- filterTypeData[filterType] = resultList
- }
- }
-
- content["Data"] = filterTypeData
- content["Methods"] = filterTypes
-
- data["Content"] = content
- data["Title"] = "Filters"
- tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
- tmpl = template.Must(tmpl.Parse(routerAndFilterTpl))
- tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
- tmpl.Execute(rw, data)
-
- default:
- rw.Write([]byte("command not support"))
+ methods = append(methods, method)
+ methodsData[method] = resultList
}
- } else {
+
+ content["Data"] = methodsData
+ content["Methods"] = methods
+ data["Content"] = content
+ data["Title"] = "Routers"
+ execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl)
+ case "filter":
+ var (
+ content = map[string]interface{}{
+ "Fields": []string{
+ "Router Pattern",
+ "Filter Function",
+ },
+ }
+ filterTypes = []string{}
+ filterTypeData = make(map[string]interface{})
+ )
+
+ if BeeApp.Handlers.enableFilter {
+ var filterType string
+ for k, fr := range map[int]string{
+ BeforeStatic: "Before Static",
+ BeforeRouter: "Before Router",
+ BeforeExec: "Before Exec",
+ AfterExec: "After Exec",
+ FinishRouter: "Finish Router"} {
+ if bf, ok := BeeApp.Handlers.filters[k]; ok {
+ filterType = fr
+ filterTypes = append(filterTypes, filterType)
+ resultList := new([][]string)
+ for _, f := range bf {
+ var result = []string{
+ fmt.Sprintf("%s", f.pattern),
+ fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)),
+ }
+ *resultList = append(*resultList, result)
+ }
+ filterTypeData[filterType] = resultList
+ }
+ }
+ }
+
+ content["Data"] = filterTypeData
+ content["Methods"] = filterTypes
+
+ data["Content"] = content
+ data["Title"] = "Filters"
+ execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl)
+ default:
+ rw.Write([]byte("command not support"))
}
}
@@ -276,23 +234,23 @@ func printTree(resultList *[][]string, t *Tree) {
if v, ok := l.runObject.(*controllerInfo); ok {
if v.routerType == routerTypeBeego {
var result = []string{
- fmt.Sprintf("%s", v.pattern),
+ v.pattern,
fmt.Sprintf("%s", v.methods),
fmt.Sprintf("%s", v.controllerType),
}
*resultList = append(*resultList, result)
} else if v.routerType == routerTypeRESTFul {
var result = []string{
- fmt.Sprintf("%s", v.pattern),
+ v.pattern,
fmt.Sprintf("%s", v.methods),
- fmt.Sprintf(""),
+ "",
}
*resultList = append(*resultList, result)
} else if v.routerType == routerTypeHandler {
var result = []string{
- fmt.Sprintf("%s", v.pattern),
- fmt.Sprintf(""),
- fmt.Sprintf(""),
+ v.pattern,
+ "",
+ "",
}
*resultList = append(*resultList, result)
}
@@ -305,53 +263,49 @@ func printTree(resultList *[][]string, t *Tree) {
func profIndex(rw http.ResponseWriter, r *http.Request) {
r.ParseForm()
command := r.Form.Get("command")
- format := r.Form.Get("format")
- data := make(map[string]interface{})
+ if command == "" {
+ return
+ }
- var result bytes.Buffer
- if command != "" {
- toolbox.ProcessInput(command, &result)
- data["Content"] = result.String()
+ var (
+ format = r.Form.Get("format")
+ data = make(map[interface{}]interface{})
+ result bytes.Buffer
+ )
+ toolbox.ProcessInput(command, &result)
+ data["Content"] = result.String()
- if format == "json" && command == "gc summary" {
- dataJson, err := json.Marshal(data)
- if err != nil {
- http.Error(rw, err.Error(), http.StatusInternalServerError)
- return
- }
-
- rw.Header().Set("Content-Type", "application/json")
- rw.Write(dataJson)
+ if format == "json" && command == "gc summary" {
+ dataJSON, err := json.Marshal(data)
+ if err != nil {
+ http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
- data["Title"] = command
- tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
- tmpl = template.Must(tmpl.Parse(profillingTpl))
- if command == "gc summary" {
- tmpl = template.Must(tmpl.Parse(gcAjaxTpl))
- } else {
-
- tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
- }
- tmpl.Execute(rw, data)
+ rw.Header().Set("Content-Type", "application/json")
+ rw.Write(dataJSON)
+ return
}
+
+ data["Title"] = command
+ defaultTpl := defaultScriptsTpl
+ if command == "gc summary" {
+ defaultTpl = gcAjaxTpl
+ }
+ execTpl(rw, data, profillingTpl, defaultTpl)
}
// Healthcheck is a http.Handler calling health checking and showing the result.
// it's in "/healthcheck" pattern in admin module.
func healthcheck(rw http.ResponseWriter, req *http.Request) {
- data := make(map[interface{}]interface{})
-
- var result = []string{}
- fields := []string{
- fmt.Sprintf("Name"),
- fmt.Sprintf("Message"),
- fmt.Sprintf("Status"),
- }
- resultList := new([][]string)
-
- content := make(map[string]interface{})
+ var (
+ data = make(map[interface{}]interface{})
+ result = []string{}
+ resultList = new([][]string)
+ content = map[string]interface{}{
+ "Fields": []string{"Name", "Message", "Status"},
+ }
+ )
for name, h := range toolbox.AdminCheckList {
if err := h.Check(); err != nil {
@@ -371,16 +325,10 @@ func healthcheck(rw http.ResponseWriter, req *http.Request) {
}
*resultList = append(*resultList, result)
}
-
- content["Fields"] = fields
content["Data"] = resultList
data["Content"] = content
data["Title"] = "Health Check"
- tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
- tmpl = template.Must(tmpl.Parse(healthCheckTpl))
- tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
- tmpl.Execute(rw, data)
-
+ execTpl(rw, data, healthCheckTpl, defaultScriptsTpl)
}
// TaskStatus is a http.Handler with running task status (task name, status and the last execution).
@@ -392,10 +340,8 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) {
req.ParseForm()
taskname := req.Form.Get("taskname")
if taskname != "" {
-
if t, ok := toolbox.AdminTaskList[taskname]; ok {
- err := t.Run()
- if err != nil {
+ if err := t.Run(); err != nil {
data["Message"] = []string{"error", fmt.Sprintf("%s", err)}
}
data["Message"] = []string{"success", fmt.Sprintf("%s run success,Now the Status is %s", taskname, t.GetStatus())}
@@ -409,18 +355,18 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) {
resultList := new([][]string)
var result = []string{}
var fields = []string{
- fmt.Sprintf("Task Name"),
- fmt.Sprintf("Task Spec"),
- fmt.Sprintf("Task Status"),
- fmt.Sprintf("Last Time"),
- fmt.Sprintf(""),
+ "Task Name",
+ "Task Spec",
+ "Task Status",
+ "Last Time",
+ "",
}
for tname, tk := range toolbox.AdminTaskList {
result = []string{
- fmt.Sprintf("%s", tname),
+ tname,
fmt.Sprintf("%s", tk.GetSpec()),
fmt.Sprintf("%s", tk.GetStatus()),
- fmt.Sprintf("%s", tk.GetPrev().String()),
+ tk.GetPrev().String(),
}
*resultList = append(*resultList, result)
}
@@ -429,9 +375,14 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) {
content["Data"] = resultList
data["Content"] = content
data["Title"] = "Tasks"
+ execTpl(rw, data, tasksTpl, defaultScriptsTpl)
+}
+
+func execTpl(rw http.ResponseWriter, data map[interface{}]interface{}, tpls ...string) {
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
- tmpl = template.Must(tmpl.Parse(tasksTpl))
- tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
+ for _, tpl := range tpls {
+ tmpl = template.Must(tmpl.Parse(tpl))
+ }
tmpl.Execute(rw, data)
}
@@ -451,10 +402,10 @@ func (admin *adminApp) Run() {
if len(toolbox.AdminTaskList) > 0 {
toolbox.StartTask()
}
- addr := AdminHttpAddr
+ addr := BConfig.Listen.AdminAddr
- if AdminHttpPort != 0 {
- addr = fmt.Sprintf("%s:%d", AdminHttpAddr, AdminHttpPort)
+ if BConfig.Listen.AdminPort != 0 {
+ addr = fmt.Sprintf("%s:%d", BConfig.Listen.AdminAddr, BConfig.Listen.AdminPort)
}
for p, f := range admin.routers {
http.Handle(p, f)
@@ -462,7 +413,7 @@ func (admin *adminApp) Run() {
BeeLogger.Info("Admin server Running on %s", addr)
var err error
- if Graceful {
+ if BConfig.Listen.Graceful {
err = grace.ListenAndServe(addr, nil)
} else {
err = http.ListenAndServe(addr, nil)
diff --git a/app.go b/app.go
index 8fc320ad..af54ea4b 100644
--- a/app.go
+++ b/app.go
@@ -20,15 +20,26 @@ import (
"net/http"
"net/http/fcgi"
"os"
+ "path"
"time"
"github.com/astaxie/beego/grace"
"github.com/astaxie/beego/utils"
)
+var (
+ // BeeApp is an application instance
+ BeeApp *App
+)
+
+func init() {
+ // create beego application
+ BeeApp = NewApp()
+}
+
// App defines beego application with a new PatternServeMux.
type App struct {
- Handlers *ControllerRegistor
+ Handlers *ControllerRegister
Server *http.Server
}
@@ -41,132 +52,311 @@ func NewApp() *App {
// Run beego application.
func (app *App) Run() {
- addr := HttpAddr
+ addr := BConfig.Listen.HTTPAddr
- if HttpPort != 0 {
- addr = fmt.Sprintf("%s:%d", HttpAddr, HttpPort)
+ if BConfig.Listen.HTTPPort != 0 {
+ addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPAddr, BConfig.Listen.HTTPPort)
}
var (
- err error
- l net.Listener
+ err error
+ l net.Listener
+ endRunning = make(chan bool, 1)
)
- endRunning := make(chan bool, 1)
- if UseFcgi {
- if UseStdIo {
- err = fcgi.Serve(nil, app.Handlers) // standard I/O
- if err == nil {
+ // run cgi server
+ if BConfig.Listen.EnableFcgi {
+ if BConfig.Listen.EnableStdIo {
+ if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O
BeeLogger.Info("Use FCGI via standard I/O")
} else {
- BeeLogger.Info("Cannot use FCGI via standard I/O", err)
+ BeeLogger.Critical("Cannot use FCGI via standard I/O", err)
}
+ return
+ }
+ if BConfig.Listen.HTTPPort == 0 {
+ // remove the Socket file before start
+ if utils.FileExists(addr) {
+ os.Remove(addr)
+ }
+ l, err = net.Listen("unix", addr)
} else {
- if HttpPort == 0 {
- // remove the Socket file before start
- if utils.FileExists(addr) {
- os.Remove(addr)
+ l, err = net.Listen("tcp", addr)
+ }
+ if err != nil {
+ BeeLogger.Critical("Listen: ", err)
+ }
+ if err = fcgi.Serve(l, app.Handlers); err != nil {
+ BeeLogger.Critical("fcgi.Serve: ", err)
+ }
+ return
+ }
+
+ app.Server.Handler = app.Handlers
+ app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second
+ app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second
+
+ // run graceful mode
+ if BConfig.Listen.Graceful {
+ httpsAddr := BConfig.Listen.HTTPSAddr
+ app.Server.Addr = httpsAddr
+ if BConfig.Listen.EnableHTTPS {
+ go func() {
+ time.Sleep(20 * time.Microsecond)
+ if BConfig.Listen.HTTPSPort != 0 {
+ httpsAddr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort)
+ app.Server.Addr = httpsAddr
+ }
+ server := grace.NewServer(httpsAddr, app.Handlers)
+ server.Server.ReadTimeout = app.Server.ReadTimeout
+ server.Server.WriteTimeout = app.Server.WriteTimeout
+ if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
+ BeeLogger.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
+ time.Sleep(100 * time.Microsecond)
+ endRunning <- true
+ }
+ }()
+ }
+ if BConfig.Listen.EnableHTTP {
+ go func() {
+ server := grace.NewServer(addr, app.Handlers)
+ server.Server.ReadTimeout = app.Server.ReadTimeout
+ server.Server.WriteTimeout = app.Server.WriteTimeout
+ if BConfig.Listen.ListenTCP4 {
+ server.Network = "tcp4"
+ }
+ if err := server.ListenAndServe(); err != nil {
+ BeeLogger.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid()))
+ time.Sleep(100 * time.Microsecond)
+ endRunning <- true
+ }
+ }()
+ }
+ <-endRunning
+ return
+ }
+
+ // run normal mode
+ app.Server.Addr = addr
+ if BConfig.Listen.EnableHTTPS {
+ go func() {
+ time.Sleep(20 * time.Microsecond)
+ if BConfig.Listen.HTTPSPort != 0 {
+ app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort)
+ }
+ BeeLogger.Info("https server Running on %s", app.Server.Addr)
+ if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
+ BeeLogger.Critical("ListenAndServeTLS: ", err)
+ time.Sleep(100 * time.Microsecond)
+ endRunning <- true
+ }
+ }()
+ }
+ if BConfig.Listen.EnableHTTP {
+ go func() {
+ app.Server.Addr = addr
+ BeeLogger.Info("http server Running on %s", app.Server.Addr)
+ if BConfig.Listen.ListenTCP4 {
+ ln, err := net.Listen("tcp4", app.Server.Addr)
+ if err != nil {
+ BeeLogger.Critical("ListenAndServe: ", err)
+ time.Sleep(100 * time.Microsecond)
+ endRunning <- true
+ return
+ }
+ if err = app.Server.Serve(ln); err != nil {
+ BeeLogger.Critical("ListenAndServe: ", err)
+ time.Sleep(100 * time.Microsecond)
+ endRunning <- true
+ return
}
- l, err = net.Listen("unix", addr)
} else {
- l, err = net.Listen("tcp", addr)
+ if err := app.Server.ListenAndServe(); err != nil {
+ BeeLogger.Critical("ListenAndServe: ", err)
+ time.Sleep(100 * time.Microsecond)
+ endRunning <- true
+ }
}
- if err != nil {
- BeeLogger.Critical("Listen: ", err)
- }
- err = fcgi.Serve(l, app.Handlers)
- }
- } else {
- if Graceful {
- app.Server.Addr = addr
- app.Server.Handler = app.Handlers
- app.Server.ReadTimeout = time.Duration(HttpServerTimeOut) * time.Second
- app.Server.WriteTimeout = time.Duration(HttpServerTimeOut) * time.Second
- if EnableHttpTLS {
- go func() {
- time.Sleep(20 * time.Microsecond)
- if HttpsPort != 0 {
- addr = fmt.Sprintf("%s:%d", HttpAddr, HttpsPort)
- app.Server.Addr = addr
- }
- server := grace.NewServer(addr, app.Handlers)
- server.Server = app.Server
- err := server.ListenAndServeTLS(HttpCertFile, HttpKeyFile)
- if err != nil {
- BeeLogger.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
- time.Sleep(100 * time.Microsecond)
- endRunning <- true
- }
- }()
- }
- if EnableHttpListen {
- go func() {
- server := grace.NewServer(addr, app.Handlers)
- server.Server = app.Server
- if ListenTCP4 && HttpAddr == "" {
- server.Network = "tcp4"
- }
- err := server.ListenAndServe()
- if err != nil {
- BeeLogger.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid()))
- time.Sleep(100 * time.Microsecond)
- endRunning <- true
- }
- }()
- }
- } else {
- app.Server.Addr = addr
- app.Server.Handler = app.Handlers
- app.Server.ReadTimeout = time.Duration(HttpServerTimeOut) * time.Second
- app.Server.WriteTimeout = time.Duration(HttpServerTimeOut) * time.Second
-
- if EnableHttpTLS {
- go func() {
- time.Sleep(20 * time.Microsecond)
- if HttpsPort != 0 {
- app.Server.Addr = fmt.Sprintf("%s:%d", HttpAddr, HttpsPort)
- }
- BeeLogger.Info("https server Running on %s", app.Server.Addr)
- err := app.Server.ListenAndServeTLS(HttpCertFile, HttpKeyFile)
- if err != nil {
- BeeLogger.Critical("ListenAndServeTLS: ", err)
- time.Sleep(100 * time.Microsecond)
- endRunning <- true
- }
- }()
- }
-
- if EnableHttpListen {
- go func() {
- app.Server.Addr = addr
- BeeLogger.Info("http server Running on %s", app.Server.Addr)
- if ListenTCP4 && HttpAddr == "" {
- ln, err := net.Listen("tcp4", app.Server.Addr)
- if err != nil {
- BeeLogger.Critical("ListenAndServe: ", err)
- time.Sleep(100 * time.Microsecond)
- endRunning <- true
- return
- }
- err = app.Server.Serve(ln)
- if err != nil {
- BeeLogger.Critical("ListenAndServe: ", err)
- time.Sleep(100 * time.Microsecond)
- endRunning <- true
- return
- }
- } else {
- err := app.Server.ListenAndServe()
- if err != nil {
- BeeLogger.Critical("ListenAndServe: ", err)
- time.Sleep(100 * time.Microsecond)
- endRunning <- true
- }
- }
- }()
- }
- }
-
+ }()
}
<-endRunning
}
+
+// Router adds a patterned controller handler to BeeApp.
+// it's an alias method of App.Router.
+// usage:
+// simple router
+// beego.Router("/admin", &admin.UserController{})
+// beego.Router("/admin/index", &admin.ArticleController{})
+//
+// regex router
+//
+// beego.Router("/api/:id([0-9]+)", &controllers.RController{})
+//
+// custom rules
+// beego.Router("/api/list",&RestController{},"*:ListFood")
+// beego.Router("/api/create",&RestController{},"post:CreateFood")
+// beego.Router("/api/update",&RestController{},"put:UpdateFood")
+// beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
+func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App {
+ BeeApp.Handlers.Add(rootpath, c, mappingMethods...)
+ return BeeApp
+}
+
+// Include will generate router file in the router/xxx.go from the controller's comments
+// usage:
+// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{})
+// type BankAccount struct{
+// beego.Controller
+// }
+//
+// register the function
+// func (b *BankAccount)Mapping(){
+// b.Mapping("ShowAccount" , b.ShowAccount)
+// b.Mapping("ModifyAccount", b.ModifyAccount)
+//}
+//
+// //@router /account/:id [get]
+// func (b *BankAccount) ShowAccount(){
+// //logic
+// }
+//
+//
+// //@router /account/:id [post]
+// func (b *BankAccount) ModifyAccount(){
+// //logic
+// }
+//
+// the comments @router url methodlist
+// url support all the function Router's pattern
+// methodlist [get post head put delete options *]
+func Include(cList ...ControllerInterface) *App {
+ BeeApp.Handlers.Include(cList...)
+ return BeeApp
+}
+
+// RESTRouter adds a restful controller handler to BeeApp.
+// its' controller implements beego.ControllerInterface and
+// defines a param "pattern/:objectId" to visit each resource.
+func RESTRouter(rootpath string, c ControllerInterface) *App {
+ Router(rootpath, c)
+ Router(path.Join(rootpath, ":objectId"), c)
+ return BeeApp
+}
+
+// AutoRouter adds defined controller handler to BeeApp.
+// it's same to App.AutoRouter.
+// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page,
+// visit the url /main/list to exec List function or /main/page to exec Page function.
+func AutoRouter(c ControllerInterface) *App {
+ BeeApp.Handlers.AddAuto(c)
+ return BeeApp
+}
+
+// AutoPrefix adds controller handler to BeeApp with prefix.
+// it's same to App.AutoRouterWithPrefix.
+// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page,
+// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function.
+func AutoPrefix(prefix string, c ControllerInterface) *App {
+ BeeApp.Handlers.AddAutoPrefix(prefix, c)
+ return BeeApp
+}
+
+// Get used to register router for Get method
+// usage:
+// beego.Get("/", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func Get(rootpath string, f FilterFunc) *App {
+ BeeApp.Handlers.Get(rootpath, f)
+ return BeeApp
+}
+
+// Post used to register router for Post method
+// usage:
+// beego.Post("/api", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func Post(rootpath string, f FilterFunc) *App {
+ BeeApp.Handlers.Post(rootpath, f)
+ return BeeApp
+}
+
+// Delete used to register router for Delete method
+// usage:
+// beego.Delete("/api", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func Delete(rootpath string, f FilterFunc) *App {
+ BeeApp.Handlers.Delete(rootpath, f)
+ return BeeApp
+}
+
+// Put used to register router for Put method
+// usage:
+// beego.Put("/api", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func Put(rootpath string, f FilterFunc) *App {
+ BeeApp.Handlers.Put(rootpath, f)
+ return BeeApp
+}
+
+// Head used to register router for Head method
+// usage:
+// beego.Head("/api", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func Head(rootpath string, f FilterFunc) *App {
+ BeeApp.Handlers.Head(rootpath, f)
+ return BeeApp
+}
+
+// Options used to register router for Options method
+// usage:
+// beego.Options("/api", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func Options(rootpath string, f FilterFunc) *App {
+ BeeApp.Handlers.Options(rootpath, f)
+ return BeeApp
+}
+
+// Patch used to register router for Patch method
+// usage:
+// beego.Patch("/api", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func Patch(rootpath string, f FilterFunc) *App {
+ BeeApp.Handlers.Patch(rootpath, f)
+ return BeeApp
+}
+
+// Any used to register router for all methods
+// usage:
+// beego.Any("/api", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func Any(rootpath string, f FilterFunc) *App {
+ BeeApp.Handlers.Any(rootpath, f)
+ return BeeApp
+}
+
+// Handler used to register a Handler router
+// usage:
+// beego.Handler("/api", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func Handler(rootpath string, h http.Handler, options ...interface{}) *App {
+ BeeApp.Handlers.Handler(rootpath, h, options...)
+ return BeeApp
+}
+
+// InsertFilter adds a FilterFunc with pattern condition and action constant.
+// The pos means action constant including
+// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter.
+// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute)
+func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App {
+ BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...)
+ return BeeApp
+}
diff --git a/beego.go b/beego.go
index cfebfbea..04f02071 100644
--- a/beego.go
+++ b/beego.go
@@ -12,243 +12,35 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// beego is an open-source, high-performance, modularity, full-stack web framework
-//
-// package main
-//
-// import "github.com/astaxie/beego"
-//
-// func main() {
-// beego.Run()
-// }
-//
-// more infomation: http://beego.me
package beego
import (
- "net/http"
+ "fmt"
"os"
- "path"
"path/filepath"
"strconv"
"strings"
-
- "github.com/astaxie/beego/session"
)
-// beego web framework version.
-const VERSION = "1.5.0"
+const (
+ // VERSION represent beego web framework version.
+ VERSION = "1.6.0"
-type hookfunc func() error //hook function to run
-var hooks []hookfunc //hook function slice to store the hookfunc
+ // DEV is for develop
+ DEV = "dev"
+ // PROD is for production
+ PROD = "prod"
+)
-// Router adds a patterned controller handler to BeeApp.
-// it's an alias method of App.Router.
-// usage:
-// simple router
-// beego.Router("/admin", &admin.UserController{})
-// beego.Router("/admin/index", &admin.ArticleController{})
-//
-// regex router
-//
-// beego.Router("/api/:id([0-9]+)", &controllers.RController{})
-//
-// custom rules
-// beego.Router("/api/list",&RestController{},"*:ListFood")
-// beego.Router("/api/create",&RestController{},"post:CreateFood")
-// beego.Router("/api/update",&RestController{},"put:UpdateFood")
-// beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
-func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App {
- BeeApp.Handlers.Add(rootpath, c, mappingMethods...)
- return BeeApp
-}
+//hook function to run
+type hookfunc func() error
-// Router add list from
-// usage:
-// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{})
-// type BankAccount struct{
-// beego.Controller
-// }
-//
-// register the function
-// func (b *BankAccount)Mapping(){
-// b.Mapping("ShowAccount" , b.ShowAccount)
-// b.Mapping("ModifyAccount", b.ModifyAccount)
-//}
-//
-// //@router /account/:id [get]
-// func (b *BankAccount) ShowAccount(){
-// //logic
-// }
-//
-//
-// //@router /account/:id [post]
-// func (b *BankAccount) ModifyAccount(){
-// //logic
-// }
-//
-// the comments @router url methodlist
-// url support all the function Router's pattern
-// methodlist [get post head put delete options *]
-func Include(cList ...ControllerInterface) *App {
- BeeApp.Handlers.Include(cList...)
- return BeeApp
-}
+var (
+ hooks = make([]hookfunc, 0) //hook function slice to store the hookfunc
+)
-// RESTRouter adds a restful controller handler to BeeApp.
-// its' controller implements beego.ControllerInterface and
-// defines a param "pattern/:objectId" to visit each resource.
-func RESTRouter(rootpath string, c ControllerInterface) *App {
- Router(rootpath, c)
- Router(path.Join(rootpath, ":objectId"), c)
- return BeeApp
-}
-
-// AutoRouter adds defined controller handler to BeeApp.
-// it's same to App.AutoRouter.
-// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page,
-// visit the url /main/list to exec List function or /main/page to exec Page function.
-func AutoRouter(c ControllerInterface) *App {
- BeeApp.Handlers.AddAuto(c)
- return BeeApp
-}
-
-// AutoPrefix adds controller handler to BeeApp with prefix.
-// it's same to App.AutoRouterWithPrefix.
-// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page,
-// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function.
-func AutoPrefix(prefix string, c ControllerInterface) *App {
- BeeApp.Handlers.AddAutoPrefix(prefix, c)
- return BeeApp
-}
-
-// register router for Get method
-// usage:
-// beego.Get("/", func(ctx *context.Context){
-// ctx.Output.Body("hello world")
-// })
-func Get(rootpath string, f FilterFunc) *App {
- BeeApp.Handlers.Get(rootpath, f)
- return BeeApp
-}
-
-// register router for Post method
-// usage:
-// beego.Post("/api", func(ctx *context.Context){
-// ctx.Output.Body("hello world")
-// })
-func Post(rootpath string, f FilterFunc) *App {
- BeeApp.Handlers.Post(rootpath, f)
- return BeeApp
-}
-
-// register router for Delete method
-// usage:
-// beego.Delete("/api", func(ctx *context.Context){
-// ctx.Output.Body("hello world")
-// })
-func Delete(rootpath string, f FilterFunc) *App {
- BeeApp.Handlers.Delete(rootpath, f)
- return BeeApp
-}
-
-// register router for Put method
-// usage:
-// beego.Put("/api", func(ctx *context.Context){
-// ctx.Output.Body("hello world")
-// })
-func Put(rootpath string, f FilterFunc) *App {
- BeeApp.Handlers.Put(rootpath, f)
- return BeeApp
-}
-
-// register router for Head method
-// usage:
-// beego.Head("/api", func(ctx *context.Context){
-// ctx.Output.Body("hello world")
-// })
-func Head(rootpath string, f FilterFunc) *App {
- BeeApp.Handlers.Head(rootpath, f)
- return BeeApp
-}
-
-// register router for Options method
-// usage:
-// beego.Options("/api", func(ctx *context.Context){
-// ctx.Output.Body("hello world")
-// })
-func Options(rootpath string, f FilterFunc) *App {
- BeeApp.Handlers.Options(rootpath, f)
- return BeeApp
-}
-
-// register router for Patch method
-// usage:
-// beego.Patch("/api", func(ctx *context.Context){
-// ctx.Output.Body("hello world")
-// })
-func Patch(rootpath string, f FilterFunc) *App {
- BeeApp.Handlers.Patch(rootpath, f)
- return BeeApp
-}
-
-// register router for all method
-// usage:
-// beego.Any("/api", func(ctx *context.Context){
-// ctx.Output.Body("hello world")
-// })
-func Any(rootpath string, f FilterFunc) *App {
- BeeApp.Handlers.Any(rootpath, f)
- return BeeApp
-}
-
-// register router for own Handler
-// usage:
-// beego.Handler("/api", func(ctx *context.Context){
-// ctx.Output.Body("hello world")
-// })
-func Handler(rootpath string, h http.Handler, options ...interface{}) *App {
- BeeApp.Handlers.Handler(rootpath, h, options...)
- return BeeApp
-}
-
-// SetViewsPath sets view directory path in beego application.
-func SetViewsPath(path string) *App {
- ViewsPath = path
- return BeeApp
-}
-
-// SetStaticPath sets static directory path and proper url pattern in beego application.
-// if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public".
-func SetStaticPath(url string, path string) *App {
- if !strings.HasPrefix(url, "/") {
- url = "/" + url
- }
- url = strings.TrimRight(url, "/")
- StaticDir[url] = path
- return BeeApp
-}
-
-// DelStaticPath removes the static folder setting in this url pattern in beego application.
-func DelStaticPath(url string) *App {
- if !strings.HasPrefix(url, "/") {
- url = "/" + url
- }
- url = strings.TrimRight(url, "/")
- delete(StaticDir, url)
- return BeeApp
-}
-
-// InsertFilter adds a FilterFunc with pattern condition and action constant.
-// The pos means action constant including
-// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter.
-// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute)
-func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App {
- BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...)
- return BeeApp
-}
-
-// The hookfunc will run in beego.Run()
+// AddAPPStartHook is used to register the hookfunc
+// The hookfuncs will run in beego.Run()
// such as sessionInit, middlerware start, buildtemplate, admin start
func AddAPPStartHook(hf hookfunc) {
hooks = append(hooks, hf)
@@ -256,97 +48,60 @@ func AddAPPStartHook(hf hookfunc) {
// Run beego application.
// beego.Run() default run on HttpPort
+// beego.Run("localhost")
// beego.Run(":8089")
// beego.Run("127.0.0.1:8089")
func Run(params ...string) {
+ initBeforeHTTPRun()
+
if len(params) > 0 && params[0] != "" {
strs := strings.Split(params[0], ":")
if len(strs) > 0 && strs[0] != "" {
- HttpAddr = strs[0]
+ BConfig.Listen.HTTPAddr = strs[0]
}
if len(strs) > 1 && strs[1] != "" {
- HttpPort, _ = strconv.Atoi(strs[1])
+ BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1])
}
}
- initBeforeHttpRun()
-
- if EnableAdmin {
- go beeAdminApp.Run()
- }
BeeApp.Run()
}
-func initBeforeHttpRun() {
- // if AppConfigPath not In the conf/app.conf reParse config
- if AppConfigPath != filepath.Join(AppPath, "conf", "app.conf") {
- err := ParseConfig()
- if err != nil && AppConfigPath != filepath.Join(workPath, "conf", "app.conf") {
- // configuration is critical to app, panic here if parse failed
- panic(err)
- }
- }
-
- //init mime
- AddAPPStartHook(initMime)
-
- // do hooks function
- for _, hk := range hooks {
- err := hk()
- if err != nil {
- panic(err)
- }
- }
-
- if SessionOn {
- var err error
- sessionConfig := AppConfig.String("sessionConfig")
- if sessionConfig == "" {
- sessionConfig = `{"cookieName":"` + SessionName + `",` +
- `"gclifetime":` + strconv.FormatInt(SessionGCMaxLifetime, 10) + `,` +
- `"providerConfig":"` + filepath.ToSlash(SessionSavePath) + `",` +
- `"secure":` + strconv.FormatBool(EnableHttpTLS) + `,` +
- `"enableSetCookie":` + strconv.FormatBool(SessionAutoSetCookie) + `,` +
- `"domain":"` + SessionDomain + `",` +
- `"cookieLifeTime":` + strconv.Itoa(SessionCookieLifeTime) + `}`
- }
- GlobalSessions, err = session.NewManager(SessionProvider,
- sessionConfig)
- if err != nil {
- panic(err)
- }
- go GlobalSessions.GC()
- }
-
- err := BuildTemplate(ViewsPath)
- if err != nil {
- if RunMode == "dev" {
- Warn(err)
- }
- }
-
- registerDefaultErrorHandler()
-
- if EnableDocs {
- Get("/docs", serverDocs)
- Get("/docs/*", serverDocs)
- }
-}
-
-// this function is for test package init
-func TestBeegoInit(apppath string) {
- AppPath = apppath
- os.Setenv("BEEGO_RUNMODE", "test")
- AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
+func initBeforeHTTPRun() {
+ // if AppConfigPath is setted or conf/app.conf exist
err := ParseConfig()
- if err != nil && !os.IsNotExist(err) {
- // for init if doesn't have app.conf will not panic
- Info(err)
+ if err != nil {
+ panic(err)
+ }
+ //init log
+ for adaptor, config := range BConfig.Log.Outputs {
+ err = BeeLogger.SetLogger(adaptor, config)
+ if err != nil {
+ fmt.Printf("%s with the config `%s` got err:%s\n", adaptor, config, err)
+ }
+ }
+
+ SetLogFuncCall(BConfig.Log.FileLineNum)
+
+ //init hooks
+ AddAPPStartHook(registerMime)
+ AddAPPStartHook(registerDefaultErrorHandler)
+ AddAPPStartHook(registerSession)
+ AddAPPStartHook(registerDocs)
+ AddAPPStartHook(registerTemplate)
+ AddAPPStartHook(registerAdmin)
+
+ for _, hk := range hooks {
+ if err := hk(); err != nil {
+ panic(err)
+ }
}
- os.Chdir(AppPath)
- initBeforeHttpRun()
}
-func init() {
- hooks = make([]hookfunc, 0)
+// TestBeegoInit is for test package init
+func TestBeegoInit(ap string) {
+ os.Setenv("BEEGO_RUNMODE", "test")
+ AppConfigPath = filepath.Join(ap, "conf", "app.conf")
+ os.Chdir(ap)
+ initBeforeHTTPRun()
}
diff --git a/cache/README.md b/cache/README.md
index 72d0d1c5..957790e7 100644
--- a/cache/README.md
+++ b/cache/README.md
@@ -26,7 +26,7 @@ Then init a Cache (example with memory adapter)
Use it like this:
- bm.Put("astaxie", 1, 10)
+ bm.Put("astaxie", 1, 10 * time.Second)
bm.Get("astaxie")
bm.IsExist("astaxie")
bm.Delete("astaxie")
@@ -43,7 +43,7 @@ interval means the gc time. The cache will check at each time interval, whether
## Memcache adapter
-Memcache adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client.
+Memcache adapter use the [gomemcache](http://github.com/bradfitz/gomemcache) client.
Configure like this:
@@ -52,7 +52,7 @@ Configure like this:
## Redis adapter
-Redis adapter use the [redigo](http://github.com/garyburd/redigo/redis) client.
+Redis adapter use the [redigo](http://github.com/garyburd/redigo) client.
Configure like this:
diff --git a/cache/cache.go b/cache/cache.go
index 7ca87802..2008402e 100644
--- a/cache/cache.go
+++ b/cache/cache.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package cache provide a Cache interface and some implemetn engine
// Usage:
//
// import(
@@ -22,7 +23,7 @@
//
// Use it like this:
//
-// bm.Put("astaxie", 1, 10)
+// bm.Put("astaxie", 1, 10 * time.Second)
// bm.Get("astaxie")
// bm.IsExist("astaxie")
// bm.Delete("astaxie")
@@ -32,13 +33,14 @@ package cache
import (
"fmt"
+ "time"
)
// Cache interface contains all behaviors for cache adapter.
// usage:
-// cache.Register("file",cache.NewFileCache()) // this operation is run in init method of file.go.
+// cache.Register("file",cache.NewFileCache) // this operation is run in init method of file.go.
// c,err := cache.NewCache("file","{....}")
-// c.Put("key",value,3600)
+// c.Put("key",value, 3600 * time.Second)
// v := c.Get("key")
//
// c.Incr("counter") // now is 1
@@ -50,7 +52,7 @@ type Cache interface {
// GetMulti is a batch version of Get.
GetMulti(keys []string) []interface{}
// set cached value with key and expire time.
- Put(key string, val interface{}, timeout int64) error
+ Put(key string, val interface{}, timeout time.Duration) error
// delete cached value by key.
Delete(key string) error
// increase cached int value by key, as a counter.
@@ -65,12 +67,14 @@ type Cache interface {
StartAndGC(config string) error
}
-var adapters = make(map[string]Cache)
+type CacheInstance func() Cache
+
+var adapters = make(map[string]CacheInstance)
// Register makes a cache adapter available by the adapter name.
// If Register is called twice with the same name or if driver is nil,
// it panics.
-func Register(name string, adapter Cache) {
+func Register(name string, adapter CacheInstance) {
if adapter == nil {
panic("cache: Register adapter is nil")
}
@@ -80,15 +84,16 @@ func Register(name string, adapter Cache) {
adapters[name] = adapter
}
-// Create a new cache driver by adapter name and config string.
+// NewCache Create a new cache driver by adapter name and config string.
// config need to be correct JSON as string: {"interval":360}.
// it will start gc automatically.
func NewCache(adapterName, config string) (adapter Cache, err error) {
- adapter, ok := adapters[adapterName]
+ instanceFunc, ok := adapters[adapterName]
if !ok {
err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName)
return
}
+ adapter = instanceFunc()
err = adapter.StartAndGC(config)
if err != nil {
adapter = nil
diff --git a/cache/cache_test.go b/cache/cache_test.go
index 481309fd..9ceb606a 100644
--- a/cache/cache_test.go
+++ b/cache/cache_test.go
@@ -25,7 +25,8 @@ func TestCache(t *testing.T) {
if err != nil {
t.Error("init err")
}
- if err = bm.Put("astaxie", 1, 10); err != nil {
+ timeoutDuration := 10 * time.Second
+ if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
@@ -42,7 +43,7 @@ func TestCache(t *testing.T) {
t.Error("check err")
}
- if err = bm.Put("astaxie", 1, 10); err != nil {
+ if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err)
}
@@ -67,7 +68,7 @@ func TestCache(t *testing.T) {
}
//test GetMulti
- if err = bm.Put("astaxie", "author", 10); err != nil {
+ if err = bm.Put("astaxie", "author", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
@@ -77,7 +78,7 @@ func TestCache(t *testing.T) {
t.Error("get err")
}
- if err = bm.Put("astaxie1", "author1", 10); err != nil {
+ if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie1") {
@@ -101,7 +102,8 @@ func TestFileCache(t *testing.T) {
if err != nil {
t.Error("init err")
}
- if err = bm.Put("astaxie", 1, 10); err != nil {
+ timeoutDuration := 10 * time.Second
+ if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
@@ -133,7 +135,7 @@ func TestFileCache(t *testing.T) {
}
//test string
- if err = bm.Put("astaxie", "author", 10); err != nil {
+ if err = bm.Put("astaxie", "author", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
@@ -144,7 +146,7 @@ func TestFileCache(t *testing.T) {
}
//test GetMulti
- if err = bm.Put("astaxie1", "author1", 10); err != nil {
+ if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie1") {
diff --git a/cache/conv.go b/cache/conv.go
index 724abfd2..dbdff1c7 100644
--- a/cache/conv.go
+++ b/cache/conv.go
@@ -19,7 +19,7 @@ import (
"strconv"
)
-// convert interface to string.
+// GetString convert interface to string.
func GetString(v interface{}) string {
switch result := v.(type) {
case string:
@@ -34,7 +34,7 @@ func GetString(v interface{}) string {
return ""
}
-// convert interface to int.
+// GetInt convert interface to int.
func GetInt(v interface{}) int {
switch result := v.(type) {
case int:
@@ -52,7 +52,7 @@ func GetInt(v interface{}) int {
return 0
}
-// convert interface to int64.
+// GetInt64 convert interface to int64.
func GetInt64(v interface{}) int64 {
switch result := v.(type) {
case int:
@@ -71,7 +71,7 @@ func GetInt64(v interface{}) int64 {
return 0
}
-// convert interface to float64.
+// GetFloat64 convert interface to float64.
func GetFloat64(v interface{}) float64 {
switch result := v.(type) {
case float64:
@@ -85,7 +85,7 @@ func GetFloat64(v interface{}) float64 {
return 0
}
-// convert interface to bool.
+// GetBool convert interface to bool.
func GetBool(v interface{}) bool {
switch result := v.(type) {
case bool:
@@ -98,15 +98,3 @@ func GetBool(v interface{}) bool {
}
return false
}
-
-// convert interface to byte slice.
-func getByteArray(v interface{}) []byte {
- switch result := v.(type) {
- case []byte:
- return result
- case string:
- return []byte(result)
- default:
- return nil
- }
-}
diff --git a/cache/conv_test.go b/cache/conv_test.go
index 267bb0c9..cf792fa6 100644
--- a/cache/conv_test.go
+++ b/cache/conv_test.go
@@ -27,7 +27,7 @@ func TestGetString(t *testing.T) {
if "test2" != GetString(t2) {
t.Error("get string from byte array error")
}
- var t3 int = 1
+ var t3 = 1
if "1" != GetString(t3) {
t.Error("get string from int error")
}
@@ -35,7 +35,7 @@ func TestGetString(t *testing.T) {
if "1" != GetString(t4) {
t.Error("get string from int64 error")
}
- var t5 float64 = 1.1
+ var t5 = 1.1
if "1.1" != GetString(t5) {
t.Error("get string from float64 error")
}
@@ -46,7 +46,7 @@ func TestGetString(t *testing.T) {
}
func TestGetInt(t *testing.T) {
- var t1 int = 1
+ var t1 = 1
if 1 != GetInt(t1) {
t.Error("get int from int error")
}
@@ -69,7 +69,7 @@ func TestGetInt(t *testing.T) {
func TestGetInt64(t *testing.T) {
var i int64 = 1
- var t1 int = 1
+ var t1 = 1
if i != GetInt64(t1) {
t.Error("get int64 from int error")
}
@@ -91,12 +91,12 @@ func TestGetInt64(t *testing.T) {
}
func TestGetFloat64(t *testing.T) {
- var f float64 = 1.11
+ var f = 1.11
var t1 float32 = 1.11
if f != GetFloat64(t1) {
t.Error("get float64 from float32 error")
}
- var t2 float64 = 1.11
+ var t2 = 1.11
if f != GetFloat64(t2) {
t.Error("get float64 from float64 error")
}
@@ -106,7 +106,7 @@ func TestGetFloat64(t *testing.T) {
}
var f2 float64 = 1
- var t4 int = 1
+ var t4 = 1
if f2 != GetFloat64(t4) {
t.Error("get float64 from int error")
}
@@ -130,21 +130,6 @@ func TestGetBool(t *testing.T) {
}
}
-func TestGetByteArray(t *testing.T) {
- var b = []byte("test")
- var t1 = []byte("test")
- if !byteArrayEquals(b, getByteArray(t1)) {
- t.Error("get byte array from byte array error")
- }
- var t2 = "test"
- if !byteArrayEquals(b, getByteArray(t2)) {
- t.Error("get byte array from string error")
- }
- if nil != getByteArray(nil) {
- t.Error("get byte array from nil error")
- }
-}
-
func byteArrayEquals(a []byte, b []byte) bool {
if len(a) != len(b) {
return false
diff --git a/cache/file.go b/cache/file.go
index 65f114f3..3a7aa8b0 100644
--- a/cache/file.go
+++ b/cache/file.go
@@ -29,23 +29,20 @@ import (
"time"
)
-func init() {
- Register("file", NewFileCache())
-}
-
// FileCacheItem is basic unit of file cache adapter.
// it contains data and expire time.
type FileCacheItem struct {
Data interface{}
- Lastaccess int64
- Expired int64
+ Lastaccess time.Time
+ Expired time.Time
}
+// FileCache Config
var (
- FileCachePath string = "cache" // cache directory
- FileCacheFileSuffix string = ".bin" // cache file suffix
- FileCacheDirectoryLevel int = 2 // cache file deep level if auto generated cache files.
- FileCacheEmbedExpiry int64 = 0 // cache expire time, default is no expire forever.
+ FileCachePath = "cache" // cache directory
+ FileCacheFileSuffix = ".bin" // cache file suffix
+ FileCacheDirectoryLevel = 2 // cache file deep level if auto generated cache files.
+ FileCacheEmbedExpiry time.Duration = 0 // cache expire time, default is no expire forever.
)
// FileCache is cache adapter for file storage.
@@ -56,14 +53,14 @@ type FileCache struct {
EmbedExpiry int
}
-// Create new file cache with no config.
+// NewFileCache Create new file cache with no config.
// the level and expiry need set in method StartAndGC as config string.
-func NewFileCache() *FileCache {
+func NewFileCache() Cache {
// return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix}
return &FileCache{}
}
-// Start and begin gc for file cache.
+// StartAndGC will start and begin gc for file cache.
// the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0}
func (fc *FileCache) StartAndGC(config string) error {
@@ -79,7 +76,7 @@ func (fc *FileCache) StartAndGC(config string) error {
cfg["DirectoryLevel"] = strconv.Itoa(FileCacheDirectoryLevel)
}
if _, ok := cfg["EmbedExpiry"]; !ok {
- cfg["EmbedExpiry"] = strconv.FormatInt(FileCacheEmbedExpiry, 10)
+ cfg["EmbedExpiry"] = strconv.FormatInt(int64(FileCacheEmbedExpiry.Seconds()), 10)
}
fc.CachePath = cfg["CachePath"]
fc.FileSuffix = cfg["FileSuffix"]
@@ -120,13 +117,13 @@ func (fc *FileCache) getCacheFileName(key string) string {
// Get value from file cache.
// if non-exist or expired, return empty string.
func (fc *FileCache) Get(key string) interface{} {
- fileData, err := File_get_contents(fc.getCacheFileName(key))
+ fileData, err := FileGetContents(fc.getCacheFileName(key))
if err != nil {
return ""
}
var to FileCacheItem
- Gob_decode(fileData, &to)
- if to.Expired < time.Now().Unix() {
+ GobDecode(fileData, &to)
+ if to.Expired.Before(time.Now()) {
return ""
}
return to.Data
@@ -145,21 +142,21 @@ func (fc *FileCache) GetMulti(keys []string) []interface{} {
// Put value into file cache.
// timeout means how long to keep this file, unit of ms.
// if timeout equals FileCacheEmbedExpiry(default is 0), cache this item forever.
-func (fc *FileCache) Put(key string, val interface{}, timeout int64) error {
+func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error {
gob.Register(val)
item := FileCacheItem{Data: val}
if timeout == FileCacheEmbedExpiry {
- item.Expired = time.Now().Unix() + (86400 * 365 * 10) // ten years
+ item.Expired = time.Now().Add((86400 * 365 * 10) * time.Second) // ten years
} else {
- item.Expired = time.Now().Unix() + timeout
+ item.Expired = time.Now().Add(timeout)
}
- item.Lastaccess = time.Now().Unix()
- data, err := Gob_encode(item)
+ item.Lastaccess = time.Now()
+ data, err := GobEncode(item)
if err != nil {
return err
}
- return File_put_contents(fc.getCacheFileName(key), data)
+ return FilePutContents(fc.getCacheFileName(key), data)
}
// Delete file cache value.
@@ -171,7 +168,7 @@ func (fc *FileCache) Delete(key string) error {
return nil
}
-// Increase cached int value.
+// Incr will increase cached int value.
// fc value is saving forever unless Delete.
func (fc *FileCache) Incr(key string) error {
data := fc.Get(key)
@@ -185,7 +182,7 @@ func (fc *FileCache) Incr(key string) error {
return nil
}
-// Decrease cached int value.
+// Decr will decrease cached int value.
func (fc *FileCache) Decr(key string) error {
data := fc.Get(key)
var decr int
@@ -198,13 +195,13 @@ func (fc *FileCache) Decr(key string) error {
return nil
}
-// Check value is exist.
+// IsExist check value is exist.
func (fc *FileCache) IsExist(key string) bool {
ret, _ := exists(fc.getCacheFileName(key))
return ret
}
-// Clean cached files.
+// ClearAll will clean cached files.
// not implemented.
func (fc *FileCache) ClearAll() error {
return nil
@@ -222,9 +219,9 @@ func exists(path string) (bool, error) {
return false, err
}
-// Get bytes to file.
+// FileGetContents Get bytes to file.
// if non-exist, create this file.
-func File_get_contents(filename string) (data []byte, e error) {
+func FileGetContents(filename string) (data []byte, e error) {
f, e := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, os.ModePerm)
if e != nil {
return
@@ -242,9 +239,9 @@ func File_get_contents(filename string) (data []byte, e error) {
return
}
-// Put bytes to file.
+// FilePutContents Put bytes to file.
// if non-exist, create this file.
-func File_put_contents(filename string, content []byte) error {
+func FilePutContents(filename string, content []byte) error {
fp, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, os.ModePerm)
if err != nil {
return err
@@ -254,8 +251,8 @@ func File_put_contents(filename string, content []byte) error {
return err
}
-// Gob encodes file cache item.
-func Gob_encode(data interface{}) ([]byte, error) {
+// GobEncode Gob encodes file cache item.
+func GobEncode(data interface{}) ([]byte, error) {
buf := bytes.NewBuffer(nil)
enc := gob.NewEncoder(buf)
err := enc.Encode(data)
@@ -265,9 +262,13 @@ func Gob_encode(data interface{}) ([]byte, error) {
return buf.Bytes(), err
}
-// Gob decodes file cache item.
-func Gob_decode(data []byte, to *FileCacheItem) error {
+// GobDecode Gob decodes file cache item.
+func GobDecode(data []byte, to *FileCacheItem) error {
buf := bytes.NewBuffer(data)
dec := gob.NewDecoder(buf)
return dec.Decode(&to)
}
+
+func init() {
+ Register("file", NewFileCache)
+}
diff --git a/cache/memcache/memcache.go b/cache/memcache/memcache.go
index c6829054..15ea5d3e 100644
--- a/cache/memcache/memcache.go
+++ b/cache/memcache/memcache.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// package memcahe for cache provider
+// Package memcache for cache provider
//
// depend on github.com/bradfitz/gomemcache/memcache
//
@@ -37,21 +37,22 @@ import (
"github.com/bradfitz/gomemcache/memcache"
"github.com/astaxie/beego/cache"
+ "time"
)
-// Memcache adapter.
-type MemcacheCache struct {
+// Cache Memcache adapter.
+type Cache struct {
conn *memcache.Client
conninfo []string
}
-// create new memcache adapter.
-func NewMemCache() *MemcacheCache {
- return &MemcacheCache{}
+// NewMemCache create new memcache adapter.
+func NewMemCache() cache.Cache {
+ return &Cache{}
}
-// get value from memcache.
-func (rc *MemcacheCache) Get(key string) interface{} {
+// Get get value from memcache.
+func (rc *Cache) Get(key string) interface{} {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return err
@@ -63,8 +64,8 @@ func (rc *MemcacheCache) Get(key string) interface{} {
return nil
}
-// get value from memcache.
-func (rc *MemcacheCache) GetMulti(keys []string) []interface{} {
+// GetMulti get value from memcache.
+func (rc *Cache) GetMulti(keys []string) []interface{} {
size := len(keys)
var rv []interface{}
if rc.conn == nil {
@@ -81,16 +82,15 @@ func (rc *MemcacheCache) GetMulti(keys []string) []interface{} {
rv = append(rv, string(v.Value))
}
return rv
- } else {
- for i := 0; i < size; i++ {
- rv = append(rv, err)
- }
- return rv
}
+ for i := 0; i < size; i++ {
+ rv = append(rv, err)
+ }
+ return rv
}
-// put value to memcache. only support string.
-func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
+// Put put value to memcache. only support string.
+func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return err
@@ -100,12 +100,12 @@ func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
if !ok {
return errors.New("val must string")
}
- item := memcache.Item{Key: key, Value: []byte(v), Expiration: int32(timeout)}
+ item := memcache.Item{Key: key, Value: []byte(v), Expiration: int32(timeout/time.Second)}
return rc.conn.Set(&item)
}
-// delete value in memcache.
-func (rc *MemcacheCache) Delete(key string) error {
+// Delete delete value in memcache.
+func (rc *Cache) Delete(key string) error {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return err
@@ -114,8 +114,8 @@ func (rc *MemcacheCache) Delete(key string) error {
return rc.conn.Delete(key)
}
-// increase counter.
-func (rc *MemcacheCache) Incr(key string) error {
+// Incr increase counter.
+func (rc *Cache) Incr(key string) error {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return err
@@ -125,8 +125,8 @@ func (rc *MemcacheCache) Incr(key string) error {
return err
}
-// decrease counter.
-func (rc *MemcacheCache) Decr(key string) error {
+// Decr decrease counter.
+func (rc *Cache) Decr(key string) error {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return err
@@ -136,8 +136,8 @@ func (rc *MemcacheCache) Decr(key string) error {
return err
}
-// check value exists in memcache.
-func (rc *MemcacheCache) IsExist(key string) bool {
+// IsExist check value exists in memcache.
+func (rc *Cache) IsExist(key string) bool {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return false
@@ -150,8 +150,8 @@ func (rc *MemcacheCache) IsExist(key string) bool {
return true
}
-// clear all cached in memcache.
-func (rc *MemcacheCache) ClearAll() error {
+// ClearAll clear all cached in memcache.
+func (rc *Cache) ClearAll() error {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return err
@@ -160,10 +160,10 @@ func (rc *MemcacheCache) ClearAll() error {
return rc.conn.FlushAll()
}
-// start memcache adapter.
+// StartAndGC start memcache adapter.
// config string is like {"conn":"connection info"}.
// if connecting error, return.
-func (rc *MemcacheCache) StartAndGC(config string) error {
+func (rc *Cache) StartAndGC(config string) error {
var cf map[string]string
json.Unmarshal([]byte(config), &cf)
if _, ok := cf["conn"]; !ok {
@@ -179,11 +179,11 @@ func (rc *MemcacheCache) StartAndGC(config string) error {
}
// connect to memcache and keep the connection.
-func (rc *MemcacheCache) connectInit() error {
+func (rc *Cache) connectInit() error {
rc.conn = memcache.New(rc.conninfo...)
return nil
}
func init() {
- cache.Register("memcache", NewMemCache())
+ cache.Register("memcache", NewMemCache)
}
diff --git a/cache/memcache/memcache_test.go b/cache/memcache/memcache_test.go
index 0523ae85..19629059 100644
--- a/cache/memcache/memcache_test.go
+++ b/cache/memcache/memcache_test.go
@@ -23,12 +23,13 @@ import (
"time"
)
-func TestRedisCache(t *testing.T) {
+func TestMemcacheCache(t *testing.T) {
bm, err := cache.NewCache("memcache", `{"conn": "127.0.0.1:11211"}`)
if err != nil {
t.Error("init err")
}
- if err = bm.Put("astaxie", "1", 10); err != nil {
+ timeoutDuration := 10 * time.Second
+ if err = bm.Put("astaxie", "1", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
@@ -40,7 +41,7 @@ func TestRedisCache(t *testing.T) {
if bm.IsExist("astaxie") {
t.Error("check err")
}
- if err = bm.Put("astaxie", "1", 10); err != nil {
+ if err = bm.Put("astaxie", "1", timeoutDuration); err != nil {
t.Error("set Error", err)
}
@@ -69,7 +70,7 @@ func TestRedisCache(t *testing.T) {
}
//test string
- if err = bm.Put("astaxie", "author", 10); err != nil {
+ if err = bm.Put("astaxie", "author", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
@@ -81,7 +82,7 @@ func TestRedisCache(t *testing.T) {
}
//test GetMulti
- if err = bm.Put("astaxie1", "author1", 10); err != nil {
+ if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie1") {
diff --git a/cache/memory.go b/cache/memory.go
index b6657048..d928afdb 100644
--- a/cache/memory.go
+++ b/cache/memory.go
@@ -17,34 +17,41 @@ package cache
import (
"encoding/json"
"errors"
- "fmt"
"sync"
"time"
)
var (
- // clock time of recycling the expired cache items in memory.
- DefaultEvery int = 60 // 1 minute
+ // DefaultEvery means the clock time of recycling the expired cache items in memory.
+ DefaultEvery = 60 // 1 minute
)
-// Memory cache item.
+// MemoryItem store memory cache item.
type MemoryItem struct {
- val interface{}
- Lastaccess time.Time
- expired int64
+ val interface{}
+ createdTime time.Time
+ lifespan time.Duration
}
-// Memory cache adapter.
+func (mi *MemoryItem) isExpire() bool {
+ // 0 means forever
+ if mi.lifespan == 0 {
+ return false
+ }
+ return time.Now().Sub(mi.createdTime) > mi.lifespan
+}
+
+// MemoryCache is Memory cache adapter.
// it contains a RW locker for safe map storage.
type MemoryCache struct {
- lock sync.RWMutex
+ sync.RWMutex
dur time.Duration
items map[string]*MemoryItem
Every int // run an expiration check Every clock time
}
// NewMemoryCache returns a new MemoryCache.
-func NewMemoryCache() *MemoryCache {
+func NewMemoryCache() Cache {
cache := MemoryCache{items: make(map[string]*MemoryItem)}
return &cache
}
@@ -52,11 +59,10 @@ func NewMemoryCache() *MemoryCache {
// Get cache from memory.
// if non-existed or expired, return nil.
func (bc *MemoryCache) Get(name string) interface{} {
- bc.lock.RLock()
- defer bc.lock.RUnlock()
+ bc.RLock()
+ defer bc.RUnlock()
if itm, ok := bc.items[name]; ok {
- if (time.Now().Unix() - itm.Lastaccess.Unix()) > itm.expired {
- go bc.Delete(name)
+ if itm.isExpire() {
return nil
}
return itm.val
@@ -75,22 +81,22 @@ func (bc *MemoryCache) GetMulti(names []string) []interface{} {
}
// Put cache to memory.
-// if expired is 0, it will be cleaned by next gc operation ( default gc clock is 1 minute).
-func (bc *MemoryCache) Put(name string, value interface{}, expired int64) error {
- bc.lock.Lock()
- defer bc.lock.Unlock()
+// if lifespan is 0, it will be forever till restart.
+func (bc *MemoryCache) Put(name string, value interface{}, lifespan time.Duration) error {
+ bc.Lock()
+ defer bc.Unlock()
bc.items[name] = &MemoryItem{
val: value,
- Lastaccess: time.Now(),
- expired: expired,
+ createdTime: time.Now(),
+ lifespan: lifespan,
}
return nil
}
-/// Delete cache in memory.
+// Delete cache in memory.
func (bc *MemoryCache) Delete(name string) error {
- bc.lock.Lock()
- defer bc.lock.Unlock()
+ bc.Lock()
+ defer bc.Unlock()
if _, ok := bc.items[name]; !ok {
return errors.New("key not exist")
}
@@ -101,11 +107,11 @@ func (bc *MemoryCache) Delete(name string) error {
return nil
}
-// Increase cache counter in memory.
-// it supports int,int64,int32,uint,uint64,uint32.
+// Incr increase cache counter in memory.
+// it supports int,int32,int64,uint,uint32,uint64.
func (bc *MemoryCache) Incr(key string) error {
- bc.lock.RLock()
- defer bc.lock.RUnlock()
+ bc.RLock()
+ defer bc.RUnlock()
itm, ok := bc.items[key]
if !ok {
return errors.New("key not exist")
@@ -113,10 +119,10 @@ func (bc *MemoryCache) Incr(key string) error {
switch itm.val.(type) {
case int:
itm.val = itm.val.(int) + 1
- case int64:
- itm.val = itm.val.(int64) + 1
case int32:
itm.val = itm.val.(int32) + 1
+ case int64:
+ itm.val = itm.val.(int64) + 1
case uint:
itm.val = itm.val.(uint) + 1
case uint32:
@@ -124,15 +130,15 @@ func (bc *MemoryCache) Incr(key string) error {
case uint64:
itm.val = itm.val.(uint64) + 1
default:
- return errors.New("item val is not int int64 int32")
+ return errors.New("item val is not (u)int (u)int32 (u)int64")
}
return nil
}
-// Decrease counter in memory.
+// Decr decrease counter in memory.
func (bc *MemoryCache) Decr(key string) error {
- bc.lock.RLock()
- defer bc.lock.RUnlock()
+ bc.RLock()
+ defer bc.RUnlock()
itm, ok := bc.items[key]
if !ok {
return errors.New("key not exist")
@@ -168,23 +174,25 @@ func (bc *MemoryCache) Decr(key string) error {
return nil
}
-// check cache exist in memory.
+// IsExist check cache exist in memory.
func (bc *MemoryCache) IsExist(name string) bool {
- bc.lock.RLock()
- defer bc.lock.RUnlock()
- _, ok := bc.items[name]
- return ok
+ bc.RLock()
+ defer bc.RUnlock()
+ if v, ok := bc.items[name]; ok {
+ return !v.isExpire()
+ }
+ return false
}
-// delete all cache in memory.
+// ClearAll will delete all cache in memory.
func (bc *MemoryCache) ClearAll() error {
- bc.lock.Lock()
- defer bc.lock.Unlock()
+ bc.Lock()
+ defer bc.Unlock()
bc.items = make(map[string]*MemoryItem)
return nil
}
-// start memory cache. it will check expiration in every clock time.
+// StartAndGC start memory cache. it will check expiration in every clock time.
func (bc *MemoryCache) StartAndGC(config string) error {
var cf map[string]int
json.Unmarshal([]byte(config), &cf)
@@ -192,10 +200,7 @@ func (bc *MemoryCache) StartAndGC(config string) error {
cf = make(map[string]int)
cf["interval"] = DefaultEvery
}
- dur, err := time.ParseDuration(fmt.Sprintf("%ds", cf["interval"]))
- if err != nil {
- return err
- }
+ dur := time.Duration(cf["interval"]) * time.Second
bc.Every = cf["interval"]
bc.dur = dur
go bc.vaccuum()
@@ -213,20 +218,21 @@ func (bc *MemoryCache) vaccuum() {
return
}
for name := range bc.items {
- bc.item_expired(name)
+ bc.itemExpired(name)
}
}
}
-// item_expired returns true if an item is expired.
-func (bc *MemoryCache) item_expired(name string) bool {
- bc.lock.Lock()
- defer bc.lock.Unlock()
+// itemExpired returns true if an item is expired.
+func (bc *MemoryCache) itemExpired(name string) bool {
+ bc.Lock()
+ defer bc.Unlock()
+
itm, ok := bc.items[name]
if !ok {
return true
}
- if time.Now().Unix()-itm.Lastaccess.Unix() >= itm.expired {
+ if itm.isExpire() {
delete(bc.items, name)
return true
}
@@ -234,5 +240,5 @@ func (bc *MemoryCache) item_expired(name string) bool {
}
func init() {
- Register("memory", NewMemoryCache())
+ Register("memory", NewMemoryCache)
}
diff --git a/cache/redis/redis.go b/cache/redis/redis.go
index d14b1ada..781e3836 100644
--- a/cache/redis/redis.go
+++ b/cache/redis/redis.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// package redis for cache provider
+// Package redis for cache provider
//
// depend on github.com/garyburd/redigo/redis
//
@@ -41,12 +41,12 @@ import (
)
var (
- // the collection name of redis for cache adapter.
- DefaultKey string = "beecacheRedis"
+ // DefaultKey the collection name of redis for cache adapter.
+ DefaultKey = "beecacheRedis"
)
-// Redis cache adapter.
-type RedisCache struct {
+// Cache is Redis cache adapter.
+type Cache struct {
p *redis.Pool // redis connection pool
conninfo string
dbNum int
@@ -54,13 +54,13 @@ type RedisCache struct {
password string
}
-// create new redis cache with default collection name.
-func NewRedisCache() *RedisCache {
- return &RedisCache{key: DefaultKey}
+// NewRedisCache create new redis cache with default collection name.
+func NewRedisCache() cache.Cache {
+ return &Cache{key: DefaultKey}
}
// actually do the redis cmds
-func (rc *RedisCache) do(commandName string, args ...interface{}) (reply interface{}, err error) {
+func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) {
c := rc.p.Get()
defer c.Close()
@@ -68,7 +68,7 @@ func (rc *RedisCache) do(commandName string, args ...interface{}) (reply interfa
}
// Get cache from redis.
-func (rc *RedisCache) Get(key string) interface{} {
+func (rc *Cache) Get(key string) interface{} {
if v, err := rc.do("GET", key); err == nil {
return v
}
@@ -76,7 +76,7 @@ func (rc *RedisCache) Get(key string) interface{} {
}
// GetMulti get cache from redis.
-func (rc *RedisCache) GetMulti(keys []string) []interface{} {
+func (rc *Cache) GetMulti(keys []string) []interface{} {
size := len(keys)
var rv []interface{}
c := rc.p.Get()
@@ -108,10 +108,10 @@ ERROR:
return rv
}
-// put cache to redis.
-func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error {
+// Put put cache to redis.
+func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error {
var err error
- if _, err = rc.do("SETEX", key, timeout, val); err != nil {
+ if _, err = rc.do("SETEX", key, int64(timeout/time.Second), val); err != nil {
return err
}
@@ -121,8 +121,8 @@ func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error {
return err
}
-// delete cache in redis.
-func (rc *RedisCache) Delete(key string) error {
+// Delete delete cache in redis.
+func (rc *Cache) Delete(key string) error {
var err error
if _, err = rc.do("DEL", key); err != nil {
return err
@@ -131,8 +131,8 @@ func (rc *RedisCache) Delete(key string) error {
return err
}
-// check cache's existence in redis.
-func (rc *RedisCache) IsExist(key string) bool {
+// IsExist check cache's existence in redis.
+func (rc *Cache) IsExist(key string) bool {
v, err := redis.Bool(rc.do("EXISTS", key))
if err != nil {
return false
@@ -145,20 +145,20 @@ func (rc *RedisCache) IsExist(key string) bool {
return v
}
-// increase counter in redis.
-func (rc *RedisCache) Incr(key string) error {
+// Incr increase counter in redis.
+func (rc *Cache) Incr(key string) error {
_, err := redis.Bool(rc.do("INCRBY", key, 1))
return err
}
-// decrease counter in redis.
-func (rc *RedisCache) Decr(key string) error {
+// Decr decrease counter in redis.
+func (rc *Cache) Decr(key string) error {
_, err := redis.Bool(rc.do("INCRBY", key, -1))
return err
}
-// clean all cache in redis. delete this redis collection.
-func (rc *RedisCache) ClearAll() error {
+// ClearAll clean all cache in redis. delete this redis collection.
+func (rc *Cache) ClearAll() error {
cachedKeys, err := redis.Strings(rc.do("HKEYS", rc.key))
if err != nil {
return err
@@ -172,11 +172,11 @@ func (rc *RedisCache) ClearAll() error {
return err
}
-// start redis cache adapter.
+// StartAndGC start redis cache adapter.
// config is like {"key":"collection key","conn":"connection info","dbNum":"0"}
// the cache item in redis are stored forever,
// so no gc operation.
-func (rc *RedisCache) StartAndGC(config string) error {
+func (rc *Cache) StartAndGC(config string) error {
var cf map[string]string
json.Unmarshal([]byte(config), &cf)
@@ -206,7 +206,7 @@ func (rc *RedisCache) StartAndGC(config string) error {
}
// connect to redis.
-func (rc *RedisCache) connectInit() {
+func (rc *Cache) connectInit() {
dialFunc := func() (c redis.Conn, err error) {
c, err = redis.Dial("tcp", rc.conninfo)
if err != nil {
@@ -236,5 +236,5 @@ func (rc *RedisCache) connectInit() {
}
func init() {
- cache.Register("redis", NewRedisCache())
+ cache.Register("redis", NewRedisCache)
}
diff --git a/cache/redis/redis_test.go b/cache/redis/redis_test.go
index 1f74fd27..47c5acc6 100644
--- a/cache/redis/redis_test.go
+++ b/cache/redis/redis_test.go
@@ -28,19 +28,20 @@ func TestRedisCache(t *testing.T) {
if err != nil {
t.Error("init err")
}
- if err = bm.Put("astaxie", 1, 10); err != nil {
+ timeoutDuration := 10 * time.Second
+ if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
t.Error("check err")
}
- time.Sleep(10 * time.Second)
+ time.Sleep(11 * time.Second)
if bm.IsExist("astaxie") {
t.Error("check err")
}
- if err = bm.Put("astaxie", 1, 10); err != nil {
+ if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err)
}
@@ -69,7 +70,7 @@ func TestRedisCache(t *testing.T) {
}
//test string
- if err = bm.Put("astaxie", "author", 10); err != nil {
+ if err = bm.Put("astaxie", "author", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
@@ -81,7 +82,7 @@ func TestRedisCache(t *testing.T) {
}
//test GetMulti
- if err = bm.Put("astaxie1", "author1", 10); err != nil {
+ if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie1") {
diff --git a/config.go b/config.go
index 09d0df24..cffb7c2b 100644
--- a/config.go
+++ b/config.go
@@ -15,78 +15,264 @@
package beego
import (
- "fmt"
"html/template"
"os"
"path/filepath"
- "runtime"
"strings"
"github.com/astaxie/beego/config"
- "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/session"
"github.com/astaxie/beego/utils"
)
-var (
- BeeApp *App // beego application
- AppName string
- AppPath string
- workPath string
- AppConfigPath string
+type BeegoConfig struct {
+ AppName string //Application name
+ RunMode string //Running Mode: dev | prod
+ RouterCaseSensitive bool
+ ServerName string
+ RecoverPanic bool
+ CopyRequestBody bool
+ EnableGzip bool
+ MaxMemory int64
+ EnableErrorsShow bool
+ Listen Listen
+ WebConfig WebConfig
+ Log LogConfig
+}
+
+type Listen struct {
+ Graceful bool // Graceful means use graceful module to start the server
+ ServerTimeOut int64
+ ListenTCP4 bool
+ EnableHTTP bool
+ HTTPAddr string
+ HTTPPort int
+ EnableHTTPS bool
+ HTTPSAddr string
+ HTTPSPort int
+ HTTPSCertFile string
+ HTTPSKeyFile string
+ EnableAdmin bool
+ AdminAddr string
+ AdminPort int
+ EnableFcgi bool
+ EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O
+}
+
+type WebConfig struct {
+ AutoRender bool
+ EnableDocs bool
+ FlashName string
+ FlashSeparator string
+ DirectoryIndex bool
StaticDir map[string]string
- TemplateCache map[string]*template.Template // template caching map
- StaticExtensionsToGzip []string // files with should be compressed with gzip (.js,.css,etc)
- EnableHttpListen bool
- HttpAddr string
- HttpPort int
- ListenTCP4 bool
- EnableHttpTLS bool
- HttpsPort int
- HttpCertFile string
- HttpKeyFile string
- RecoverPanic bool // flag of auto recover panic
- AutoRender bool // flag of render template automatically
- ViewsPath string
- AppConfig *beegoAppConfig
- RunMode string // run mode, "dev" or "prod"
- GlobalSessions *session.Manager // global session mananger
- SessionOn bool // flag of starting session auto. default is false.
- SessionProvider string // default session provider, memory, mysql , redis ,etc.
- SessionName string // the cookie name when saving session id into cookie.
- SessionGCMaxLifetime int64 // session gc time for auto cleaning expired session.
- SessionSavePath string // if use mysql/redis/file provider, define save path to connection info.
- SessionCookieLifeTime int // the life time of session id in cookie.
- SessionAutoSetCookie bool // auto setcookie
- SessionDomain string // the cookie domain default is empty
- UseFcgi bool
- UseStdIo bool
- MaxMemory int64
- EnableGzip bool // flag of enable gzip
- DirectoryIndex bool // flag of display directory index. default is false.
- HttpServerTimeOut int64
- ErrorsShow bool // flag of show errors in page. if true, show error and trace info in page rendered with error template.
- XSRFKEY string // xsrf hash salt string.
- EnableXSRF bool // flag of enable xsrf.
- XSRFExpire int // the expiry of xsrf value.
- CopyRequestBody bool // flag of copy raw request body in context.
+ StaticExtensionsToGzip []string
TemplateLeft string
TemplateRight string
- BeegoServerName string // beego server name exported in response header.
- EnableAdmin bool // flag of enable admin module to log every request info.
- AdminHttpAddr string // http server configurations for admin module.
- AdminHttpPort int
- FlashName string // name of the flash variable found in response header and cookie
- FlashSeperator string // used to seperate flash key:value
- AppConfigProvider string // config provider
- EnableDocs bool // enable generate docs & server docs API Swagger
- RouterCaseSensitive bool // router case sensitive default is true
- AccessLogs bool // print access logs, default is false
- Graceful bool // use graceful start the server
+ ViewsPath string
+ EnableXSRF bool
+ XSRFKey string
+ XSRFExpire int
+ Session SessionConfig
+}
+
+type SessionConfig struct {
+ SessionOn bool
+ SessionProvider string
+ SessionName string
+ SessionGCMaxLifetime int64
+ SessionProviderConfig string
+ SessionCookieLifeTime int
+ SessionAutoSetCookie bool
+ SessionDomain string
+}
+
+type LogConfig struct {
+ AccessLogs bool
+ FileLineNum bool
+ Outputs map[string]string // Store Adaptor : config
+}
+
+var (
+ // BConfig is the default config for Application
+ BConfig *BeegoConfig
+ // AppConfig is the instance of Config, store the config information from file
+ AppConfig *beegoAppConfig
+ // AppConfigPath is the path to the config files
+ AppConfigPath string
+ // AppConfigProvider is the provider for the config, default is ini
+ AppConfigProvider = "ini"
+ // TemplateCache stores template caching
+ TemplateCache map[string]*template.Template
+ // GlobalSessions is the instance for the session manager
+ GlobalSessions *session.Manager
)
+func init() {
+ BConfig = &BeegoConfig{
+ AppName: "beego",
+ RunMode: DEV,
+ RouterCaseSensitive: true,
+ ServerName: "beegoServer:" + VERSION,
+ RecoverPanic: true,
+ CopyRequestBody: false,
+ EnableGzip: false,
+ MaxMemory: 1 << 26, //64MB
+ EnableErrorsShow: true,
+ Listen: Listen{
+ Graceful: false,
+ ServerTimeOut: 0,
+ ListenTCP4: false,
+ EnableHTTP: true,
+ HTTPAddr: "",
+ HTTPPort: 8080,
+ EnableHTTPS: false,
+ HTTPSAddr: "",
+ HTTPSPort: 10443,
+ HTTPSCertFile: "",
+ HTTPSKeyFile: "",
+ EnableAdmin: false,
+ AdminAddr: "",
+ AdminPort: 8088,
+ EnableFcgi: false,
+ EnableStdIo: false,
+ },
+ WebConfig: WebConfig{
+ AutoRender: true,
+ EnableDocs: false,
+ FlashName: "BEEGO_FLASH",
+ FlashSeparator: "BEEGOFLASH",
+ DirectoryIndex: false,
+ StaticDir: map[string]string{"/static": "static"},
+ StaticExtensionsToGzip: []string{".css", ".js"},
+ TemplateLeft: "{{",
+ TemplateRight: "}}",
+ ViewsPath: "views",
+ EnableXSRF: false,
+ XSRFKey: "beegoxsrf",
+ XSRFExpire: 0,
+ Session: SessionConfig{
+ SessionOn: false,
+ SessionProvider: "memory",
+ SessionName: "beegosessionID",
+ SessionGCMaxLifetime: 3600,
+ SessionProviderConfig: "",
+ SessionCookieLifeTime: 0, //set cookie default is the brower life
+ SessionAutoSetCookie: true,
+ SessionDomain: "",
+ },
+ },
+ Log: LogConfig{
+ AccessLogs: false,
+ FileLineNum: true,
+ Outputs: map[string]string{"console": ""},
+ },
+ }
+ ParseConfig()
+}
+
+// ParseConfig parsed default config file.
+// now only support ini, next will support json.
+func ParseConfig() (err error) {
+ if AppConfigPath == "" {
+ if utils.FileExists(filepath.Join("conf", "app.conf")) {
+ AppConfigPath = filepath.Join("conf", "app.conf")
+ } else {
+ AppConfig = &beegoAppConfig{config.NewFakeConfig()}
+ return
+ }
+ }
+ AppConfig, err = newAppConfig(AppConfigProvider, AppConfigPath)
+ if err != nil {
+ return err
+ }
+ // set the runmode first
+ if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" {
+ BConfig.RunMode = envRunMode
+ } else if runmode := AppConfig.String("RunMode"); runmode != "" {
+ BConfig.RunMode = runmode
+ }
+
+ BConfig.AppName = AppConfig.DefaultString("AppName", BConfig.AppName)
+ BConfig.RecoverPanic = AppConfig.DefaultBool("RecoverPanic", BConfig.RecoverPanic)
+ BConfig.RouterCaseSensitive = AppConfig.DefaultBool("RouterCaseSensitive", BConfig.RouterCaseSensitive)
+ BConfig.ServerName = AppConfig.DefaultString("ServerName", BConfig.ServerName)
+ BConfig.EnableGzip = AppConfig.DefaultBool("EnableGzip", BConfig.EnableGzip)
+ BConfig.EnableErrorsShow = AppConfig.DefaultBool("EnableErrorsShow", BConfig.EnableErrorsShow)
+ BConfig.CopyRequestBody = AppConfig.DefaultBool("CopyRequestBody", BConfig.CopyRequestBody)
+ BConfig.MaxMemory = AppConfig.DefaultInt64("MaxMemory", BConfig.MaxMemory)
+ BConfig.Listen.Graceful = AppConfig.DefaultBool("Graceful", BConfig.Listen.Graceful)
+ BConfig.Listen.HTTPAddr = AppConfig.String("HTTPAddr")
+ BConfig.Listen.HTTPPort = AppConfig.DefaultInt("HTTPPort", BConfig.Listen.HTTPPort)
+ BConfig.Listen.ListenTCP4 = AppConfig.DefaultBool("ListenTCP4", BConfig.Listen.ListenTCP4)
+ BConfig.Listen.EnableHTTP = AppConfig.DefaultBool("EnableHTTP", BConfig.Listen.EnableHTTP)
+ BConfig.Listen.EnableHTTPS = AppConfig.DefaultBool("EnableHTTPS", BConfig.Listen.EnableHTTPS)
+ BConfig.Listen.HTTPSAddr = AppConfig.DefaultString("HTTPSAddr", BConfig.Listen.HTTPSAddr)
+ BConfig.Listen.HTTPSPort = AppConfig.DefaultInt("HTTPSPort", BConfig.Listen.HTTPSPort)
+ BConfig.Listen.HTTPSCertFile = AppConfig.DefaultString("HTTPSCertFile", BConfig.Listen.HTTPSCertFile)
+ BConfig.Listen.HTTPSKeyFile = AppConfig.DefaultString("HTTPSKeyFile", BConfig.Listen.HTTPSKeyFile)
+ BConfig.Listen.EnableAdmin = AppConfig.DefaultBool("EnableAdmin", BConfig.Listen.EnableAdmin)
+ BConfig.Listen.AdminAddr = AppConfig.DefaultString("AdminAddr", BConfig.Listen.AdminAddr)
+ BConfig.Listen.AdminPort = AppConfig.DefaultInt("AdminPort", BConfig.Listen.AdminPort)
+ BConfig.Listen.EnableFcgi = AppConfig.DefaultBool("EnableFcgi", BConfig.Listen.EnableFcgi)
+ BConfig.Listen.EnableStdIo = AppConfig.DefaultBool("EnableStdIo", BConfig.Listen.EnableStdIo)
+ BConfig.Listen.ServerTimeOut = AppConfig.DefaultInt64("ServerTimeOut", BConfig.Listen.ServerTimeOut)
+ BConfig.WebConfig.AutoRender = AppConfig.DefaultBool("AutoRender", BConfig.WebConfig.AutoRender)
+ BConfig.WebConfig.ViewsPath = AppConfig.DefaultString("ViewsPath", BConfig.WebConfig.ViewsPath)
+ BConfig.WebConfig.DirectoryIndex = AppConfig.DefaultBool("DirectoryIndex", BConfig.WebConfig.DirectoryIndex)
+ BConfig.WebConfig.FlashName = AppConfig.DefaultString("FlashName", BConfig.WebConfig.FlashName)
+ BConfig.WebConfig.FlashSeparator = AppConfig.DefaultString("FlashSeparator", BConfig.WebConfig.FlashSeparator)
+ BConfig.WebConfig.EnableDocs = AppConfig.DefaultBool("EnableDocs", BConfig.WebConfig.EnableDocs)
+ BConfig.WebConfig.XSRFKey = AppConfig.DefaultString("XSRFKEY", BConfig.WebConfig.XSRFKey)
+ BConfig.WebConfig.EnableXSRF = AppConfig.DefaultBool("EnableXSRF", BConfig.WebConfig.EnableXSRF)
+ BConfig.WebConfig.XSRFExpire = AppConfig.DefaultInt("XSRFExpire", BConfig.WebConfig.XSRFExpire)
+ BConfig.WebConfig.TemplateLeft = AppConfig.DefaultString("TemplateLeft", BConfig.WebConfig.TemplateLeft)
+ BConfig.WebConfig.TemplateRight = AppConfig.DefaultString("TemplateRight", BConfig.WebConfig.TemplateRight)
+ BConfig.WebConfig.Session.SessionOn = AppConfig.DefaultBool("SessionOn", BConfig.WebConfig.Session.SessionOn)
+ BConfig.WebConfig.Session.SessionProvider = AppConfig.DefaultString("SessionProvider", BConfig.WebConfig.Session.SessionProvider)
+ BConfig.WebConfig.Session.SessionName = AppConfig.DefaultString("SessionName", BConfig.WebConfig.Session.SessionName)
+ BConfig.WebConfig.Session.SessionProviderConfig = AppConfig.DefaultString("SessionProviderConfig", BConfig.WebConfig.Session.SessionProviderConfig)
+ BConfig.WebConfig.Session.SessionGCMaxLifetime = AppConfig.DefaultInt64("SessionGCMaxLifetime", BConfig.WebConfig.Session.SessionGCMaxLifetime)
+ BConfig.WebConfig.Session.SessionCookieLifeTime = AppConfig.DefaultInt("SessionCookieLifeTime", BConfig.WebConfig.Session.SessionCookieLifeTime)
+ BConfig.WebConfig.Session.SessionAutoSetCookie = AppConfig.DefaultBool("SessionAutoSetCookie", BConfig.WebConfig.Session.SessionAutoSetCookie)
+ BConfig.WebConfig.Session.SessionDomain = AppConfig.DefaultString("SessionDomain", BConfig.WebConfig.Session.SessionDomain)
+
+ if sd := AppConfig.String("StaticDir"); sd != "" {
+ for k := range BConfig.WebConfig.StaticDir {
+ delete(BConfig.WebConfig.StaticDir, k)
+ }
+ sds := strings.Fields(sd)
+ for _, v := range sds {
+ if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 {
+ BConfig.WebConfig.StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[1]
+ } else {
+ BConfig.WebConfig.StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[0]
+ }
+ }
+ }
+
+ if sgz := AppConfig.String("StaticExtensionsToGzip"); sgz != "" {
+ extensions := strings.Split(sgz, ",")
+ fileExts := []string{}
+ for _, ext := range extensions {
+ ext = strings.TrimSpace(ext)
+ if ext == "" {
+ continue
+ }
+ if !strings.HasPrefix(ext, ".") {
+ ext = "." + ext
+ }
+ fileExts = append(fileExts, ext)
+ }
+ if len(fileExts) > 0 {
+ BConfig.WebConfig.StaticExtensionsToGzip = fileExts
+ }
+ }
+ return nil
+}
+
type beegoAppConfig struct {
- innerConfig config.ConfigContainer
+ innerConfig config.Configer
}
func newAppConfig(AppConfigProvider, AppConfigPath string) (*beegoAppConfig, error) {
@@ -94,109 +280,95 @@ func newAppConfig(AppConfigProvider, AppConfigPath string) (*beegoAppConfig, err
if err != nil {
return nil, err
}
- rac := &beegoAppConfig{ac}
- return rac, nil
+ return &beegoAppConfig{ac}, nil
}
func (b *beegoAppConfig) Set(key, val string) error {
- err := b.innerConfig.Set(RunMode+"::"+key, val)
- if err == nil {
+ if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil {
return err
}
return b.innerConfig.Set(key, val)
}
func (b *beegoAppConfig) String(key string) string {
- v := b.innerConfig.String(RunMode + "::" + key)
- if v == "" {
- return b.innerConfig.String(key)
+ if v := b.innerConfig.String(BConfig.RunMode + "::" + key); v != "" {
+ return v
}
- return v
+ return b.innerConfig.String(key)
}
func (b *beegoAppConfig) Strings(key string) []string {
- v := b.innerConfig.Strings(RunMode + "::" + key)
- if v[0] == "" {
- return b.innerConfig.Strings(key)
+ if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); v[0] != "" {
+ return v
}
- return v
+ return b.innerConfig.Strings(key)
}
func (b *beegoAppConfig) Int(key string) (int, error) {
- v, err := b.innerConfig.Int(RunMode + "::" + key)
- if err != nil {
- return b.innerConfig.Int(key)
+ if v, err := b.innerConfig.Int(BConfig.RunMode + "::" + key); err == nil {
+ return v, nil
}
- return v, nil
+ return b.innerConfig.Int(key)
}
func (b *beegoAppConfig) Int64(key string) (int64, error) {
- v, err := b.innerConfig.Int64(RunMode + "::" + key)
- if err != nil {
- return b.innerConfig.Int64(key)
+ if v, err := b.innerConfig.Int64(BConfig.RunMode + "::" + key); err == nil {
+ return v, nil
}
- return v, nil
+ return b.innerConfig.Int64(key)
}
func (b *beegoAppConfig) Bool(key string) (bool, error) {
- v, err := b.innerConfig.Bool(RunMode + "::" + key)
- if err != nil {
- return b.innerConfig.Bool(key)
+ if v, err := b.innerConfig.Bool(BConfig.RunMode + "::" + key); err == nil {
+ return v, nil
}
- return v, nil
+ return b.innerConfig.Bool(key)
}
func (b *beegoAppConfig) Float(key string) (float64, error) {
- v, err := b.innerConfig.Float(RunMode + "::" + key)
- if err != nil {
- return b.innerConfig.Float(key)
+ if v, err := b.innerConfig.Float(BConfig.RunMode + "::" + key); err == nil {
+ return v, nil
}
- return v, nil
+ return b.innerConfig.Float(key)
}
func (b *beegoAppConfig) DefaultString(key string, defaultval string) string {
- v := b.String(key)
- if v != "" {
+ if v := b.String(key); v != "" {
return v
}
return defaultval
}
func (b *beegoAppConfig) DefaultStrings(key string, defaultval []string) []string {
- v := b.Strings(key)
- if len(v) != 0 {
+ if v := b.Strings(key); len(v) != 0 {
return v
}
return defaultval
}
func (b *beegoAppConfig) DefaultInt(key string, defaultval int) int {
- v, err := b.Int(key)
- if err == nil {
+ if v, err := b.Int(key); err == nil {
return v
}
return defaultval
}
func (b *beegoAppConfig) DefaultInt64(key string, defaultval int64) int64 {
- v, err := b.Int64(key)
- if err == nil {
+ if v, err := b.Int64(key); err == nil {
return v
}
return defaultval
}
func (b *beegoAppConfig) DefaultBool(key string, defaultval bool) bool {
- v, err := b.Bool(key)
- if err == nil {
+ if v, err := b.Bool(key); err == nil {
return v
}
return defaultval
}
func (b *beegoAppConfig) DefaultFloat(key string, defaultval float64) float64 {
- v, err := b.Float(key)
- if err == nil {
+ if v, err := b.Float(key); err == nil {
return v
}
return defaultval
@@ -213,305 +385,3 @@ func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) {
func (b *beegoAppConfig) SaveConfigFile(filename string) error {
return b.innerConfig.SaveConfigFile(filename)
}
-
-func init() {
- // create beego application
- BeeApp = NewApp()
-
- workPath, _ = os.Getwd()
- workPath, _ = filepath.Abs(workPath)
- // initialize default configurations
- AppPath, _ = filepath.Abs(filepath.Dir(os.Args[0]))
-
- AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
-
- if workPath != AppPath {
- if utils.FileExists(AppConfigPath) {
- os.Chdir(AppPath)
- } else {
- AppConfigPath = filepath.Join(workPath, "conf", "app.conf")
- }
- }
-
- AppConfigProvider = "ini"
-
- StaticDir = make(map[string]string)
- StaticDir["/static"] = "static"
-
- StaticExtensionsToGzip = []string{".css", ".js"}
-
- TemplateCache = make(map[string]*template.Template)
-
- // set this to 0.0.0.0 to make this app available to externally
- EnableHttpListen = true //default enable http Listen
-
- HttpAddr = ""
- HttpPort = 8080
-
- HttpsPort = 10443
-
- AppName = "beego"
-
- RunMode = "dev" //default runmod
-
- AutoRender = true
-
- RecoverPanic = true
-
- ViewsPath = "views"
-
- SessionOn = false
- SessionProvider = "memory"
- SessionName = "beegosessionID"
- SessionGCMaxLifetime = 3600
- SessionSavePath = ""
- SessionCookieLifeTime = 0 //set cookie default is the brower life
- SessionAutoSetCookie = true
-
- UseFcgi = false
- UseStdIo = false
-
- MaxMemory = 1 << 26 //64MB
-
- EnableGzip = false
-
- HttpServerTimeOut = 0
-
- ErrorsShow = true
-
- XSRFKEY = "beegoxsrf"
- XSRFExpire = 0
-
- TemplateLeft = "{{"
- TemplateRight = "}}"
-
- BeegoServerName = "beegoServer:" + VERSION
-
- EnableAdmin = false
- AdminHttpAddr = "127.0.0.1"
- AdminHttpPort = 8088
-
- FlashName = "BEEGO_FLASH"
- FlashSeperator = "BEEGOFLASH"
-
- RouterCaseSensitive = true
-
- runtime.GOMAXPROCS(runtime.NumCPU())
-
- // init BeeLogger
- BeeLogger = logs.NewLogger(10000)
- err := BeeLogger.SetLogger("console", "")
- if err != nil {
- fmt.Println("init console log error:", err)
- }
- SetLogFuncCall(true)
-
- err = ParseConfig()
- if err != nil && os.IsNotExist(err) {
- // for init if doesn't have app.conf will not panic
- ac := config.NewFakeConfig()
- AppConfig = &beegoAppConfig{ac}
- Warning(err)
- }
-}
-
-// ParseConfig parsed default config file.
-// now only support ini, next will support json.
-func ParseConfig() (err error) {
- AppConfig, err = newAppConfig(AppConfigProvider, AppConfigPath)
- if err != nil {
- return err
- }
- envRunMode := os.Getenv("BEEGO_RUNMODE")
- // set the runmode first
- if envRunMode != "" {
- RunMode = envRunMode
- } else if runmode := AppConfig.String("RunMode"); runmode != "" {
- RunMode = runmode
- }
-
- HttpAddr = AppConfig.String("HttpAddr")
-
- if v, err := AppConfig.Int("HttpPort"); err == nil {
- HttpPort = v
- }
-
- if v, err := AppConfig.Bool("ListenTCP4"); err == nil {
- ListenTCP4 = v
- }
-
- if v, err := AppConfig.Bool("EnableHttpListen"); err == nil {
- EnableHttpListen = v
- }
-
- if maxmemory, err := AppConfig.Int64("MaxMemory"); err == nil {
- MaxMemory = maxmemory
- }
-
- if appname := AppConfig.String("AppName"); appname != "" {
- AppName = appname
- }
-
- if autorender, err := AppConfig.Bool("AutoRender"); err == nil {
- AutoRender = autorender
- }
-
- if autorecover, err := AppConfig.Bool("RecoverPanic"); err == nil {
- RecoverPanic = autorecover
- }
-
- if views := AppConfig.String("ViewsPath"); views != "" {
- ViewsPath = views
- }
-
- if sessionon, err := AppConfig.Bool("SessionOn"); err == nil {
- SessionOn = sessionon
- }
-
- if sessProvider := AppConfig.String("SessionProvider"); sessProvider != "" {
- SessionProvider = sessProvider
- }
-
- if sessName := AppConfig.String("SessionName"); sessName != "" {
- SessionName = sessName
- }
-
- if sesssavepath := AppConfig.String("SessionSavePath"); sesssavepath != "" {
- SessionSavePath = sesssavepath
- }
-
- if sessMaxLifeTime, err := AppConfig.Int64("SessionGCMaxLifetime"); err == nil && sessMaxLifeTime != 0 {
- SessionGCMaxLifetime = sessMaxLifeTime
- }
-
- if sesscookielifetime, err := AppConfig.Int("SessionCookieLifeTime"); err == nil && sesscookielifetime != 0 {
- SessionCookieLifeTime = sesscookielifetime
- }
-
- if usefcgi, err := AppConfig.Bool("UseFcgi"); err == nil {
- UseFcgi = usefcgi
- }
-
- if enablegzip, err := AppConfig.Bool("EnableGzip"); err == nil {
- EnableGzip = enablegzip
- }
-
- if directoryindex, err := AppConfig.Bool("DirectoryIndex"); err == nil {
- DirectoryIndex = directoryindex
- }
-
- if timeout, err := AppConfig.Int64("HttpServerTimeOut"); err == nil {
- HttpServerTimeOut = timeout
- }
-
- if errorsshow, err := AppConfig.Bool("ErrorsShow"); err == nil {
- ErrorsShow = errorsshow
- }
-
- if copyrequestbody, err := AppConfig.Bool("CopyRequestBody"); err == nil {
- CopyRequestBody = copyrequestbody
- }
-
- if xsrfkey := AppConfig.String("XSRFKEY"); xsrfkey != "" {
- XSRFKEY = xsrfkey
- }
-
- if enablexsrf, err := AppConfig.Bool("EnableXSRF"); err == nil {
- EnableXSRF = enablexsrf
- }
-
- if expire, err := AppConfig.Int("XSRFExpire"); err == nil {
- XSRFExpire = expire
- }
-
- if tplleft := AppConfig.String("TemplateLeft"); tplleft != "" {
- TemplateLeft = tplleft
- }
-
- if tplright := AppConfig.String("TemplateRight"); tplright != "" {
- TemplateRight = tplright
- }
-
- if httptls, err := AppConfig.Bool("EnableHttpTLS"); err == nil {
- EnableHttpTLS = httptls
- }
-
- if httpsport, err := AppConfig.Int("HttpsPort"); err == nil {
- HttpsPort = httpsport
- }
-
- if certfile := AppConfig.String("HttpCertFile"); certfile != "" {
- HttpCertFile = certfile
- }
-
- if keyfile := AppConfig.String("HttpKeyFile"); keyfile != "" {
- HttpKeyFile = keyfile
- }
-
- if serverName := AppConfig.String("BeegoServerName"); serverName != "" {
- BeegoServerName = serverName
- }
-
- if flashname := AppConfig.String("FlashName"); flashname != "" {
- FlashName = flashname
- }
-
- if flashseperator := AppConfig.String("FlashSeperator"); flashseperator != "" {
- FlashSeperator = flashseperator
- }
-
- if sd := AppConfig.String("StaticDir"); sd != "" {
- for k := range StaticDir {
- delete(StaticDir, k)
- }
- sds := strings.Fields(sd)
- for _, v := range sds {
- if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 {
- StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[1]
- } else {
- StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[0]
- }
- }
- }
-
- if sgz := AppConfig.String("StaticExtensionsToGzip"); sgz != "" {
- extensions := strings.Split(sgz, ",")
- if len(extensions) > 0 {
- StaticExtensionsToGzip = []string{}
- for _, ext := range extensions {
- if len(ext) == 0 {
- continue
- }
- extWithDot := ext
- if extWithDot[:1] != "." {
- extWithDot = "." + extWithDot
- }
- StaticExtensionsToGzip = append(StaticExtensionsToGzip, extWithDot)
- }
- }
- }
-
- if enableadmin, err := AppConfig.Bool("EnableAdmin"); err == nil {
- EnableAdmin = enableadmin
- }
-
- if adminhttpaddr := AppConfig.String("AdminHttpAddr"); adminhttpaddr != "" {
- AdminHttpAddr = adminhttpaddr
- }
-
- if adminhttpport, err := AppConfig.Int("AdminHttpPort"); err == nil {
- AdminHttpPort = adminhttpport
- }
-
- if enabledocs, err := AppConfig.Bool("EnableDocs"); err == nil {
- EnableDocs = enabledocs
- }
-
- if casesensitive, err := AppConfig.Bool("RouterCaseSensitive"); err == nil {
- RouterCaseSensitive = casesensitive
- }
- if graceful, err := AppConfig.Bool("Graceful"); err == nil {
- Graceful = graceful
- }
- return nil
-}
diff --git a/config/config.go b/config/config.go
index 8d9261b8..da5d358b 100644
--- a/config/config.go
+++ b/config/config.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package config is used to parse config
// Usage:
// import(
// "github.com/astaxie/beego/config"
@@ -28,12 +29,12 @@
// cnf.Int64(key string) (int64, error)
// cnf.Bool(key string) (bool, error)
// cnf.Float(key string) (float64, error)
-// cnf.DefaultString(key string, defaultval string) string
-// cnf.DefaultStrings(key string, defaultval []string) []string
-// cnf.DefaultInt(key string, defaultval int) int
-// cnf.DefaultInt64(key string, defaultval int64) int64
-// cnf.DefaultBool(key string, defaultval bool) bool
-// cnf.DefaultFloat(key string, defaultval float64) float64
+// cnf.DefaultString(key string, defaultVal string) string
+// cnf.DefaultStrings(key string, defaultVal []string) []string
+// cnf.DefaultInt(key string, defaultVal int) int
+// cnf.DefaultInt64(key string, defaultVal int64) int64
+// cnf.DefaultBool(key string, defaultVal bool) bool
+// cnf.DefaultFloat(key string, defaultVal float64) float64
// cnf.DIY(key string) (interface{}, error)
// cnf.GetSection(section string) (map[string]string, error)
// cnf.SaveConfigFile(filename string) error
@@ -45,30 +46,30 @@ import (
"fmt"
)
-// ConfigContainer defines how to get and set value from configuration raw data.
-type ConfigContainer interface {
- Set(key, val string) error // support section::key type in given key when using ini type.
- String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
+// Configer defines how to get and set value from configuration raw data.
+type Configer interface {
+ Set(key, val string) error //support section::key type in given key when using ini type.
+ String(key string) string //support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
Strings(key string) []string //get string slice
Int(key string) (int, error)
Int64(key string) (int64, error)
Bool(key string) (bool, error)
Float(key string) (float64, error)
- DefaultString(key string, defaultval string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
- DefaultStrings(key string, defaultval []string) []string //get string slice
- DefaultInt(key string, defaultval int) int
- DefaultInt64(key string, defaultval int64) int64
- DefaultBool(key string, defaultval bool) bool
- DefaultFloat(key string, defaultval float64) float64
+ DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
+ DefaultStrings(key string, defaultVal []string) []string //get string slice
+ DefaultInt(key string, defaultVal int) int
+ DefaultInt64(key string, defaultVal int64) int64
+ DefaultBool(key string, defaultVal bool) bool
+ DefaultFloat(key string, defaultVal float64) float64
DIY(key string) (interface{}, error)
GetSection(section string) (map[string]string, error)
SaveConfigFile(filename string) error
}
-// Config is the adapter interface for parsing config file to get raw data to ConfigContainer.
+// Config is the adapter interface for parsing config file to get raw data to Configer.
type Config interface {
- Parse(key string) (ConfigContainer, error)
- ParseData(data []byte) (ConfigContainer, error)
+ Parse(key string) (Configer, error)
+ ParseData(data []byte) (Configer, error)
}
var adapters = make(map[string]Config)
@@ -86,19 +87,19 @@ func Register(name string, adapter Config) {
adapters[name] = adapter
}
-// adapterName is ini/json/xml/yaml.
+// NewConfig adapterName is ini/json/xml/yaml.
// filename is the config file path.
-func NewConfig(adapterName, fileaname string) (ConfigContainer, error) {
+func NewConfig(adapterName, filename string) (Configer, error) {
adapter, ok := adapters[adapterName]
if !ok {
return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName)
}
- return adapter.Parse(fileaname)
+ return adapter.Parse(filename)
}
-// adapterName is ini/json/xml/yaml.
+// NewConfigData adapterName is ini/json/xml/yaml.
// data is the config data.
-func NewConfigData(adapterName string, data []byte) (ConfigContainer, error) {
+func NewConfigData(adapterName string, data []byte) (Configer, error) {
adapter, ok := adapters[adapterName]
if !ok {
return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName)
diff --git a/config/fake.go b/config/fake.go
index 54588e5e..50aa5d4a 100644
--- a/config/fake.go
+++ b/config/fake.go
@@ -38,11 +38,11 @@ func (c *fakeConfigContainer) String(key string) string {
}
func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string {
- if v := c.getData(key); v == "" {
+ v := c.getData(key)
+ if v == "" {
return defaultval
- } else {
- return v
}
+ return v
}
func (c *fakeConfigContainer) Strings(key string) []string {
@@ -50,11 +50,11 @@ func (c *fakeConfigContainer) Strings(key string) []string {
}
func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []string {
- if v := c.Strings(key); len(v) == 0 {
+ v := c.Strings(key)
+ if len(v) == 0 {
return defaultval
- } else {
- return v
}
+ return v
}
func (c *fakeConfigContainer) Int(key string) (int, error) {
@@ -62,11 +62,11 @@ func (c *fakeConfigContainer) Int(key string) (int, error) {
}
func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int {
- if v, err := c.Int(key); err != nil {
+ v, err := c.Int(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
func (c *fakeConfigContainer) Int64(key string) (int64, error) {
@@ -74,11 +74,11 @@ func (c *fakeConfigContainer) Int64(key string) (int64, error) {
}
func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
- if v, err := c.Int64(key); err != nil {
+ v, err := c.Int64(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
func (c *fakeConfigContainer) Bool(key string) (bool, error) {
@@ -86,11 +86,11 @@ func (c *fakeConfigContainer) Bool(key string) (bool, error) {
}
func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool {
- if v, err := c.Bool(key); err != nil {
+ v, err := c.Bool(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
func (c *fakeConfigContainer) Float(key string) (float64, error) {
@@ -98,11 +98,11 @@ func (c *fakeConfigContainer) Float(key string) (float64, error) {
}
func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
- if v, err := c.Float(key); err != nil {
+ v, err := c.Float(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
func (c *fakeConfigContainer) DIY(key string) (interface{}, error) {
@@ -120,9 +120,10 @@ func (c *fakeConfigContainer) SaveConfigFile(filename string) error {
return errors.New("not implement in the fakeConfigContainer")
}
-var _ ConfigContainer = new(fakeConfigContainer)
+var _ Configer = new(fakeConfigContainer)
-func NewFakeConfig() ConfigContainer {
+// NewFakeConfig return a fake Congiger
+func NewFakeConfig() Configer {
return &fakeConfigContainer{
data: make(map[string]string),
}
diff --git a/config/ini.go b/config/ini.go
index 31fe9b5f..59e84e1e 100644
--- a/config/ini.go
+++ b/config/ini.go
@@ -31,23 +31,23 @@ import (
)
var (
- DEFAULT_SECTION = "default" // default section means if some ini items not in a section, make them in default section,
- bNumComment = []byte{'#'} // number signal
- bSemComment = []byte{';'} // semicolon signal
- bEmpty = []byte{}
- bEqual = []byte{'='} // equal signal
- bDQuote = []byte{'"'} // quote signal
- sectionStart = []byte{'['} // section start signal
- sectionEnd = []byte{']'} // section end signal
- lineBreak = "\n"
+ defaultSection = "default" // default section means if some ini items not in a section, make them in default section,
+ bNumComment = []byte{'#'} // number signal
+ bSemComment = []byte{';'} // semicolon signal
+ bEmpty = []byte{}
+ bEqual = []byte{'='} // equal signal
+ bDQuote = []byte{'"'} // quote signal
+ sectionStart = []byte{'['} // section start signal
+ sectionEnd = []byte{']'} // section end signal
+ lineBreak = "\n"
)
// IniConfig implements Config to parse ini file.
type IniConfig struct {
}
-// ParseFile creates a new Config and parses the file configuration from the named file.
-func (ini *IniConfig) Parse(name string) (ConfigContainer, error) {
+// Parse creates a new Config and parses the file configuration from the named file.
+func (ini *IniConfig) Parse(name string) (Configer, error) {
return ini.parseFile(name)
}
@@ -77,7 +77,7 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
buf.ReadByte()
}
}
- section := DEFAULT_SECTION
+ section := defaultSection
for {
line, _, err := buf.ReadLine()
if err == io.EOF {
@@ -171,7 +171,8 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
return cfg, nil
}
-func (ini *IniConfig) ParseData(data []byte) (ConfigContainer, error) {
+// ParseData parse ini the data
+func (ini *IniConfig) ParseData(data []byte) (Configer, error) {
// Save memory data to temporary file
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond()))
os.MkdirAll(path.Dir(tmpName), os.ModePerm)
@@ -181,7 +182,7 @@ func (ini *IniConfig) ParseData(data []byte) (ConfigContainer, error) {
return ini.Parse(tmpName)
}
-// A Config represents the ini configuration.
+// IniConfigContainer A Config represents the ini configuration.
// When set and get value, support key as section:name type.
type IniConfigContainer struct {
filename string
@@ -199,11 +200,11 @@ func (c *IniConfigContainer) Bool(key string) (bool, error) {
// DefaultBool returns the boolean value for a given key.
// if err != nil return defaltval
func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool {
- if v, err := c.Bool(key); err != nil {
+ v, err := c.Bool(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// Int returns the integer value for a given key.
@@ -214,11 +215,11 @@ func (c *IniConfigContainer) Int(key string) (int, error) {
// DefaultInt returns the integer value for a given key.
// if err != nil return defaltval
func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int {
- if v, err := c.Int(key); err != nil {
+ v, err := c.Int(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// Int64 returns the int64 value for a given key.
@@ -229,11 +230,11 @@ func (c *IniConfigContainer) Int64(key string) (int64, error) {
// DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval
func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
- if v, err := c.Int64(key); err != nil {
+ v, err := c.Int64(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// Float returns the float value for a given key.
@@ -244,11 +245,11 @@ func (c *IniConfigContainer) Float(key string) (float64, error) {
// DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval
func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
- if v, err := c.Float(key); err != nil {
+ v, err := c.Float(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// String returns the string value for a given key.
@@ -259,11 +260,11 @@ func (c *IniConfigContainer) String(key string) string {
// DefaultString returns the string value for a given key.
// if err != nil return defaltval
func (c *IniConfigContainer) DefaultString(key string, defaultval string) string {
- if v := c.String(key); v == "" {
+ v := c.String(key)
+ if v == "" {
return defaultval
- } else {
- return v
}
+ return v
}
// Strings returns the []string value for a given key.
@@ -274,20 +275,19 @@ func (c *IniConfigContainer) Strings(key string) []string {
// DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval
func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string {
- if v := c.Strings(key); len(v) == 0 {
+ v := c.Strings(key)
+ if len(v) == 0 {
return defaultval
- } else {
- return v
}
+ return v
}
// GetSection returns map for the given section
func (c *IniConfigContainer) GetSection(section string) (map[string]string, error) {
if v, ok := c.data[section]; ok {
return v, nil
- } else {
- return nil, errors.New("not exist setction")
}
+ return nil, errors.New("not exist setction")
}
// SaveConfigFile save the config into file
@@ -301,7 +301,7 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
buf := bytes.NewBuffer(nil)
// Save default section at first place
- if dt, ok := c.data[DEFAULT_SECTION]; ok {
+ if dt, ok := c.data[defaultSection]; ok {
for key, val := range dt {
if key != " " {
// Write key comments.
@@ -325,7 +325,7 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
}
// Save named sections
for section, dt := range c.data {
- if section != DEFAULT_SECTION {
+ if section != defaultSection {
// Write section comments.
if v, ok := c.sectionComment[section]; ok {
if _, err = buf.WriteString(string(bNumComment) + v + lineBreak); err != nil {
@@ -367,7 +367,7 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
return nil
}
-// WriteValue writes a new value for key.
+// Set writes a new value for key.
// if write to one section, the key need be "section::key".
// if the section is not existed, it panics.
func (c *IniConfigContainer) Set(key, value string) error {
@@ -379,14 +379,14 @@ func (c *IniConfigContainer) Set(key, value string) error {
var (
section, k string
- sectionKey []string = strings.Split(key, "::")
+ sectionKey = strings.Split(key, "::")
)
if len(sectionKey) >= 2 {
section = sectionKey[0]
k = sectionKey[1]
} else {
- section = DEFAULT_SECTION
+ section = defaultSection
k = sectionKey[0]
}
@@ -415,13 +415,13 @@ func (c *IniConfigContainer) getdata(key string) string {
var (
section, k string
- sectionKey []string = strings.Split(strings.ToLower(key), "::")
+ sectionKey = strings.Split(strings.ToLower(key), "::")
)
if len(sectionKey) >= 2 {
section = sectionKey[0]
k = sectionKey[1]
} else {
- section = DEFAULT_SECTION
+ section = defaultSection
k = sectionKey[0]
}
if v, ok := c.data[section]; ok {
diff --git a/config/json.go b/config/json.go
index e2b53793..6929baad 100644
--- a/config/json.go
+++ b/config/json.go
@@ -23,12 +23,12 @@ import (
"sync"
)
-// JsonConfig is a json config parser and implements Config interface.
-type JsonConfig struct {
+// JSONConfig is a json config parser and implements Config interface.
+type JSONConfig struct {
}
// Parse returns a ConfigContainer with parsed json config map.
-func (js *JsonConfig) Parse(filename string) (ConfigContainer, error) {
+func (js *JSONConfig) Parse(filename string) (Configer, error) {
file, err := os.Open(filename)
if err != nil {
return nil, err
@@ -43,8 +43,8 @@ func (js *JsonConfig) Parse(filename string) (ConfigContainer, error) {
}
// ParseData returns a ConfigContainer with json string
-func (js *JsonConfig) ParseData(data []byte) (ConfigContainer, error) {
- x := &JsonConfigContainer{
+func (js *JSONConfig) ParseData(data []byte) (Configer, error) {
+ x := &JSONConfigContainer{
data: make(map[string]interface{}),
}
err := json.Unmarshal(data, &x.data)
@@ -59,15 +59,15 @@ func (js *JsonConfig) ParseData(data []byte) (ConfigContainer, error) {
return x, nil
}
-// A Config represents the json configuration.
+// JSONConfigContainer A Config represents the json configuration.
// Only when get value, support key as section:name type.
-type JsonConfigContainer struct {
+type JSONConfigContainer struct {
data map[string]interface{}
sync.RWMutex
}
// Bool returns the boolean value for a given key.
-func (c *JsonConfigContainer) Bool(key string) (bool, error) {
+func (c *JSONConfigContainer) Bool(key string) (bool, error) {
val := c.getData(key)
if val != nil {
if v, ok := val.(bool); ok {
@@ -80,7 +80,7 @@ func (c *JsonConfigContainer) Bool(key string) (bool, error) {
// DefaultBool return the bool value if has no error
// otherwise return the defaultval
-func (c *JsonConfigContainer) DefaultBool(key string, defaultval bool) bool {
+func (c *JSONConfigContainer) DefaultBool(key string, defaultval bool) bool {
if v, err := c.Bool(key); err == nil {
return v
}
@@ -88,7 +88,7 @@ func (c *JsonConfigContainer) DefaultBool(key string, defaultval bool) bool {
}
// Int returns the integer value for a given key.
-func (c *JsonConfigContainer) Int(key string) (int, error) {
+func (c *JSONConfigContainer) Int(key string) (int, error) {
val := c.getData(key)
if val != nil {
if v, ok := val.(float64); ok {
@@ -101,7 +101,7 @@ func (c *JsonConfigContainer) Int(key string) (int, error) {
// DefaultInt returns the integer value for a given key.
// if err != nil return defaltval
-func (c *JsonConfigContainer) DefaultInt(key string, defaultval int) int {
+func (c *JSONConfigContainer) DefaultInt(key string, defaultval int) int {
if v, err := c.Int(key); err == nil {
return v
}
@@ -109,7 +109,7 @@ func (c *JsonConfigContainer) DefaultInt(key string, defaultval int) int {
}
// Int64 returns the int64 value for a given key.
-func (c *JsonConfigContainer) Int64(key string) (int64, error) {
+func (c *JSONConfigContainer) Int64(key string) (int64, error) {
val := c.getData(key)
if val != nil {
if v, ok := val.(float64); ok {
@@ -122,7 +122,7 @@ func (c *JsonConfigContainer) Int64(key string) (int64, error) {
// DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval
-func (c *JsonConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
+func (c *JSONConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
if v, err := c.Int64(key); err == nil {
return v
}
@@ -130,7 +130,7 @@ func (c *JsonConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
}
// Float returns the float value for a given key.
-func (c *JsonConfigContainer) Float(key string) (float64, error) {
+func (c *JSONConfigContainer) Float(key string) (float64, error) {
val := c.getData(key)
if val != nil {
if v, ok := val.(float64); ok {
@@ -143,7 +143,7 @@ func (c *JsonConfigContainer) Float(key string) (float64, error) {
// DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval
-func (c *JsonConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
+func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
if v, err := c.Float(key); err == nil {
return v
}
@@ -151,7 +151,7 @@ func (c *JsonConfigContainer) DefaultFloat(key string, defaultval float64) float
}
// String returns the string value for a given key.
-func (c *JsonConfigContainer) String(key string) string {
+func (c *JSONConfigContainer) String(key string) string {
val := c.getData(key)
if val != nil {
if v, ok := val.(string); ok {
@@ -163,7 +163,7 @@ func (c *JsonConfigContainer) String(key string) string {
// DefaultString returns the string value for a given key.
// if err != nil return defaltval
-func (c *JsonConfigContainer) DefaultString(key string, defaultval string) string {
+func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string {
// TODO FIXME should not use "" to replace non existance
if v := c.String(key); v != "" {
return v
@@ -172,7 +172,7 @@ func (c *JsonConfigContainer) DefaultString(key string, defaultval string) strin
}
// Strings returns the []string value for a given key.
-func (c *JsonConfigContainer) Strings(key string) []string {
+func (c *JSONConfigContainer) Strings(key string) []string {
stringVal := c.String(key)
if stringVal == "" {
return []string{}
@@ -182,7 +182,7 @@ func (c *JsonConfigContainer) Strings(key string) []string {
// DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval
-func (c *JsonConfigContainer) DefaultStrings(key string, defaultval []string) []string {
+func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string {
if v := c.Strings(key); len(v) > 0 {
return v
}
@@ -190,7 +190,7 @@ func (c *JsonConfigContainer) DefaultStrings(key string, defaultval []string) []
}
// GetSection returns map for the given section
-func (c *JsonConfigContainer) GetSection(section string) (map[string]string, error) {
+func (c *JSONConfigContainer) GetSection(section string) (map[string]string, error) {
if v, ok := c.data[section]; ok {
return v.(map[string]string), nil
}
@@ -198,7 +198,7 @@ func (c *JsonConfigContainer) GetSection(section string) (map[string]string, err
}
// SaveConfigFile save the config into file
-func (c *JsonConfigContainer) SaveConfigFile(filename string) (err error) {
+func (c *JSONConfigContainer) SaveConfigFile(filename string) (err error) {
// Write configuration file by filename.
f, err := os.Create(filename)
if err != nil {
@@ -214,7 +214,7 @@ func (c *JsonConfigContainer) SaveConfigFile(filename string) (err error) {
}
// Set writes a new value for key.
-func (c *JsonConfigContainer) Set(key, val string) error {
+func (c *JSONConfigContainer) Set(key, val string) error {
c.Lock()
defer c.Unlock()
c.data[key] = val
@@ -222,7 +222,7 @@ func (c *JsonConfigContainer) Set(key, val string) error {
}
// DIY returns the raw value by a given key.
-func (c *JsonConfigContainer) DIY(key string) (v interface{}, err error) {
+func (c *JSONConfigContainer) DIY(key string) (v interface{}, err error) {
val := c.getData(key)
if val != nil {
return val, nil
@@ -231,7 +231,7 @@ func (c *JsonConfigContainer) DIY(key string) (v interface{}, err error) {
}
// section.key or key
-func (c *JsonConfigContainer) getData(key string) interface{} {
+func (c *JSONConfigContainer) getData(key string) interface{} {
if len(key) == 0 {
return nil
}
@@ -261,5 +261,5 @@ func (c *JsonConfigContainer) getData(key string) interface{} {
}
func init() {
- Register("json", &JsonConfig{})
+ Register("json", &JSONConfig{})
}
diff --git a/config/xml/xml.go b/config/xml/xml.go
index a1d9fcdb..4d48f7d2 100644
--- a/config/xml/xml.go
+++ b/config/xml/xml.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// package xml for config provider
+// Package xml for config provider
//
// depend on github.com/beego/x2j
//
@@ -45,20 +45,20 @@ import (
"github.com/beego/x2j"
)
-// XmlConfig is a xml config parser and implements Config interface.
+// Config is a xml config parser and implements Config interface.
// xml configurations should be included in tag.
// only support key/value pair as value as each item.
-type XMLConfig struct{}
+type Config struct{}
// Parse returns a ConfigContainer with parsed xml config map.
-func (xc *XMLConfig) Parse(filename string) (config.ConfigContainer, error) {
+func (xc *Config) Parse(filename string) (config.Configer, error) {
file, err := os.Open(filename)
if err != nil {
return nil, err
}
defer file.Close()
- x := &XMLConfigContainer{data: make(map[string]interface{})}
+ x := &ConfigContainer{data: make(map[string]interface{})}
content, err := ioutil.ReadAll(file)
if err != nil {
return nil, err
@@ -73,84 +73,86 @@ func (xc *XMLConfig) Parse(filename string) (config.ConfigContainer, error) {
return x, nil
}
-func (x *XMLConfig) ParseData(data []byte) (config.ConfigContainer, error) {
+// ParseData xml data
+func (xc *Config) ParseData(data []byte) (config.Configer, error) {
// Save memory data to temporary file
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond()))
os.MkdirAll(path.Dir(tmpName), os.ModePerm)
if err := ioutil.WriteFile(tmpName, data, 0655); err != nil {
return nil, err
}
- return x.Parse(tmpName)
+ return xc.Parse(tmpName)
}
-// A Config represents the xml configuration.
-type XMLConfigContainer struct {
+// ConfigContainer A Config represents the xml configuration.
+type ConfigContainer struct {
data map[string]interface{}
sync.Mutex
}
// Bool returns the boolean value for a given key.
-func (c *XMLConfigContainer) Bool(key string) (bool, error) {
+func (c *ConfigContainer) Bool(key string) (bool, error) {
return strconv.ParseBool(c.data[key].(string))
}
// DefaultBool return the bool value if has no error
// otherwise return the defaultval
-func (c *XMLConfigContainer) DefaultBool(key string, defaultval bool) bool {
- if v, err := c.Bool(key); err != nil {
+func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool {
+ v, err := c.Bool(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// Int returns the integer value for a given key.
-func (c *XMLConfigContainer) Int(key string) (int, error) {
+func (c *ConfigContainer) Int(key string) (int, error) {
return strconv.Atoi(c.data[key].(string))
}
// DefaultInt returns the integer value for a given key.
// if err != nil return defaltval
-func (c *XMLConfigContainer) DefaultInt(key string, defaultval int) int {
- if v, err := c.Int(key); err != nil {
+func (c *ConfigContainer) DefaultInt(key string, defaultval int) int {
+ v, err := c.Int(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// Int64 returns the int64 value for a given key.
-func (c *XMLConfigContainer) Int64(key string) (int64, error) {
+func (c *ConfigContainer) Int64(key string) (int64, error) {
return strconv.ParseInt(c.data[key].(string), 10, 64)
}
// DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval
-func (c *XMLConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
- if v, err := c.Int64(key); err != nil {
+func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
+ v, err := c.Int64(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
+
}
// Float returns the float value for a given key.
-func (c *XMLConfigContainer) Float(key string) (float64, error) {
+func (c *ConfigContainer) Float(key string) (float64, error) {
return strconv.ParseFloat(c.data[key].(string), 64)
}
// DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval
-func (c *XMLConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
- if v, err := c.Float(key); err != nil {
+func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
+ v, err := c.Float(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// String returns the string value for a given key.
-func (c *XMLConfigContainer) String(key string) string {
+func (c *ConfigContainer) String(key string) string {
if v, ok := c.data[key].(string); ok {
return v
}
@@ -159,40 +161,39 @@ func (c *XMLConfigContainer) String(key string) string {
// DefaultString returns the string value for a given key.
// if err != nil return defaltval
-func (c *XMLConfigContainer) DefaultString(key string, defaultval string) string {
- if v := c.String(key); v == "" {
+func (c *ConfigContainer) DefaultString(key string, defaultval string) string {
+ v := c.String(key)
+ if v == "" {
return defaultval
- } else {
- return v
}
+ return v
}
// Strings returns the []string value for a given key.
-func (c *XMLConfigContainer) Strings(key string) []string {
+func (c *ConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval
-func (c *XMLConfigContainer) DefaultStrings(key string, defaultval []string) []string {
- if v := c.Strings(key); len(v) == 0 {
+func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string {
+ v := c.Strings(key)
+ if len(v) == 0 {
return defaultval
- } else {
- return v
}
+ return v
}
// GetSection returns map for the given section
-func (c *XMLConfigContainer) GetSection(section string) (map[string]string, error) {
+func (c *ConfigContainer) GetSection(section string) (map[string]string, error) {
if v, ok := c.data[section]; ok {
return v.(map[string]string), nil
- } else {
- return nil, errors.New("not exist setction")
}
+ return nil, errors.New("not exist setction")
}
// SaveConfigFile save the config into file
-func (c *XMLConfigContainer) SaveConfigFile(filename string) (err error) {
+func (c *ConfigContainer) SaveConfigFile(filename string) (err error) {
// Write configuration file by filename.
f, err := os.Create(filename)
if err != nil {
@@ -207,8 +208,8 @@ func (c *XMLConfigContainer) SaveConfigFile(filename string) (err error) {
return err
}
-// WriteValue writes a new value for key.
-func (c *XMLConfigContainer) Set(key, val string) error {
+// Set writes a new value for key.
+func (c *ConfigContainer) Set(key, val string) error {
c.Lock()
defer c.Unlock()
c.data[key] = val
@@ -216,7 +217,7 @@ func (c *XMLConfigContainer) Set(key, val string) error {
}
// DIY returns the raw value by a given key.
-func (c *XMLConfigContainer) DIY(key string) (v interface{}, err error) {
+func (c *ConfigContainer) DIY(key string) (v interface{}, err error) {
if v, ok := c.data[key]; ok {
return v, nil
}
@@ -224,5 +225,5 @@ func (c *XMLConfigContainer) DIY(key string) (v interface{}, err error) {
}
func init() {
- config.Register("xml", &XMLConfig{})
+ config.Register("xml", &Config{})
}
diff --git a/config/yaml/yaml.go b/config/yaml/yaml.go
index c5be44a9..f034d3ba 100644
--- a/config/yaml/yaml.go
+++ b/config/yaml/yaml.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// package yaml for config provider
+// Package yaml for config provider
//
// depend on github.com/beego/goyaml2
//
@@ -46,22 +46,23 @@ import (
"github.com/beego/goyaml2"
)
-// YAMLConfig is a yaml config parser and implements Config interface.
-type YAMLConfig struct{}
+// Config is a yaml config parser and implements Config interface.
+type Config struct{}
// Parse returns a ConfigContainer with parsed yaml config map.
-func (yaml *YAMLConfig) Parse(filename string) (y config.ConfigContainer, err error) {
+func (yaml *Config) Parse(filename string) (y config.Configer, err error) {
cnf, err := ReadYmlReader(filename)
if err != nil {
return
}
- y = &YAMLConfigContainer{
+ y = &ConfigContainer{
data: cnf,
}
return
}
-func (yaml *YAMLConfig) ParseData(data []byte) (config.ConfigContainer, error) {
+// ParseData parse yaml data
+func (yaml *Config) ParseData(data []byte) (config.Configer, error) {
// Save memory data to temporary file
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond()))
os.MkdirAll(path.Dir(tmpName), os.ModePerm)
@@ -71,7 +72,7 @@ func (yaml *YAMLConfig) ParseData(data []byte) (config.ConfigContainer, error) {
return yaml.Parse(tmpName)
}
-// Read yaml file to map.
+// ReadYmlReader Read yaml file to map.
// if json like, use json package, unless goyaml2 package.
func ReadYmlReader(path string) (cnf map[string]interface{}, err error) {
f, err := os.Open(path)
@@ -112,14 +113,14 @@ func ReadYmlReader(path string) (cnf map[string]interface{}, err error) {
return
}
-// A Config represents the yaml configuration.
-type YAMLConfigContainer struct {
+// ConfigContainer A Config represents the yaml configuration.
+type ConfigContainer struct {
data map[string]interface{}
sync.Mutex
}
// Bool returns the boolean value for a given key.
-func (c *YAMLConfigContainer) Bool(key string) (bool, error) {
+func (c *ConfigContainer) Bool(key string) (bool, error) {
if v, ok := c.data[key].(bool); ok {
return v, nil
}
@@ -128,16 +129,16 @@ func (c *YAMLConfigContainer) Bool(key string) (bool, error) {
// DefaultBool return the bool value if has no error
// otherwise return the defaultval
-func (c *YAMLConfigContainer) DefaultBool(key string, defaultval bool) bool {
- if v, err := c.Bool(key); err != nil {
+func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool {
+ v, err := c.Bool(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// Int returns the integer value for a given key.
-func (c *YAMLConfigContainer) Int(key string) (int, error) {
+func (c *ConfigContainer) Int(key string) (int, error) {
if v, ok := c.data[key].(int64); ok {
return int(v), nil
}
@@ -146,16 +147,16 @@ func (c *YAMLConfigContainer) Int(key string) (int, error) {
// DefaultInt returns the integer value for a given key.
// if err != nil return defaltval
-func (c *YAMLConfigContainer) DefaultInt(key string, defaultval int) int {
- if v, err := c.Int(key); err != nil {
+func (c *ConfigContainer) DefaultInt(key string, defaultval int) int {
+ v, err := c.Int(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// Int64 returns the int64 value for a given key.
-func (c *YAMLConfigContainer) Int64(key string) (int64, error) {
+func (c *ConfigContainer) Int64(key string) (int64, error) {
if v, ok := c.data[key].(int64); ok {
return v, nil
}
@@ -164,16 +165,16 @@ func (c *YAMLConfigContainer) Int64(key string) (int64, error) {
// DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval
-func (c *YAMLConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
- if v, err := c.Int64(key); err != nil {
+func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
+ v, err := c.Int64(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// Float returns the float value for a given key.
-func (c *YAMLConfigContainer) Float(key string) (float64, error) {
+func (c *ConfigContainer) Float(key string) (float64, error) {
if v, ok := c.data[key].(float64); ok {
return v, nil
}
@@ -182,16 +183,16 @@ func (c *YAMLConfigContainer) Float(key string) (float64, error) {
// DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval
-func (c *YAMLConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
- if v, err := c.Float(key); err != nil {
+func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
+ v, err := c.Float(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// String returns the string value for a given key.
-func (c *YAMLConfigContainer) String(key string) string {
+func (c *ConfigContainer) String(key string) string {
if v, ok := c.data[key].(string); ok {
return v
}
@@ -200,40 +201,40 @@ func (c *YAMLConfigContainer) String(key string) string {
// DefaultString returns the string value for a given key.
// if err != nil return defaltval
-func (c *YAMLConfigContainer) DefaultString(key string, defaultval string) string {
- if v := c.String(key); v == "" {
+func (c *ConfigContainer) DefaultString(key string, defaultval string) string {
+ v := c.String(key)
+ if v == "" {
return defaultval
- } else {
- return v
}
+ return v
}
// Strings returns the []string value for a given key.
-func (c *YAMLConfigContainer) Strings(key string) []string {
+func (c *ConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval
-func (c *YAMLConfigContainer) DefaultStrings(key string, defaultval []string) []string {
- if v := c.Strings(key); len(v) == 0 {
+func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string {
+ v := c.Strings(key)
+ if len(v) == 0 {
return defaultval
- } else {
- return v
}
+ return v
}
// GetSection returns map for the given section
-func (c *YAMLConfigContainer) GetSection(section string) (map[string]string, error) {
- if v, ok := c.data[section]; ok {
+func (c *ConfigContainer) GetSection(section string) (map[string]string, error) {
+ v, ok := c.data[section]
+ if ok {
return v.(map[string]string), nil
- } else {
- return nil, errors.New("not exist setction")
}
+ return nil, errors.New("not exist setction")
}
// SaveConfigFile save the config into file
-func (c *YAMLConfigContainer) SaveConfigFile(filename string) (err error) {
+func (c *ConfigContainer) SaveConfigFile(filename string) (err error) {
// Write configuration file by filename.
f, err := os.Create(filename)
if err != nil {
@@ -244,8 +245,8 @@ func (c *YAMLConfigContainer) SaveConfigFile(filename string) (err error) {
return err
}
-// WriteValue writes a new value for key.
-func (c *YAMLConfigContainer) Set(key, val string) error {
+// Set writes a new value for key.
+func (c *ConfigContainer) Set(key, val string) error {
c.Lock()
defer c.Unlock()
c.data[key] = val
@@ -253,7 +254,7 @@ func (c *YAMLConfigContainer) Set(key, val string) error {
}
// DIY returns the raw value by a given key.
-func (c *YAMLConfigContainer) DIY(key string) (v interface{}, err error) {
+func (c *ConfigContainer) DIY(key string) (v interface{}, err error) {
if v, ok := c.data[key]; ok {
return v, nil
}
@@ -261,5 +262,5 @@ func (c *YAMLConfigContainer) DIY(key string) (v interface{}, err error) {
}
func init() {
- config.Register("yaml", &YAMLConfig{})
+ config.Register("yaml", &Config{})
}
diff --git a/config_test.go b/config_test.go
index 17645f80..cf4a781d 100644
--- a/config_test.go
+++ b/config_test.go
@@ -19,11 +19,11 @@ import (
)
func TestDefaults(t *testing.T) {
- if FlashName != "BEEGO_FLASH" {
+ if BConfig.WebConfig.FlashName != "BEEGO_FLASH" {
t.Errorf("FlashName was not set to default.")
}
- if FlashSeperator != "BEEGOFLASH" {
+ if BConfig.WebConfig.FlashSeparator != "BEEGOFLASH" {
t.Errorf("FlashName was not set to default.")
}
}
diff --git a/context/acceptencoder.go b/context/acceptencoder.go
new file mode 100644
index 00000000..07c5cb0b
--- /dev/null
+++ b/context/acceptencoder.go
@@ -0,0 +1,198 @@
+// Copyright 2015 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package context
+
+import (
+ "bytes"
+ "compress/flate"
+ "compress/gzip"
+ "compress/zlib"
+ "io"
+ "net/http"
+ "os"
+ "strconv"
+ "strings"
+ "sync"
+)
+
+type resetWriter interface {
+ io.Writer
+ Reset(w io.Writer)
+}
+
+type nopResetWriter struct {
+ io.Writer
+}
+
+func (n nopResetWriter) Reset(w io.Writer) {
+ //do nothing
+}
+
+type acceptEncoder struct {
+ name string
+ levelEncode func(int) resetWriter
+ bestSpeedPool *sync.Pool
+ bestCompressionPool *sync.Pool
+}
+
+func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter {
+ if ac.bestSpeedPool == nil || ac.bestCompressionPool == nil {
+ return nopResetWriter{wr}
+ }
+ var rwr resetWriter
+ switch level {
+ case flate.BestSpeed:
+ rwr = ac.bestSpeedPool.Get().(resetWriter)
+ case flate.BestCompression:
+ rwr = ac.bestCompressionPool.Get().(resetWriter)
+ default:
+ rwr = ac.levelEncode(level)
+ }
+ rwr.Reset(wr)
+ return rwr
+}
+
+func (ac acceptEncoder) put(wr resetWriter, level int) {
+ if ac.bestSpeedPool == nil || ac.bestCompressionPool == nil {
+ return
+ }
+ wr.Reset(nil)
+ switch level {
+ case flate.BestSpeed:
+ ac.bestSpeedPool.Put(wr)
+ case flate.BestCompression:
+ ac.bestCompressionPool.Put(wr)
+ }
+}
+
+var (
+ noneCompressEncoder = acceptEncoder{"", nil, nil, nil}
+ gzipCompressEncoder = acceptEncoder{"gzip",
+ func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr },
+ &sync.Pool{
+ New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestSpeed); return wr },
+ },
+ &sync.Pool{
+ New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr },
+ },
+ }
+
+ //according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed
+ //deflate
+ //The "zlib" format defined in RFC 1950 [31] in combination with
+ //the "deflate" compression mechanism described in RFC 1951 [29].
+ deflateCompressEncoder = acceptEncoder{"deflate",
+ func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr },
+ &sync.Pool{
+ New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestSpeed); return wr },
+ },
+ &sync.Pool{
+ New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr },
+ },
+ }
+)
+
+var (
+ encoderMap = map[string]acceptEncoder{ // all the other compress methods will ignore
+ "gzip": gzipCompressEncoder,
+ "deflate": deflateCompressEncoder,
+ "*": gzipCompressEncoder, // * means any compress will accept,we prefer gzip
+ "identity": noneCompressEncoder, // identity means none-compress
+ }
+)
+
+// WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate)
+func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) {
+ return writeLevel(encoding, writer, file, flate.BestCompression)
+}
+
+// WriteBody reads writes content to writer by the specific encoding(gzip/deflate)
+func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) {
+ return writeLevel(encoding, writer, bytes.NewReader(content), flate.BestSpeed)
+}
+
+// writeLevel reads from reader,writes to writer by specific encoding and compress level
+// the compress level is defined by deflate package
+func writeLevel(encoding string, writer io.Writer, reader io.Reader, level int) (bool, string, error) {
+ var outputWriter resetWriter
+ var err error
+ var ce = noneCompressEncoder
+
+ if cf, ok := encoderMap[encoding]; ok {
+ ce = cf
+ }
+ encoding = ce.name
+ outputWriter = ce.encode(writer, level)
+ defer ce.put(outputWriter, level)
+
+ _, err = io.Copy(outputWriter, reader)
+ if err != nil {
+ return false, "", err
+ }
+
+ switch outputWriter.(type) {
+ case io.WriteCloser:
+ outputWriter.(io.WriteCloser).Close()
+ }
+ return encoding != "", encoding, nil
+}
+
+// ParseEncoding will extract the right encoding for response
+// the Accept-Encoding's sec is here:
+// http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.3
+func ParseEncoding(r *http.Request) string {
+ if r == nil {
+ return ""
+ }
+ return parseEncoding(r)
+}
+
+type q struct {
+ name string
+ value float64
+}
+
+func parseEncoding(r *http.Request) string {
+ acceptEncoding := r.Header.Get("Accept-Encoding")
+ if acceptEncoding == "" {
+ return ""
+ }
+ var lastQ q
+ for _, v := range strings.Split(acceptEncoding, ",") {
+ v = strings.TrimSpace(v)
+ if v == "" {
+ continue
+ }
+ vs := strings.Split(v, ";")
+ if len(vs) == 1 {
+ lastQ = q{vs[0], 1}
+ break
+ }
+ if len(vs) == 2 {
+ f, _ := strconv.ParseFloat(strings.Replace(vs[1], "q=", "", -1), 64)
+ if f == 0 {
+ continue
+ }
+ if f > lastQ.value {
+ lastQ = q{vs[0], f}
+ }
+ }
+ }
+ if cf, ok := encoderMap[lastQ.name]; ok {
+ return cf.name
+ } else {
+ return ""
+ }
+}
diff --git a/context/acceptencoder_test.go b/context/acceptencoder_test.go
new file mode 100644
index 00000000..147313c5
--- /dev/null
+++ b/context/acceptencoder_test.go
@@ -0,0 +1,45 @@
+// Copyright 2015 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package context
+
+import (
+ "net/http"
+ "testing"
+)
+
+func Test_ExtractEncoding(t *testing.T) {
+ if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"gzip,deflate"}}}) != "gzip" {
+ t.Fail()
+ }
+ if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"deflate,gzip"}}}) != "deflate" {
+ t.Fail()
+ }
+ if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"gzip;q=.5,deflate"}}}) != "deflate" {
+ t.Fail()
+ }
+ if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"gzip;q=.5,deflate;q=0.3"}}}) != "gzip" {
+ t.Fail()
+ }
+
+ if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"gzip;q=0,deflate"}}}) != "deflate" {
+ t.Fail()
+ }
+ if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"deflate;q=0.5,gzip;q=0.5,identity"}}}) != "" {
+ t.Fail()
+ }
+ if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"*"}}}) != "gzip" {
+ t.Fail()
+ }
+}
diff --git a/context/context.go b/context/context.go
index f6aa85d6..db790ff2 100644
--- a/context/context.go
+++ b/context/context.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package context provide the context utils
// Usage:
//
// import "github.com/astaxie/beego/context"
@@ -22,10 +23,13 @@
package context
import (
+ "bufio"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
+ "errors"
"fmt"
+ "net"
"net/http"
"strconv"
"strings"
@@ -34,14 +38,30 @@ import (
"github.com/astaxie/beego/utils"
)
-// Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter.
+// NewContext return the Context with Input and Output
+func NewContext() *Context {
+ return &Context{
+ Input: NewInput(),
+ Output: NewOutput(),
+ }
+}
+
+// Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter.
// BeegoInput and BeegoOutput provides some api to operate request and response more easily.
type Context struct {
Input *BeegoInput
Output *BeegoOutput
Request *http.Request
- ResponseWriter http.ResponseWriter
- _xsrf_token string
+ ResponseWriter *Response
+ _xsrfToken string
+}
+
+// Reset init Context, BeegoInput and BeegoOutput
+func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) {
+ ctx.Request = r
+ ctx.ResponseWriter = &Response{rw, false, 0}
+ ctx.Input.Reset(ctx)
+ ctx.Output.Reset(ctx)
}
// Redirect does redirection to localurl with http header status code.
@@ -54,29 +74,28 @@ func (ctx *Context) Redirect(status int, localurl string) {
// Abort stops this request.
// if beego.ErrorMaps exists, panic body.
func (ctx *Context) Abort(status int, body string) {
- ctx.ResponseWriter.WriteHeader(status)
panic(body)
}
-// Write string to response body.
+// WriteString Write string to response body.
// it sends response body.
func (ctx *Context) WriteString(content string) {
ctx.ResponseWriter.Write([]byte(content))
}
-// Get cookie from request by a given key.
+// GetCookie Get cookie from request by a given key.
// It's alias of BeegoInput.Cookie.
func (ctx *Context) GetCookie(key string) string {
return ctx.Input.Cookie(key)
}
-// Set cookie for response.
+// SetCookie Set cookie for response.
// It's alias of BeegoOutput.Cookie.
func (ctx *Context) SetCookie(name string, value string, others ...interface{}) {
ctx.Output.Cookie(name, value, others...)
}
-// Get secure cookie from request by a given key.
+// GetSecureCookie Get secure cookie from request by a given key.
func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) {
val := ctx.Input.Cookie(key)
if val == "" {
@@ -103,7 +122,7 @@ func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) {
return string(res), true
}
-// Set Secure cookie for response.
+// SetSecureCookie Set Secure cookie for response.
func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) {
vs := base64.URLEncoding.EncodeToString([]byte(value))
timestamp := strconv.FormatInt(time.Now().UnixNano(), 10)
@@ -114,23 +133,23 @@ func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interf
ctx.Output.Cookie(name, cookie, others...)
}
-// XsrfToken creates a xsrf token string and returns.
-func (ctx *Context) XsrfToken(key string, expire int64) string {
- if ctx._xsrf_token == "" {
+// XSRFToken creates a xsrf token string and returns.
+func (ctx *Context) XSRFToken(key string, expire int64) string {
+ if ctx._xsrfToken == "" {
token, ok := ctx.GetSecureCookie(key, "_xsrf")
if !ok {
token = string(utils.RandomCreateBytes(32))
ctx.SetSecureCookie(key, "_xsrf", token, expire)
}
- ctx._xsrf_token = token
+ ctx._xsrfToken = token
}
- return ctx._xsrf_token
+ return ctx._xsrfToken
}
-// CheckXsrfCookie checks xsrf token in this request is valid or not.
+// CheckXSRFCookie checks xsrf token in this request is valid or not.
// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken"
// or in form field value named as "_xsrf".
-func (ctx *Context) CheckXsrfCookie() bool {
+func (ctx *Context) CheckXSRFCookie() bool {
token := ctx.Input.Query("_xsrf")
if token == "" {
token = ctx.Request.Header.Get("X-Xsrftoken")
@@ -142,9 +161,57 @@ func (ctx *Context) CheckXsrfCookie() bool {
ctx.Abort(403, "'_xsrf' argument missing from POST")
return false
}
- if ctx._xsrf_token != token {
+ if ctx._xsrfToken != token {
ctx.Abort(403, "XSRF cookie does not match POST argument")
return false
}
return true
}
+
+//Response is a wrapper for the http.ResponseWriter
+//started set to true if response was written to then don't execute other handler
+type Response struct {
+ http.ResponseWriter
+ Started bool
+ Status int
+}
+
+// Write writes the data to the connection as part of an HTTP reply,
+// and sets `started` to true.
+// started means the response has sent out.
+func (w *Response) Write(p []byte) (int, error) {
+ w.Started = true
+ return w.ResponseWriter.Write(p)
+}
+
+// WriteHeader sends an HTTP response header with status code,
+// and sets `started` to true.
+func (w *Response) WriteHeader(code int) {
+ w.Status = code
+ w.Started = true
+ w.ResponseWriter.WriteHeader(code)
+}
+
+// Hijack hijacker for http
+func (w *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ hj, ok := w.ResponseWriter.(http.Hijacker)
+ if !ok {
+ return nil, nil, errors.New("webserver doesn't support hijacking")
+ }
+ return hj.Hijack()
+}
+
+// Flush http.Flusher
+func (w *Response) Flush() {
+ if f, ok := w.ResponseWriter.(http.Flusher); ok {
+ f.Flush()
+ }
+}
+
+// CloseNotify http.CloseNotifier
+func (w *Response) CloseNotify() <-chan bool {
+ if cn, ok := w.ResponseWriter.(http.CloseNotifier); ok {
+ return cn.CloseNotify()
+ }
+ return nil
+}
diff --git a/context/input.go b/context/input.go
index 1985df21..c37204bd 100644
--- a/context/input.go
+++ b/context/input.go
@@ -17,8 +17,8 @@ package context
import (
"bytes"
"errors"
+ "io"
"io/ioutil"
- "net/http"
"net/url"
"reflect"
"regexp"
@@ -31,45 +31,55 @@ import (
// Regexes for checking the accept headers
// TODO make sure these are correct
var (
- acceptsHtmlRegex = regexp.MustCompile(`(text/html|application/xhtml\+xml)(?:,|$)`)
- acceptsXmlRegex = regexp.MustCompile(`(application/xml|text/xml)(?:,|$)`)
- acceptsJsonRegex = regexp.MustCompile(`(application/json)(?:,|$)`)
+ acceptsHTMLRegex = regexp.MustCompile(`(text/html|application/xhtml\+xml)(?:,|$)`)
+ acceptsXMLRegex = regexp.MustCompile(`(application/xml|text/xml)(?:,|$)`)
+ acceptsJSONRegex = regexp.MustCompile(`(application/json)(?:,|$)`)
+ maxParam = 50
)
// BeegoInput operates the http request header, data, cookie and body.
// it also contains router params and current session.
type BeegoInput struct {
- CruSession session.SessionStore
- Params map[string]string
- Data map[interface{}]interface{} // store some values in this context when calling context in filter or controller.
- Request *http.Request
- RequestBody []byte
- RunController reflect.Type
- RunMethod string
+ Context *Context
+ CruSession session.Store
+ pnames []string
+ pvalues []string
+ data map[interface{}]interface{} // store some values in this context when calling context in filter or controller.
+ RequestBody []byte
}
-// NewInput return BeegoInput generated by http.Request.
-func NewInput(req *http.Request) *BeegoInput {
+// NewInput return BeegoInput generated by Context.
+func NewInput() *BeegoInput {
return &BeegoInput{
- Params: make(map[string]string),
- Data: make(map[interface{}]interface{}),
- Request: req,
+ pnames: make([]string, 0, maxParam),
+ pvalues: make([]string, 0, maxParam),
+ data: make(map[interface{}]interface{}),
}
}
+// Reset init the BeegoInput
+func (input *BeegoInput) Reset(ctx *Context) {
+ input.Context = ctx
+ input.CruSession = nil
+ input.pnames = input.pnames[:0]
+ input.pvalues = input.pvalues[:0]
+ input.data = nil
+ input.RequestBody = []byte{}
+}
+
// Protocol returns request protocol name, such as HTTP/1.1 .
func (input *BeegoInput) Protocol() string {
- return input.Request.Proto
+ return input.Context.Request.Proto
}
-// Uri returns full request url with query string, fragment.
-func (input *BeegoInput) Uri() string {
- return input.Request.RequestURI
+// URI returns full request url with query string, fragment.
+func (input *BeegoInput) URI() string {
+ return input.Context.Request.RequestURI
}
-// Url returns request url path (without query string, fragment).
-func (input *BeegoInput) Url() string {
- return input.Request.URL.Path
+// URL returns request url path (without query string, fragment).
+func (input *BeegoInput) URL() string {
+ return input.Context.Request.URL.Path
}
// Site returns base site url as scheme://domain type.
@@ -79,10 +89,10 @@ func (input *BeegoInput) Site() string {
// Scheme returns request scheme as "http" or "https".
func (input *BeegoInput) Scheme() string {
- if input.Request.URL.Scheme != "" {
- return input.Request.URL.Scheme
+ if input.Context.Request.URL.Scheme != "" {
+ return input.Context.Request.URL.Scheme
}
- if input.Request.TLS == nil {
+ if input.Context.Request.TLS == nil {
return "http"
}
return "https"
@@ -97,19 +107,19 @@ func (input *BeegoInput) Domain() string {
// Host returns host name.
// if no host info in request, return localhost.
func (input *BeegoInput) Host() string {
- if input.Request.Host != "" {
- hostParts := strings.Split(input.Request.Host, ":")
+ if input.Context.Request.Host != "" {
+ hostParts := strings.Split(input.Context.Request.Host, ":")
if len(hostParts) > 0 {
return hostParts[0]
}
- return input.Request.Host
+ return input.Context.Request.Host
}
return "localhost"
}
// Method returns http request method.
func (input *BeegoInput) Method() string {
- return input.Request.Method
+ return input.Context.Request.Method
}
// Is returns boolean of this request is on given method, such as Is("POST").
@@ -117,37 +127,37 @@ func (input *BeegoInput) Is(method string) bool {
return input.Method() == method
}
-// Is this a GET method request?
+// IsGet Is this a GET method request?
func (input *BeegoInput) IsGet() bool {
return input.Is("GET")
}
-// Is this a POST method request?
+// IsPost Is this a POST method request?
func (input *BeegoInput) IsPost() bool {
return input.Is("POST")
}
-// Is this a Head method request?
+// IsHead Is this a Head method request?
func (input *BeegoInput) IsHead() bool {
return input.Is("HEAD")
}
-// Is this a OPTIONS method request?
+// IsOptions Is this a OPTIONS method request?
func (input *BeegoInput) IsOptions() bool {
return input.Is("OPTIONS")
}
-// Is this a PUT method request?
+// IsPut Is this a PUT method request?
func (input *BeegoInput) IsPut() bool {
return input.Is("PUT")
}
-// Is this a DELETE method request?
+// IsDelete Is this a DELETE method request?
func (input *BeegoInput) IsDelete() bool {
return input.Is("DELETE")
}
-// Is this a PATCH method request?
+// IsPatch Is this a PATCH method request?
func (input *BeegoInput) IsPatch() bool {
return input.Is("PATCH")
}
@@ -172,19 +182,19 @@ func (input *BeegoInput) IsUpload() bool {
return strings.Contains(input.Header("Content-Type"), "multipart/form-data")
}
-// Checks if request accepts html response
-func (input *BeegoInput) AcceptsHtml() bool {
- return acceptsHtmlRegex.MatchString(input.Header("Accept"))
+// AcceptsHTML Checks if request accepts html response
+func (input *BeegoInput) AcceptsHTML() bool {
+ return acceptsHTMLRegex.MatchString(input.Header("Accept"))
}
-// Checks if request accepts xml response
-func (input *BeegoInput) AcceptsXml() bool {
- return acceptsXmlRegex.MatchString(input.Header("Accept"))
+// AcceptsXML Checks if request accepts xml response
+func (input *BeegoInput) AcceptsXML() bool {
+ return acceptsXMLRegex.MatchString(input.Header("Accept"))
}
-// Checks if request accepts json response
-func (input *BeegoInput) AcceptsJson() bool {
- return acceptsJsonRegex.MatchString(input.Header("Accept"))
+// AcceptsJSON Checks if request accepts json response
+func (input *BeegoInput) AcceptsJSON() bool {
+ return acceptsJSONRegex.MatchString(input.Header("Accept"))
}
// IP returns request client ip.
@@ -196,7 +206,7 @@ func (input *BeegoInput) IP() string {
rip := strings.Split(ips[0], ":")
return rip[0]
}
- ip := strings.Split(input.Request.RemoteAddr, ":")
+ ip := strings.Split(input.Context.Request.RemoteAddr, ":")
if len(ip) > 0 {
if ip[0] != "[" {
return ip[0]
@@ -236,7 +246,7 @@ func (input *BeegoInput) SubDomains() string {
// Port returns request client port.
// when error or empty, return 80.
func (input *BeegoInput) Port() int {
- parts := strings.Split(input.Request.Host, ":")
+ parts := strings.Split(input.Context.Request.Host, ":")
if len(parts) == 2 {
port, _ := strconv.Atoi(parts[1])
return port
@@ -249,35 +259,59 @@ func (input *BeegoInput) UserAgent() string {
return input.Header("User-Agent")
}
+// ParamsLen return the length of the params
+func (input *BeegoInput) ParamsLen() int {
+ return len(input.pnames)
+}
+
// Param returns router param by a given key.
func (input *BeegoInput) Param(key string) string {
- if v, ok := input.Params[key]; ok {
- return v
+ for i, v := range input.pnames {
+ if v == key && i <= len(input.pvalues) {
+ return input.pvalues[i]
+ }
}
return ""
}
+// Params returns the map[key]value.
+func (input *BeegoInput) Params() map[string]string {
+ m := make(map[string]string)
+ for i, v := range input.pnames {
+ if i <= len(input.pvalues) {
+ m[v] = input.pvalues[i]
+ }
+ }
+ return m
+}
+
+// SetParam will set the param with key and value
+func (input *BeegoInput) SetParam(key, val string) {
+ input.pvalues = append(input.pvalues, val)
+ input.pnames = append(input.pnames, key)
+}
+
// Query returns input data item string by a given string.
func (input *BeegoInput) Query(key string) string {
if val := input.Param(key); val != "" {
return val
}
- if input.Request.Form == nil {
- input.Request.ParseForm()
+ if input.Context.Request.Form == nil {
+ input.Context.Request.ParseForm()
}
- return input.Request.Form.Get(key)
+ return input.Context.Request.Form.Get(key)
}
// Header returns request header item string by a given string.
// if non-existed, return empty string.
func (input *BeegoInput) Header(key string) string {
- return input.Request.Header.Get(key)
+ return input.Context.Request.Header.Get(key)
}
// Cookie returns request cookie item string by a given key.
// if non-existed, return empty string.
func (input *BeegoInput) Cookie(key string) string {
- ck, err := input.Request.Cookie(key)
+ ck, err := input.Context.Request.Cookie(key)
if err != nil {
return ""
}
@@ -291,18 +325,27 @@ func (input *BeegoInput) Session(key interface{}) interface{} {
}
// CopyBody returns the raw request body data as bytes.
-func (input *BeegoInput) CopyBody() []byte {
- requestbody, _ := ioutil.ReadAll(input.Request.Body)
- input.Request.Body.Close()
+func (input *BeegoInput) CopyBody(MaxMemory int64) []byte {
+ safe := &io.LimitedReader{R: input.Context.Request.Body, N: MaxMemory}
+ requestbody, _ := ioutil.ReadAll(safe)
+ input.Context.Request.Body.Close()
bf := bytes.NewBuffer(requestbody)
- input.Request.Body = ioutil.NopCloser(bf)
+ input.Context.Request.Body = ioutil.NopCloser(bf)
input.RequestBody = requestbody
return requestbody
}
+// Data return the implicit data in the input
+func (input *BeegoInput) Data() map[interface{}]interface{} {
+ if input.data == nil {
+ input.data = make(map[interface{}]interface{})
+ }
+ return input.data
+}
+
// GetData returns the stored data in this context.
func (input *BeegoInput) GetData(key interface{}) interface{} {
- if v, ok := input.Data[key]; ok {
+ if v, ok := input.data[key]; ok {
return v
}
return nil
@@ -311,17 +354,20 @@ func (input *BeegoInput) GetData(key interface{}) interface{} {
// SetData stores data with given key in this context.
// This data are only available in this context.
func (input *BeegoInput) SetData(key, val interface{}) {
- input.Data[key] = val
+ if input.data == nil {
+ input.data = make(map[interface{}]interface{})
+ }
+ input.data[key] = val
}
-// parseForm or parseMultiForm based on Content-type
+// ParseFormOrMulitForm parseForm or parseMultiForm based on Content-type
func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error {
// Parse the body depending on the content type.
if strings.Contains(input.Header("Content-Type"), "multipart/form-data") {
- if err := input.Request.ParseMultipartForm(maxMemory); err != nil {
+ if err := input.Context.Request.ParseMultipartForm(maxMemory); err != nil {
return errors.New("Error parsing request body:" + err.Error())
}
- } else if err := input.Request.ParseForm(); err != nil {
+ } else if err := input.Context.Request.ParseForm(); err != nil {
return errors.New("Error parsing request body:" + err.Error())
}
return nil
@@ -353,7 +399,7 @@ func (input *BeegoInput) Bind(dest interface{}, key string) error {
}
func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value {
- rv := reflect.Zero(reflect.TypeOf(0))
+ rv := reflect.Zero(typ)
switch typ.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
val := input.Query(key)
@@ -386,19 +432,19 @@ func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value {
}
rv = input.bindBool(val, typ)
case reflect.Slice:
- rv = input.bindSlice(&input.Request.Form, key, typ)
+ rv = input.bindSlice(&input.Context.Request.Form, key, typ)
case reflect.Struct:
- rv = input.bindStruct(&input.Request.Form, key, typ)
+ rv = input.bindStruct(&input.Context.Request.Form, key, typ)
case reflect.Ptr:
rv = input.bindPoint(key, typ)
case reflect.Map:
- rv = input.bindMap(&input.Request.Form, key, typ)
+ rv = input.bindMap(&input.Context.Request.Form, key, typ)
}
return rv
}
func (input *BeegoInput) bindValue(val string, typ reflect.Type) reflect.Value {
- rv := reflect.Zero(reflect.TypeOf(0))
+ rv := reflect.Zero(typ)
switch typ.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
rv = input.bindInt(val, typ)
diff --git a/context/input_test.go b/context/input_test.go
index 4566f6d6..618e1254 100644
--- a/context/input_test.go
+++ b/context/input_test.go
@@ -17,12 +17,15 @@ package context
import (
"fmt"
"net/http"
+ "net/http/httptest"
"testing"
)
func TestParse(t *testing.T) {
r, _ := http.NewRequest("GET", "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil)
- beegoInput := NewInput(r)
+ beegoInput := NewInput()
+ beegoInput.Context = NewContext()
+ beegoInput.Context.Reset(httptest.NewRecorder(), r)
beegoInput.ParseFormOrMulitForm(1 << 20)
var id int
@@ -73,7 +76,9 @@ func TestParse(t *testing.T) {
func TestSubDomain(t *testing.T) {
r, _ := http.NewRequest("GET", "http://www.example.com/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil)
- beegoInput := NewInput(r)
+ beegoInput := NewInput()
+ beegoInput.Context = NewContext()
+ beegoInput.Context.Reset(httptest.NewRecorder(), r)
subdomain := beegoInput.SubDomains()
if subdomain != "www" {
@@ -81,13 +86,13 @@ func TestSubDomain(t *testing.T) {
}
r, _ = http.NewRequest("GET", "http://localhost/", nil)
- beegoInput.Request = r
+ beegoInput.Context.Request = r
if beegoInput.SubDomains() != "" {
t.Fatal("Subdomain parse error, should be empty, got " + beegoInput.SubDomains())
}
r, _ = http.NewRequest("GET", "http://aa.bb.example.com/", nil)
- beegoInput.Request = r
+ beegoInput.Context.Request = r
if beegoInput.SubDomains() != "aa.bb" {
t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains())
}
@@ -101,13 +106,13 @@ func TestSubDomain(t *testing.T) {
*/
r, _ = http.NewRequest("GET", "http://example.com/", nil)
- beegoInput.Request = r
+ beegoInput.Context.Request = r
if beegoInput.SubDomains() != "" {
t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains())
}
r, _ = http.NewRequest("GET", "http://aa.bb.cc.dd.example.com/", nil)
- beegoInput.Request = r
+ beegoInput.Context.Request = r
if beegoInput.SubDomains() != "aa.bb.cc.dd" {
t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains())
}
diff --git a/context/output.go b/context/output.go
index 7edde552..2d756e27 100644
--- a/context/output.go
+++ b/context/output.go
@@ -16,8 +16,6 @@ package context
import (
"bytes"
- "compress/flate"
- "compress/gzip"
"encoding/json"
"encoding/xml"
"errors"
@@ -45,6 +43,12 @@ func NewOutput() *BeegoOutput {
return &BeegoOutput{}
}
+// Reset init BeegoOutput
+func (output *BeegoOutput) Reset(ctx *Context) {
+ output.Context = ctx
+ output.Status = 0
+}
+
// Header sets response header item string via given key.
func (output *BeegoOutput) Header(key, val string) {
output.Context.ResponseWriter.Header().Set(key, val)
@@ -54,30 +58,16 @@ func (output *BeegoOutput) Header(key, val string) {
// if EnableGzip, compress content string.
// it sends out response body directly.
func (output *BeegoOutput) Body(content []byte) {
- output_writer := output.Context.ResponseWriter.(io.Writer)
- if output.EnableGzip == true && output.Context.Input.Header("Accept-Encoding") != "" {
- splitted := strings.SplitN(output.Context.Input.Header("Accept-Encoding"), ",", -1)
- encodings := make([]string, len(splitted))
-
- for i, val := range splitted {
- encodings[i] = strings.TrimSpace(val)
- }
- for _, val := range encodings {
- if val == "gzip" {
- output.Header("Content-Encoding", "gzip")
- output_writer, _ = gzip.NewWriterLevel(output.Context.ResponseWriter, gzip.BestSpeed)
-
- break
- } else if val == "deflate" {
- output.Header("Content-Encoding", "deflate")
- output_writer, _ = flate.NewWriter(output.Context.ResponseWriter, flate.BestSpeed)
- break
- }
- }
+ var encoding string
+ var buf = &bytes.Buffer{}
+ if output.EnableGzip {
+ encoding = ParseEncoding(output.Context.Request)
+ }
+ if b, n, _ := WriteBody(encoding, buf, content); b {
+ output.Header("Content-Encoding", n)
} else {
output.Header("Content-Length", strconv.Itoa(len(content)))
}
-
// Write status code if it has been set manually
// Set it to 0 afterwards to prevent "multiple response.WriteHeader calls"
if output.Status != 0 {
@@ -85,13 +75,7 @@ func (output *BeegoOutput) Body(content []byte) {
output.Status = 0
}
- output_writer.Write(content)
- switch output_writer.(type) {
- case *gzip.Writer:
- output_writer.(*gzip.Writer).Close()
- case *flate.Writer:
- output_writer.(*flate.Writer).Close()
- }
+ io.Copy(output.Context.ResponseWriter, buf)
}
// Cookie sets cookie value via given key.
@@ -100,29 +84,25 @@ func (output *BeegoOutput) Cookie(name string, value string, others ...interface
var b bytes.Buffer
fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value))
- //fix cookie not work in IE
- if len(others) > 0 {
- switch v := others[0].(type) {
- case int:
- if v > 0 {
- fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(v) * time.Second).UTC().Format(time.RFC1123), v)
- } else if v < 0 {
- fmt.Fprintf(&b, "; Max-Age=0")
- }
- case int64:
- if v > 0 {
- fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(v) * time.Second).UTC().Format(time.RFC1123), v)
- } else if v < 0 {
- fmt.Fprintf(&b, "; Max-Age=0")
- }
- case int32:
- if v > 0 {
- fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(v) * time.Second).UTC().Format(time.RFC1123), v)
- } else if v < 0 {
- fmt.Fprintf(&b, "; Max-Age=0")
- }
- }
- }
+ //fix cookie not work in IE
+ if len(others) > 0 {
+ var maxAge int64
+
+ switch v := others[0].(type) {
+ case int:
+ maxAge = int64(v)
+ case int32:
+ maxAge = int64(v)
+ case int64:
+ maxAge = v
+ }
+
+ if maxAge > 0 {
+ fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(maxAge)*time.Second).UTC().Format(time.RFC1123), maxAge)
+ } else {
+ fmt.Fprintf(&b, "; Max-Age=0")
+ }
+ }
// the settings below
// Path, Domain, Secure, HttpOnly
@@ -188,9 +168,9 @@ func sanitizeValue(v string) string {
return cookieValueSanitizer.Replace(v)
}
-// Json writes json to response body.
+// JSON writes json to response body.
// if coding is true, it converts utf-8 to \u0000 type.
-func (output *BeegoOutput) Json(data interface{}, hasIndent bool, coding bool) error {
+func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, coding bool) error {
output.Header("Content-Type", "application/json; charset=utf-8")
var content []byte
var err error
@@ -204,14 +184,14 @@ func (output *BeegoOutput) Json(data interface{}, hasIndent bool, coding bool) e
return err
}
if coding {
- content = []byte(stringsToJson(string(content)))
+ content = []byte(stringsToJSON(string(content)))
}
output.Body(content)
return nil
}
-// Jsonp writes jsonp to response body.
-func (output *BeegoOutput) Jsonp(data interface{}, hasIndent bool) error {
+// JSONP writes jsonp to response body.
+func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error {
output.Header("Content-Type", "application/javascript; charset=utf-8")
var content []byte
var err error
@@ -228,16 +208,16 @@ func (output *BeegoOutput) Jsonp(data interface{}, hasIndent bool) error {
if callback == "" {
return errors.New(`"callback" parameter required`)
}
- callback_content := bytes.NewBufferString(" " + template.JSEscapeString(callback))
- callback_content.WriteString("(")
- callback_content.Write(content)
- callback_content.WriteString(");\r\n")
- output.Body(callback_content.Bytes())
+ callbackContent := bytes.NewBufferString(" " + template.JSEscapeString(callback))
+ callbackContent.WriteString("(")
+ callbackContent.Write(content)
+ callbackContent.WriteString(");\r\n")
+ output.Body(callbackContent.Bytes())
return nil
}
-// Xml writes xml string to response body.
-func (output *BeegoOutput) Xml(data interface{}, hasIndent bool) error {
+// XML writes xml string to response body.
+func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error {
output.Header("Content-Type", "application/xml; charset=utf-8")
var content []byte
var err error
@@ -331,7 +311,7 @@ func (output *BeegoOutput) IsNotFound(status int) bool {
return output.Status == 404
}
-// IsClient returns boolean of this request client sends error data.
+// IsClientError returns boolean of this request client sends error data.
// HTTP 4xx means forbidden.
func (output *BeegoOutput) IsClientError(status int) bool {
return output.Status >= 400 && output.Status < 500
@@ -343,7 +323,7 @@ func (output *BeegoOutput) IsServerError(status int) bool {
return output.Status >= 500 && output.Status < 600
}
-func stringsToJson(str string) string {
+func stringsToJSON(str string) string {
rs := []rune(str)
jsons := ""
for _, r := range rs {
@@ -357,7 +337,7 @@ func stringsToJson(str string) string {
return jsons
}
-// Sessions sets session item value with given key.
+// Session sets session item value with given key.
func (output *BeegoOutput) Session(name interface{}, value interface{}) {
output.Context.Input.CruSession.Set(name, value)
}
diff --git a/controller.go b/controller.go
index a8b9e8d3..ab261a56 100644
--- a/controller.go
+++ b/controller.go
@@ -19,7 +19,6 @@ import (
"errors"
"html/template"
"io"
- "io/ioutil"
"mime/multipart"
"net/http"
"net/url"
@@ -34,18 +33,19 @@ import (
//commonly used mime-types
const (
- applicationJson = "application/json"
- applicationXml = "application/xml"
- textXml = "text/xml"
+ applicationJSON = "application/json"
+ applicationXML = "application/xml"
+ textXML = "text/xml"
)
var (
- // custom error when user stop request handler manually.
- USERSTOPRUN = errors.New("User stop run")
- GlobalControllerRouter map[string][]ControllerComments = make(map[string][]ControllerComments) //pkgpath+controller:comments
+ // ErrAbort custom error when user stop request handler manually.
+ ErrAbort = errors.New("User stop run")
+ // GlobalControllerRouter store comments with controller. pkgpath+controller:comments
+ GlobalControllerRouter = make(map[string][]ControllerComments)
)
-// store the comment for the controller method
+// ControllerComments store the comment for the controller method
type ControllerComments struct {
Method string
Router string
@@ -56,22 +56,31 @@ type ControllerComments struct {
// Controller defines some basic http request handler operations, such as
// http context, template and view, session and xsrf.
type Controller struct {
- Ctx *context.Context
- Data map[interface{}]interface{}
+ // context data
+ Ctx *context.Context
+ Data map[interface{}]interface{}
+
+ // route controller info
controllerName string
actionName string
- TplNames string
+ methodMapping map[string]func() //method:routertree
+ gotofunc string
+ AppController interface{}
+
+ // template data
+ TplName string
Layout string
LayoutSections map[string]string // the key is the section name and the value is the template name
TplExt string
- _xsrf_token string
- gotofunc string
- CruSession session.SessionStore
- XSRFExpire int
- AppController interface{}
EnableRender bool
- EnableXSRF bool
- methodMapping map[string]func() //method:routertree
+
+ // xsrf data
+ _xsrfToken string
+ XSRFExpire int
+ EnableXSRF bool
+
+ // session
+ CruSession session.Store
}
// ControllerInterface is an interface to uniform all controller handler.
@@ -87,8 +96,8 @@ type ControllerInterface interface {
Options()
Finish()
Render() error
- XsrfToken() string
- CheckXsrfCookie() bool
+ XSRFToken() string
+ CheckXSRFCookie() bool
HandlerFunc(fn string) bool
URLMapping()
}
@@ -96,7 +105,7 @@ type ControllerInterface interface {
// Init generates default values of controller operations.
func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) {
c.Layout = ""
- c.TplNames = ""
+ c.TplName = ""
c.controllerName = controllerName
c.actionName = actionName
c.Ctx = ctx
@@ -104,19 +113,15 @@ func (c *Controller) Init(ctx *context.Context, controllerName, actionName strin
c.AppController = app
c.EnableRender = true
c.EnableXSRF = true
- c.Data = ctx.Input.Data
+ c.Data = ctx.Input.Data()
c.methodMapping = make(map[string]func())
}
// Prepare runs after Init before request function execution.
-func (c *Controller) Prepare() {
-
-}
+func (c *Controller) Prepare() {}
// Finish runs after request function execution.
-func (c *Controller) Finish() {
-
-}
+func (c *Controller) Finish() {}
// Get adds a request function to handle GET request.
func (c *Controller) Get() {
@@ -153,20 +158,19 @@ func (c *Controller) Options() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405)
}
-// call function fn
+// HandlerFunc call function with the name
func (c *Controller) HandlerFunc(fnname string) bool {
if v, ok := c.methodMapping[fnname]; ok {
v()
return true
- } else {
- return false
}
+ return false
}
// URLMapping register the internal Controller router.
-func (c *Controller) URLMapping() {
-}
+func (c *Controller) URLMapping() {}
+// Mapping the method to function
func (c *Controller) Mapping(method string, fn func()) {
c.methodMapping[method] = fn
}
@@ -177,13 +181,11 @@ func (c *Controller) Render() error {
return nil
}
rb, err := c.RenderBytes()
-
if err != nil {
return err
- } else {
- c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8")
- c.Ctx.Output.Body(rb)
}
+ c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8")
+ c.Ctx.Output.Body(rb)
return nil
}
@@ -196,24 +198,33 @@ func (c *Controller) RenderString() (string, error) {
// RenderBytes returns the bytes of rendered template string. Do not send out response.
func (c *Controller) RenderBytes() ([]byte, error) {
//if the controller has set layout, then first get the tplname's content set the content to the layout
+ var buf bytes.Buffer
if c.Layout != "" {
- if c.TplNames == "" {
- c.TplNames = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt
+ if c.TplName == "" {
+ c.TplName = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt
}
- if RunMode == "dev" {
- BuildTemplate(ViewsPath)
+
+ if BConfig.RunMode == DEV {
+ buildFiles := []string{c.TplName}
+ if c.LayoutSections != nil {
+ for _, sectionTpl := range c.LayoutSections {
+ if sectionTpl == "" {
+ continue
+ }
+ buildFiles = append(buildFiles, sectionTpl)
+ }
+ }
+ BuildTemplate(BConfig.WebConfig.ViewsPath, buildFiles...)
}
- newbytes := bytes.NewBufferString("")
- if _, ok := BeeTemplates[c.TplNames]; !ok {
- panic("can't find templatefile in the path:" + c.TplNames)
+ if _, ok := BeeTemplates[c.TplName]; !ok {
+ panic("can't find templatefile in the path:" + c.TplName)
}
- err := BeeTemplates[c.TplNames].ExecuteTemplate(newbytes, c.TplNames, c.Data)
+ err := BeeTemplates[c.TplName].ExecuteTemplate(&buf, c.TplName, c.Data)
if err != nil {
Trace("template Execute err:", err)
return nil, err
}
- tplcontent, _ := ioutil.ReadAll(newbytes)
- c.Data["LayoutContent"] = template.HTML(string(tplcontent))
+ c.Data["LayoutContent"] = template.HTML(buf.String())
if c.LayoutSections != nil {
for sectionName, sectionTpl := range c.LayoutSections {
@@ -222,44 +233,41 @@ func (c *Controller) RenderBytes() ([]byte, error) {
continue
}
- sectionBytes := bytes.NewBufferString("")
- err = BeeTemplates[sectionTpl].ExecuteTemplate(sectionBytes, sectionTpl, c.Data)
+ buf.Reset()
+ err = BeeTemplates[sectionTpl].ExecuteTemplate(&buf, sectionTpl, c.Data)
if err != nil {
Trace("template Execute err:", err)
return nil, err
}
- sectionContent, _ := ioutil.ReadAll(sectionBytes)
- c.Data[sectionName] = template.HTML(string(sectionContent))
+ c.Data[sectionName] = template.HTML(buf.String())
}
}
- ibytes := bytes.NewBufferString("")
- err = BeeTemplates[c.Layout].ExecuteTemplate(ibytes, c.Layout, c.Data)
+ buf.Reset()
+ err = BeeTemplates[c.Layout].ExecuteTemplate(&buf, c.Layout, c.Data)
if err != nil {
Trace("template Execute err:", err)
return nil, err
}
- icontent, _ := ioutil.ReadAll(ibytes)
- return icontent, nil
- } else {
- if c.TplNames == "" {
- c.TplNames = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt
- }
- if RunMode == "dev" {
- BuildTemplate(ViewsPath)
- }
- ibytes := bytes.NewBufferString("")
- if _, ok := BeeTemplates[c.TplNames]; !ok {
- panic("can't find templatefile in the path:" + c.TplNames)
- }
- err := BeeTemplates[c.TplNames].ExecuteTemplate(ibytes, c.TplNames, c.Data)
- if err != nil {
- Trace("template Execute err:", err)
- return nil, err
- }
- icontent, _ := ioutil.ReadAll(ibytes)
- return icontent, nil
+ return buf.Bytes(), nil
}
+
+ if c.TplName == "" {
+ c.TplName = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt
+ }
+ if BConfig.RunMode == DEV {
+ BuildTemplate(BConfig.WebConfig.ViewsPath, c.TplName)
+ }
+ if _, ok := BeeTemplates[c.TplName]; !ok {
+ panic("can't find templatefile in the path:" + c.TplName)
+ }
+ buf.Reset()
+ err := BeeTemplates[c.TplName].ExecuteTemplate(&buf, c.TplName, c.Data)
+ if err != nil {
+ Trace("template Execute err:", err)
+ return nil, err
+ }
+ return buf.Bytes(), nil
}
// Redirect sends the redirection response to url with status code.
@@ -267,7 +275,7 @@ func (c *Controller) Redirect(url string, code int) {
c.Ctx.Redirect(code, url)
}
-// Aborts stops controller handler and show the error data if code is defined in ErrorMap or code string.
+// Abort stops controller handler and show the error data if code is defined in ErrorMap or code string.
func (c *Controller) Abort(code string) {
status, err := strconv.Atoi(code)
if err != nil {
@@ -285,74 +293,69 @@ func (c *Controller) CustomAbort(status int, body string) {
}
// last panic user string
c.Ctx.ResponseWriter.Write([]byte(body))
- panic(USERSTOPRUN)
+ panic(ErrAbort)
}
// StopRun makes panic of USERSTOPRUN error and go to recover function if defined.
func (c *Controller) StopRun() {
- panic(USERSTOPRUN)
+ panic(ErrAbort)
}
-// UrlFor does another controller handler in this request function.
+// URLFor does another controller handler in this request function.
// it goes to this controller method if endpoint is not clear.
-func (c *Controller) UrlFor(endpoint string, values ...interface{}) string {
- if len(endpoint) <= 0 {
+func (c *Controller) URLFor(endpoint string, values ...interface{}) string {
+ if len(endpoint) == 0 {
return ""
}
if endpoint[0] == '.' {
- return UrlFor(reflect.Indirect(reflect.ValueOf(c.AppController)).Type().Name()+endpoint, values...)
- } else {
- return UrlFor(endpoint, values...)
+ return URLFor(reflect.Indirect(reflect.ValueOf(c.AppController)).Type().Name()+endpoint, values...)
}
+ return URLFor(endpoint, values...)
}
-// ServeJson sends a json response with encoding charset.
-func (c *Controller) ServeJson(encoding ...bool) {
- var hasIndent bool
- var hasencoding bool
- if RunMode == "prod" {
+// ServeJSON sends a json response with encoding charset.
+func (c *Controller) ServeJSON(encoding ...bool) {
+ var (
+ hasIndent = true
+ hasEncoding = false
+ )
+ if BConfig.RunMode == PROD {
hasIndent = false
- } else {
- hasIndent = true
}
if len(encoding) > 0 && encoding[0] == true {
- hasencoding = true
+ hasEncoding = true
}
- c.Ctx.Output.Json(c.Data["json"], hasIndent, hasencoding)
+ c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding)
}
-// ServeJsonp sends a jsonp response.
-func (c *Controller) ServeJsonp() {
- var hasIndent bool
- if RunMode == "prod" {
+// ServeJSONP sends a jsonp response.
+func (c *Controller) ServeJSONP() {
+ hasIndent := true
+ if BConfig.RunMode == PROD {
hasIndent = false
- } else {
- hasIndent = true
}
- c.Ctx.Output.Jsonp(c.Data["jsonp"], hasIndent)
+ c.Ctx.Output.JSONP(c.Data["jsonp"], hasIndent)
}
-// ServeXml sends xml response.
-func (c *Controller) ServeXml() {
- var hasIndent bool
- if RunMode == "prod" {
+// ServeXML sends xml response.
+func (c *Controller) ServeXML() {
+ hasIndent := true
+ if BConfig.RunMode == PROD {
hasIndent = false
- } else {
- hasIndent = true
}
- c.Ctx.Output.Xml(c.Data["xml"], hasIndent)
+ c.Ctx.Output.XML(c.Data["xml"], hasIndent)
}
// ServeFormatted serve Xml OR Json, depending on the value of the Accept header
func (c *Controller) ServeFormatted() {
accept := c.Ctx.Input.Header("Accept")
switch accept {
- case applicationJson:
- c.ServeJson()
- case applicationXml, textXml:
- c.ServeXml()
+ case applicationJSON:
+ c.ServeJSON()
+ case applicationXML, textXML:
+ c.ServeXML()
default:
- c.ServeJson()
+ c.ServeJSON()
}
}
@@ -371,16 +374,13 @@ func (c *Controller) ParseForm(obj interface{}) error {
// GetString returns the input value by key string or the default value while it's present and input is blank
func (c *Controller) GetString(key string, def ...string) string {
- var defv string
- if len(def) > 0 {
- defv = def[0]
- }
-
if v := c.Ctx.Input.Query(key); v != "" {
return v
- } else {
- return defv
}
+ if len(def) > 0 {
+ return def[0]
+ }
+ return ""
}
// GetStrings returns the input string slice by key string or the default value while it's present and input is blank
@@ -391,106 +391,81 @@ func (c *Controller) GetStrings(key string, def ...[]string) []string {
defv = def[0]
}
- f := c.Input()
- if f == nil {
+ if f := c.Input(); f == nil {
return defv
+ } else {
+ if vs := f[key]; len(vs) > 0 {
+ return vs
+ }
}
- vs := f[key]
- if len(vs) > 0 {
- return vs
- } else {
- return defv
- }
+ return defv
}
// GetInt returns input as an int or the default value while it's present and input is blank
func (c *Controller) GetInt(key string, def ...int) (int, error) {
- if strv := c.Ctx.Input.Query(key); strv != "" {
- return strconv.Atoi(strv)
- } else if len(def) > 0 {
+ strv := c.Ctx.Input.Query(key)
+ if len(strv) == 0 && len(def) > 0 {
return def[0], nil
- } else {
- return strconv.Atoi(strv)
}
+ return strconv.Atoi(strv)
}
// GetInt8 return input as an int8 or the default value while it's present and input is blank
func (c *Controller) GetInt8(key string, def ...int8) (int8, error) {
- if strv := c.Ctx.Input.Query(key); strv != "" {
- i64, err := strconv.ParseInt(strv, 10, 8)
- i8 := int8(i64)
- return i8, err
- } else if len(def) > 0 {
+ strv := c.Ctx.Input.Query(key)
+ if len(strv) == 0 && len(def) > 0 {
return def[0], nil
- } else {
- i64, err := strconv.ParseInt(strv, 10, 8)
- i8 := int8(i64)
- return i8, err
}
+ i64, err := strconv.ParseInt(strv, 10, 8)
+ return int8(i64), err
}
// GetInt16 returns input as an int16 or the default value while it's present and input is blank
func (c *Controller) GetInt16(key string, def ...int16) (int16, error) {
- if strv := c.Ctx.Input.Query(key); strv != "" {
- i64, err := strconv.ParseInt(strv, 10, 16)
- i16 := int16(i64)
- return i16, err
- } else if len(def) > 0 {
+ strv := c.Ctx.Input.Query(key)
+ if len(strv) == 0 && len(def) > 0 {
return def[0], nil
- } else {
- i64, err := strconv.ParseInt(strv, 10, 16)
- i16 := int16(i64)
- return i16, err
}
+ i64, err := strconv.ParseInt(strv, 10, 16)
+ return int16(i64), err
}
// GetInt32 returns input as an int32 or the default value while it's present and input is blank
func (c *Controller) GetInt32(key string, def ...int32) (int32, error) {
- if strv := c.Ctx.Input.Query(key); strv != "" {
- i64, err := strconv.ParseInt(c.Ctx.Input.Query(key), 10, 32)
- i32 := int32(i64)
- return i32, err
- } else if len(def) > 0 {
+ strv := c.Ctx.Input.Query(key)
+ if len(strv) == 0 && len(def) > 0 {
return def[0], nil
- } else {
- i64, err := strconv.ParseInt(c.Ctx.Input.Query(key), 10, 32)
- i32 := int32(i64)
- return i32, err
}
+ i64, err := strconv.ParseInt(strv, 10, 32)
+ return int32(i64), err
}
// GetInt64 returns input value as int64 or the default value while it's present and input is blank.
func (c *Controller) GetInt64(key string, def ...int64) (int64, error) {
- if strv := c.Ctx.Input.Query(key); strv != "" {
- return strconv.ParseInt(strv, 10, 64)
- } else if len(def) > 0 {
+ strv := c.Ctx.Input.Query(key)
+ if len(strv) == 0 && len(def) > 0 {
return def[0], nil
- } else {
- return strconv.ParseInt(strv, 10, 64)
}
+ return strconv.ParseInt(strv, 10, 64)
}
// GetBool returns input value as bool or the default value while it's present and input is blank.
func (c *Controller) GetBool(key string, def ...bool) (bool, error) {
- if strv := c.Ctx.Input.Query(key); strv != "" {
- return strconv.ParseBool(strv)
- } else if len(def) > 0 {
+ strv := c.Ctx.Input.Query(key)
+ if len(strv) == 0 && len(def) > 0 {
return def[0], nil
- } else {
- return strconv.ParseBool(strv)
}
+ return strconv.ParseBool(strv)
}
// GetFloat returns input value as float64 or the default value while it's present and input is blank.
func (c *Controller) GetFloat(key string, def ...float64) (float64, error) {
- if strv := c.Ctx.Input.Query(key); strv != "" {
- return strconv.ParseFloat(strv, 64)
- } else if len(def) > 0 {
+ strv := c.Ctx.Input.Query(key)
+ if len(strv) == 0 && len(def) > 0 {
return def[0], nil
- } else {
- return strconv.ParseFloat(strv, 64)
}
+ return strconv.ParseFloat(strv, 64)
}
// GetFile returns the file data in file upload field named as key.
@@ -527,8 +502,7 @@ func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader,
// }
// }
func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) {
- files, ok := c.Ctx.Request.MultipartForm.File[key]
- if ok {
+ if files, ok := c.Ctx.Request.MultipartForm.File[key]; ok {
return files, nil
}
return nil, http.ErrMissingFile
@@ -552,7 +526,7 @@ func (c *Controller) SaveToFile(fromfile, tofile string) error {
}
// StartSession starts session and load old session data info this controller.
-func (c *Controller) StartSession() session.SessionStore {
+func (c *Controller) StartSession() session.Store {
if c.CruSession == nil {
c.CruSession = c.Ctx.Input.CruSession
}
@@ -575,7 +549,7 @@ func (c *Controller) GetSession(name interface{}) interface{} {
return c.CruSession.Get(name)
}
-// SetSession removes value from session.
+// DelSession removes value from session.
func (c *Controller) DelSession(name interface{}) {
if c.CruSession == nil {
c.StartSession()
@@ -589,7 +563,7 @@ func (c *Controller) SessionRegenerateID() {
if c.CruSession != nil {
c.CruSession.SessionRelease(c.Ctx.ResponseWriter)
}
- c.CruSession = GlobalSessions.SessionRegenerateId(c.Ctx.ResponseWriter, c.Ctx.Request)
+ c.CruSession = GlobalSessions.SessionRegenerateID(c.Ctx.ResponseWriter, c.Ctx.Request)
c.Ctx.Input.CruSession = c.CruSession
}
@@ -614,37 +588,35 @@ func (c *Controller) SetSecureCookie(Secret, name, value string, others ...inter
c.Ctx.SetSecureCookie(Secret, name, value, others...)
}
-// XsrfToken creates a xsrf token string and returns.
-func (c *Controller) XsrfToken() string {
- if c._xsrf_token == "" {
- var expire int64
+// XSRFToken creates a CSRF token string and returns.
+func (c *Controller) XSRFToken() string {
+ if c._xsrfToken == "" {
+ expire := int64(BConfig.WebConfig.XSRFExpire)
if c.XSRFExpire > 0 {
expire = int64(c.XSRFExpire)
- } else {
- expire = int64(XSRFExpire)
}
- c._xsrf_token = c.Ctx.XsrfToken(XSRFKEY, expire)
+ c._xsrfToken = c.Ctx.XSRFToken(BConfig.WebConfig.XSRFKey, expire)
}
- return c._xsrf_token
+ return c._xsrfToken
}
-// CheckXsrfCookie checks xsrf token in this request is valid or not.
+// CheckXSRFCookie checks xsrf token in this request is valid or not.
// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken"
// or in form field value named as "_xsrf".
-func (c *Controller) CheckXsrfCookie() bool {
+func (c *Controller) CheckXSRFCookie() bool {
if !c.EnableXSRF {
return true
}
- return c.Ctx.CheckXsrfCookie()
+ return c.Ctx.CheckXSRFCookie()
}
-// XsrfFormHtml writes an input field contains xsrf token value.
-func (c *Controller) XsrfFormHtml() string {
- return ""
+// XSRFFormHTML writes an input field contains xsrf token value.
+func (c *Controller) XSRFFormHTML() string {
+ return ``
}
// GetControllerAndAction gets the executing controller name and action name.
-func (c *Controller) GetControllerAndAction() (controllerName, actionName string) {
+func (c *Controller) GetControllerAndAction() (string, string) {
return c.controllerName, c.actionName
}
diff --git a/controller_test.go b/controller_test.go
index 15938cdc..51d3a5b7 100644
--- a/controller_test.go
+++ b/controller_test.go
@@ -15,61 +15,63 @@
package beego
import (
- "fmt"
+ "testing"
+
"github.com/astaxie/beego/context"
)
-func ExampleGetInt() {
-
- i := &context.BeegoInput{Params: map[string]string{"age": "40"}}
+func TestGetInt(t *testing.T) {
+ i := context.NewInput()
+ i.SetParam("age", "40")
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
-
val, _ := ctrlr.GetInt("age")
- fmt.Printf("%T", val)
- //Output: int
+ if val != 40 {
+ t.Errorf("TestGetInt expect 40,get %T,%v", val, val)
+ }
}
-func ExampleGetInt8() {
-
- i := &context.BeegoInput{Params: map[string]string{"age": "40"}}
+func TestGetInt8(t *testing.T) {
+ i := context.NewInput()
+ i.SetParam("age", "40")
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
-
val, _ := ctrlr.GetInt8("age")
- fmt.Printf("%T", val)
+ if val != 40 {
+ t.Errorf("TestGetInt8 expect 40,get %T,%v", val, val)
+ }
//Output: int8
}
-func ExampleGetInt16() {
-
- i := &context.BeegoInput{Params: map[string]string{"age": "40"}}
+func TestGetInt16(t *testing.T) {
+ i := context.NewInput()
+ i.SetParam("age", "40")
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
-
val, _ := ctrlr.GetInt16("age")
- fmt.Printf("%T", val)
- //Output: int16
+ if val != 40 {
+ t.Errorf("TestGetInt16 expect 40,get %T,%v", val, val)
+ }
}
-func ExampleGetInt32() {
-
- i := &context.BeegoInput{Params: map[string]string{"age": "40"}}
+func TestGetInt32(t *testing.T) {
+ i := context.NewInput()
+ i.SetParam("age", "40")
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
-
val, _ := ctrlr.GetInt32("age")
- fmt.Printf("%T", val)
- //Output: int32
+ if val != 40 {
+ t.Errorf("TestGetInt32 expect 40,get %T,%v", val, val)
+ }
}
-func ExampleGetInt64() {
-
- i := &context.BeegoInput{Params: map[string]string{"age": "40"}}
+func TestGetInt64(t *testing.T) {
+ i := context.NewInput()
+ i.SetParam("age", "40")
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
-
val, _ := ctrlr.GetInt64("age")
- fmt.Printf("%T", val)
- //Output: int64
+ if val != 40 {
+ t.Errorf("TestGeetInt64 expect 40,get %T,%v", val, val)
+ }
}
diff --git a/doc.go b/doc.go
new file mode 100644
index 00000000..4be305b3
--- /dev/null
+++ b/doc.go
@@ -0,0 +1,17 @@
+/*
+Package beego provide a MVC framework
+beego: an open-source, high-performance, modular, full-stack web framework
+
+It is used for rapid development of RESTful APIs, web apps and backend services in Go.
+beego is inspired by Tornado, Sinatra and Flask with the added benefit of some Go-specific features such as interfaces and struct embedding.
+
+ package main
+ import "github.com/astaxie/beego"
+
+ func main() {
+ beego.Run()
+ }
+
+more infomation: http://beego.me
+*/
+package beego
diff --git a/docs.go b/docs.go
index aaad205e..72532876 100644
--- a/docs.go
+++ b/docs.go
@@ -15,37 +15,24 @@
package beego
import (
- "encoding/json"
-
"github.com/astaxie/beego/context"
)
-var GlobalDocApi map[string]interface{}
-
-func init() {
- if EnableDocs {
- GlobalDocApi = make(map[string]interface{})
- }
-}
+// GlobalDocAPI store the swagger api documents
+var GlobalDocAPI = make(map[string]interface{})
func serverDocs(ctx *context.Context) {
var obj interface{}
if splat := ctx.Input.Param(":splat"); splat == "" {
- obj = GlobalDocApi["Root"]
+ obj = GlobalDocAPI["Root"]
} else {
- if v, ok := GlobalDocApi[splat]; ok {
+ if v, ok := GlobalDocAPI[splat]; ok {
obj = v
}
}
if obj != nil {
- bt, err := json.Marshal(obj)
- if err != nil {
- ctx.Output.SetStatus(504)
- return
- }
- ctx.Output.Header("Content-Type", "application/json;charset=UTF-8")
ctx.Output.Header("Access-Control-Allow-Origin", "*")
- ctx.Output.Body(bt)
+ ctx.Output.JSON(obj, false, false)
return
}
ctx.Output.SetStatus(404)
diff --git a/error.go b/error.go
index 99a1fcf3..af57b7c7 100644
--- a/error.go
+++ b/error.go
@@ -82,16 +82,17 @@ var tpl = `
`
// render default application error page with error and stack string.
-func showErr(err interface{}, ctx *context.Context, Stack string) {
+func showErr(err interface{}, ctx *context.Context, stack string) {
t, _ := template.New("beegoerrortemp").Parse(tpl)
- data := make(map[string]string)
- data["AppError"] = AppName + ":" + fmt.Sprint(err)
- data["RequestMethod"] = ctx.Input.Method()
- data["RequestURL"] = ctx.Input.Uri()
- data["RemoteAddr"] = ctx.Input.IP()
- data["Stack"] = Stack
- data["BeegoVersion"] = VERSION
- data["GoVersion"] = runtime.Version()
+ data := map[string]string{
+ "AppError": fmt.Sprintf("%s:%v", BConfig.AppName, err),
+ "RequestMethod": ctx.Input.Method(),
+ "RequestURL": ctx.Input.URI(),
+ "RemoteAddr": ctx.Input.IP(),
+ "Stack": stack,
+ "BeegoVersion": VERSION,
+ "GoVersion": runtime.Version(),
+ }
ctx.ResponseWriter.WriteHeader(500)
t.Execute(ctx.ResponseWriter, data)
}
@@ -204,47 +205,48 @@ type errorInfo struct {
}
// map of http handlers for each error string.
-var ErrorMaps map[string]*errorInfo
-
-func init() {
- ErrorMaps = make(map[string]*errorInfo)
-}
+// there is 10 kinds default error(40x and 50x)
+var ErrorMaps = make(map[string]*errorInfo, 10)
// show 401 unauthorized error.
func unauthorized(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := make(map[string]interface{})
- data["Title"] = "Unauthorized"
+ data := map[string]interface{}{
+ "Title": http.StatusText(401),
+ "BeegoVersion": VERSION,
+ }
data["Content"] = template.HTML(" The page you have requested can't be authorized." +
" Perhaps you are here because:" +
"
" +
" The credentials you supplied are incorrect" +
" There are errors in the website address" +
"
")
- data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 402 Payment Required
func paymentRequired(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := make(map[string]interface{})
- data["Title"] = "Payment Required"
+ data := map[string]interface{}{
+ "Title": http.StatusText(402),
+ "BeegoVersion": VERSION,
+ }
data["Content"] = template.HTML(" The page you have requested Payment Required." +
" Perhaps you are here because:" +
"
" +
" The credentials you supplied are incorrect" +
" There are errors in the website address" +
"
")
- data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 403 forbidden error.
func forbidden(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := make(map[string]interface{})
- data["Title"] = "Forbidden"
+ data := map[string]interface{}{
+ "Title": http.StatusText(403),
+ "BeegoVersion": VERSION,
+ }
data["Content"] = template.HTML(" The page you have requested is forbidden." +
" Perhaps you are here because:" +
"
" +
@@ -252,15 +254,16 @@ func forbidden(rw http.ResponseWriter, r *http.Request) {
" The site may be disabled" +
" You need to log in" +
"
")
- data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 404 notfound error.
func notFound(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := make(map[string]interface{})
- data["Title"] = "Page Not Found"
+ data := map[string]interface{}{
+ "Title": http.StatusText(404),
+ "BeegoVersion": VERSION,
+ }
data["Content"] = template.HTML(" The page you have requested has flown the coop." +
" Perhaps you are here because:" +
"
" +
@@ -269,191 +272,158 @@ func notFound(rw http.ResponseWriter, r *http.Request) {
" You were looking for your puppy and got lost" +
" You like 404 pages" +
"
")
- data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 405 Method Not Allowed
func methodNotAllowed(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := make(map[string]interface{})
- data["Title"] = "Method Not Allowed"
+ data := map[string]interface{}{
+ "Title": http.StatusText(405),
+ "BeegoVersion": VERSION,
+ }
data["Content"] = template.HTML(" The method you have requested Not Allowed." +
" Perhaps you are here because:" +
"
" +
" The method specified in the Request-Line is not allowed for the resource identified by the Request-URI" +
" The response MUST include an Allow header containing a list of valid methods for the requested resource." +
"
")
- data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 500 internal server error.
func internalServerError(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := make(map[string]interface{})
- data["Title"] = "Internal Server Error"
+ data := map[string]interface{}{
+ "Title": http.StatusText(500),
+ "BeegoVersion": VERSION,
+ }
data["Content"] = template.HTML(" The page you have requested is down right now." +
"
" +
" Please try again later and report the error to the website administrator" +
"
")
- data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 501 Not Implemented.
func notImplemented(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := make(map[string]interface{})
- data["Title"] = "Not Implemented"
+ data := map[string]interface{}{
+ "Title": http.StatusText(504),
+ "BeegoVersion": VERSION,
+ }
data["Content"] = template.HTML(" The page you have requested is Not Implemented." +
"
" +
" Please try again later and report the error to the website administrator" +
"
")
- data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 502 Bad Gateway.
func badGateway(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := make(map[string]interface{})
- data["Title"] = "Bad Gateway"
+ data := map[string]interface{}{
+ "Title": http.StatusText(502),
+ "BeegoVersion": VERSION,
+ }
data["Content"] = template.HTML(" The page you have requested is down right now." +
"
" +
" The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request." +
" Please try again later and report the error to the website administrator" +
"
")
- data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 503 service unavailable error.
func serviceUnavailable(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := make(map[string]interface{})
- data["Title"] = "Service Unavailable"
+ data := map[string]interface{}{
+ "Title": http.StatusText(503),
+ "BeegoVersion": VERSION,
+ }
data["Content"] = template.HTML(" The page you have requested is unavailable." +
" Perhaps you are here because:" +
"
" +
"
The page is overloaded" +
" Please try again later." +
"
")
- data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 504 Gateway Timeout.
func gatewayTimeout(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := make(map[string]interface{})
- data["Title"] = "Gateway Timeout"
+ data := map[string]interface{}{
+ "Title": http.StatusText(504),
+ "BeegoVersion": VERSION,
+ }
data["Content"] = template.HTML(" The page you have requested is unavailable." +
" Perhaps you are here because:" +
"
" +
"
The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI." +
" Please try again later." +
"
")
- data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
-// register default error http handlers, 404,401,403,500 and 503.
-func registerDefaultErrorHandler() {
- if _, ok := ErrorMaps["401"]; !ok {
- Errorhandler("401", unauthorized)
- }
-
- if _, ok := ErrorMaps["402"]; !ok {
- Errorhandler("402", paymentRequired)
- }
-
- if _, ok := ErrorMaps["403"]; !ok {
- Errorhandler("403", forbidden)
- }
-
- if _, ok := ErrorMaps["404"]; !ok {
- Errorhandler("404", notFound)
- }
-
- if _, ok := ErrorMaps["405"]; !ok {
- Errorhandler("405", methodNotAllowed)
- }
-
- if _, ok := ErrorMaps["500"]; !ok {
- Errorhandler("500", internalServerError)
- }
- if _, ok := ErrorMaps["501"]; !ok {
- Errorhandler("501", notImplemented)
- }
- if _, ok := ErrorMaps["502"]; !ok {
- Errorhandler("502", badGateway)
- }
-
- if _, ok := ErrorMaps["503"]; !ok {
- Errorhandler("503", serviceUnavailable)
- }
-
- if _, ok := ErrorMaps["504"]; !ok {
- Errorhandler("504", gatewayTimeout)
- }
-}
-
// ErrorHandler registers http.HandlerFunc to each http err code string.
// usage:
// beego.ErrorHandler("404",NotFound)
// beego.ErrorHandler("500",InternalServerError)
-func Errorhandler(code string, h http.HandlerFunc) *App {
- errinfo := &errorInfo{}
- errinfo.errorType = errorTypeHandler
- errinfo.handler = h
- errinfo.method = code
- ErrorMaps[code] = errinfo
+func ErrorHandler(code string, h http.HandlerFunc) *App {
+ ErrorMaps[code] = &errorInfo{
+ errorType: errorTypeHandler,
+ handler: h,
+ method: code,
+ }
return BeeApp
}
// ErrorController registers ControllerInterface to each http err code string.
// usage:
-// beego.ErrorHandler(&controllers.ErrorController{})
+// beego.ErrorController(&controllers.ErrorController{})
func ErrorController(c ControllerInterface) *App {
reflectVal := reflect.ValueOf(c)
rt := reflectVal.Type()
ct := reflect.Indirect(reflectVal).Type()
for i := 0; i < rt.NumMethod(); i++ {
- if !utils.InSlice(rt.Method(i).Name, exceptMethod) && strings.HasPrefix(rt.Method(i).Name, "Error") {
- errinfo := &errorInfo{}
- errinfo.errorType = errorTypeController
- errinfo.controllerType = ct
- errinfo.method = rt.Method(i).Name
- errname := strings.TrimPrefix(rt.Method(i).Name, "Error")
- ErrorMaps[errname] = errinfo
+ methodName := rt.Method(i).Name
+ if !utils.InSlice(methodName, exceptMethod) && strings.HasPrefix(methodName, "Error") {
+ errName := strings.TrimPrefix(methodName, "Error")
+ ErrorMaps[errName] = &errorInfo{
+ errorType: errorTypeController,
+ controllerType: ct,
+ method: methodName,
+ }
}
}
return BeeApp
}
// show error string as simple text message.
-// if error string is empty, show 500 error as default.
-func exception(errcode string, ctx *context.Context) {
- code, err := strconv.Atoi(errcode)
- if err != nil {
- code = 503
+// if error string is empty, show 503 or 500 error as default.
+func exception(errCode string, ctx *context.Context) {
+ atoi := func(code string) int {
+ v, err := strconv.Atoi(code)
+ if err == nil {
+ return v
+ }
+ return 503
}
- if h, ok := ErrorMaps[errcode]; ok {
- executeError(h, ctx, code)
- return
- } else if h, ok := ErrorMaps["503"]; ok {
- executeError(h, ctx, code)
- return
- } else {
- ctx.ResponseWriter.WriteHeader(code)
- ctx.WriteString(errcode)
+
+ for _, ec := range []string{errCode, "503", "500"} {
+ if h, ok := ErrorMaps[ec]; ok {
+ executeError(h, ctx, atoi(ec))
+ return
+ }
}
+ //if 50x error has been removed from errorMap
+ ctx.ResponseWriter.WriteHeader(atoi(errCode))
+ ctx.WriteString(errCode)
}
func executeError(err *errorInfo, ctx *context.Context, code int) {
if err.errorType == errorTypeHandler {
- ctx.ResponseWriter.WriteHeader(code)
err.handler(ctx.ResponseWriter, ctx.Request)
return
}
@@ -473,12 +443,11 @@ func executeError(err *errorInfo, ctx *context.Context, code int) {
execController.URLMapping()
- in := make([]reflect.Value, 0)
method := vc.MethodByName(err.method)
- method.Call(in)
+ method.Call([]reflect.Value{})
//render template
- if AutoRender {
+ if BConfig.WebConfig.AutoRender {
if err := execController.Render(); err != nil {
panic(err)
}
diff --git a/example/beeapi/conf/app.conf b/example/beeapi/conf/app.conf
deleted file mode 100644
index fd1681a2..00000000
--- a/example/beeapi/conf/app.conf
+++ /dev/null
@@ -1,5 +0,0 @@
-appname = beeapi
-httpport = 8080
-runmode = dev
-autorender = false
-copyrequestbody = true
diff --git a/example/beeapi/controllers/default.go b/example/beeapi/controllers/default.go
deleted file mode 100644
index a2184075..00000000
--- a/example/beeapi/controllers/default.go
+++ /dev/null
@@ -1,63 +0,0 @@
-// Beego (http://beego.me/)
-// @description beego is an open-source, high-performance web framework for the Go programming language.
-// @link http://github.com/astaxie/beego for the canonical source repository
-// @license http://github.com/astaxie/beego/blob/master/LICENSE
-// @authors astaxie
-
-package controllers
-
-import (
- "encoding/json"
-
- "github.com/astaxie/beego"
- "github.com/astaxie/beego/example/beeapi/models"
-)
-
-type ObjectController struct {
- beego.Controller
-}
-
-func (o *ObjectController) Post() {
- var ob models.Object
- json.Unmarshal(o.Ctx.Input.RequestBody, &ob)
- objectid := models.AddOne(ob)
- o.Data["json"] = map[string]string{"ObjectId": objectid}
- o.ServeJson()
-}
-
-func (o *ObjectController) Get() {
- objectId := o.Ctx.Input.Params[":objectId"]
- if objectId != "" {
- ob, err := models.GetOne(objectId)
- if err != nil {
- o.Data["json"] = err
- } else {
- o.Data["json"] = ob
- }
- } else {
- obs := models.GetAll()
- o.Data["json"] = obs
- }
- o.ServeJson()
-}
-
-func (o *ObjectController) Put() {
- objectId := o.Ctx.Input.Params[":objectId"]
- var ob models.Object
- json.Unmarshal(o.Ctx.Input.RequestBody, &ob)
-
- err := models.Update(objectId, ob.Score)
- if err != nil {
- o.Data["json"] = err
- } else {
- o.Data["json"] = "update success!"
- }
- o.ServeJson()
-}
-
-func (o *ObjectController) Delete() {
- objectId := o.Ctx.Input.Params[":objectId"]
- models.Delete(objectId)
- o.Data["json"] = "delete success!"
- o.ServeJson()
-}
diff --git a/example/beeapi/main.go b/example/beeapi/main.go
deleted file mode 100644
index c1250e03..00000000
--- a/example/beeapi/main.go
+++ /dev/null
@@ -1,30 +0,0 @@
-// Beego (http://beego.me/)
-
-// @description beego is an open-source, high-performance web framework for the Go programming language.
-
-// @link http://github.com/astaxie/beego for the canonical source repository
-
-// @license http://github.com/astaxie/beego/blob/master/LICENSE
-
-// @authors astaxie
-
-package main
-
-import (
- "github.com/astaxie/beego"
- "github.com/astaxie/beego/example/beeapi/controllers"
-)
-
-// Objects
-
-// URL HTTP Verb Functionality
-// /object POST Creating Objects
-// /object/ GET Retrieving Objects
-// /object/ PUT Updating Objects
-// /object GET Queries
-// /object/ DELETE Deleting Objects
-
-func main() {
- beego.RESTRouter("/object", &controllers.ObjectController{})
- beego.Run()
-}
diff --git a/example/beeapi/models/object.go b/example/beeapi/models/object.go
deleted file mode 100644
index 46109c50..00000000
--- a/example/beeapi/models/object.go
+++ /dev/null
@@ -1,58 +0,0 @@
-// Beego (http://beego.me/)
-// @description beego is an open-source, high-performance web framework for the Go programming language.
-// @link http://github.com/astaxie/beego for the canonical source repository
-// @license http://github.com/astaxie/beego/blob/master/LICENSE
-// @authors astaxie
-
-package models
-
-import (
- "errors"
- "strconv"
- "time"
-)
-
-var (
- Objects map[string]*Object
-)
-
-type Object struct {
- ObjectId string
- Score int64
- PlayerName string
-}
-
-func init() {
- Objects = make(map[string]*Object)
- Objects["hjkhsbnmn123"] = &Object{"hjkhsbnmn123", 100, "astaxie"}
- Objects["mjjkxsxsaa23"] = &Object{"mjjkxsxsaa23", 101, "someone"}
-}
-
-func AddOne(object Object) (ObjectId string) {
- object.ObjectId = "astaxie" + strconv.FormatInt(time.Now().UnixNano(), 10)
- Objects[object.ObjectId] = &object
- return object.ObjectId
-}
-
-func GetOne(ObjectId string) (object *Object, err error) {
- if v, ok := Objects[ObjectId]; ok {
- return v, nil
- }
- return nil, errors.New("ObjectId Not Exist")
-}
-
-func GetAll() map[string]*Object {
- return Objects
-}
-
-func Update(ObjectId string, Score int64) (err error) {
- if v, ok := Objects[ObjectId]; ok {
- v.Score = Score
- return nil
- }
- return errors.New("ObjectId Not Exist")
-}
-
-func Delete(ObjectId string) {
- delete(Objects, ObjectId)
-}
diff --git a/example/chat/conf/app.conf b/example/chat/conf/app.conf
deleted file mode 100644
index 9d5a823f..00000000
--- a/example/chat/conf/app.conf
+++ /dev/null
@@ -1,3 +0,0 @@
-appname = chat
-httpport = 8080
-runmode = dev
diff --git a/example/chat/controllers/default.go b/example/chat/controllers/default.go
deleted file mode 100644
index 95ef8a9b..00000000
--- a/example/chat/controllers/default.go
+++ /dev/null
@@ -1,20 +0,0 @@
-// Beego (http://beego.me/)
-// @description beego is an open-source, high-performance web framework for the Go programming language.
-// @link http://github.com/astaxie/beego for the canonical source repository
-// @license http://github.com/astaxie/beego/blob/master/LICENSE
-// @authors Unknwon
-
-package controllers
-
-import (
- "github.com/astaxie/beego"
-)
-
-type MainController struct {
- beego.Controller
-}
-
-func (m *MainController) Get() {
- m.Data["host"] = m.Ctx.Request.Host
- m.TplNames = "index.tpl"
-}
diff --git a/example/chat/controllers/ws.go b/example/chat/controllers/ws.go
deleted file mode 100644
index fc3917b3..00000000
--- a/example/chat/controllers/ws.go
+++ /dev/null
@@ -1,181 +0,0 @@
-// Beego (http://beego.me/)
-// @description beego is an open-source, high-performance web framework for the Go programming language.
-// @link http://github.com/astaxie/beego for the canonical source repository
-// @license http://github.com/astaxie/beego/blob/master/LICENSE
-// @authors Unknwon
-
-package controllers
-
-import (
- "io/ioutil"
- "math/rand"
- "net/http"
- "time"
-
- "github.com/astaxie/beego"
- "github.com/gorilla/websocket"
-)
-
-const (
- // Time allowed to write a message to the client.
- writeWait = 10 * time.Second
-
- // Time allowed to read the next message from the client.
- readWait = 60 * time.Second
-
- // Send pings to client with this period. Must be less than readWait.
- pingPeriod = (readWait * 9) / 10
-
- // Maximum message size allowed from client.
- maxMessageSize = 512
-)
-
-func init() {
- rand.Seed(time.Now().UTC().UnixNano())
- go h.run()
-}
-
-// connection is an middleman between the websocket connection and the hub.
-type connection struct {
- username string
-
- // The websocket connection.
- ws *websocket.Conn
-
- // Buffered channel of outbound messages.
- send chan []byte
-}
-
-// readPump pumps messages from the websocket connection to the hub.
-func (c *connection) readPump() {
- defer func() {
- h.unregister <- c
- c.ws.Close()
- }()
- c.ws.SetReadLimit(maxMessageSize)
- c.ws.SetReadDeadline(time.Now().Add(readWait))
- for {
- op, r, err := c.ws.NextReader()
- if err != nil {
- break
- }
- switch op {
- case websocket.PongMessage:
- c.ws.SetReadDeadline(time.Now().Add(readWait))
- case websocket.TextMessage:
- message, err := ioutil.ReadAll(r)
- if err != nil {
- break
- }
- h.broadcast <- []byte(c.username + "_" + time.Now().Format("15:04:05") + ":" + string(message))
- }
- }
-}
-
-// write writes a message with the given opCode and payload.
-func (c *connection) write(opCode int, payload []byte) error {
- c.ws.SetWriteDeadline(time.Now().Add(writeWait))
- return c.ws.WriteMessage(opCode, payload)
-}
-
-// writePump pumps messages from the hub to the websocket connection.
-func (c *connection) writePump() {
- ticker := time.NewTicker(pingPeriod)
- defer func() {
- ticker.Stop()
- c.ws.Close()
- }()
- for {
- select {
- case message, ok := <-c.send:
- if !ok {
- c.write(websocket.CloseMessage, []byte{})
- return
- }
- if err := c.write(websocket.TextMessage, message); err != nil {
- return
- }
- case <-ticker.C:
- if err := c.write(websocket.PingMessage, []byte{}); err != nil {
- return
- }
- }
- }
-}
-
-type hub struct {
- // Registered connections.
- connections map[*connection]bool
-
- // Inbound messages from the connections.
- broadcast chan []byte
-
- // Register requests from the connections.
- register chan *connection
-
- // Unregister requests from connections.
- unregister chan *connection
-}
-
-var h = &hub{
- broadcast: make(chan []byte, maxMessageSize),
- register: make(chan *connection, 1),
- unregister: make(chan *connection, 1),
- connections: make(map[*connection]bool),
-}
-
-func (h *hub) run() {
- for {
- select {
- case c := <-h.register:
- h.connections[c] = true
- case c := <-h.unregister:
- delete(h.connections, c)
- close(c.send)
- case m := <-h.broadcast:
- for c := range h.connections {
- select {
- case c.send <- m:
- default:
- close(c.send)
- delete(h.connections, c)
- }
- }
- }
- }
-}
-
-type WSController struct {
- beego.Controller
-}
-
-var upgrader = websocket.Upgrader{
- ReadBufferSize: 1024,
- WriteBufferSize: 1024,
-}
-
-func (w *WSController) Get() {
- ws, err := upgrader.Upgrade(w.Ctx.ResponseWriter, w.Ctx.Request, nil)
- if _, ok := err.(websocket.HandshakeError); ok {
- http.Error(w.Ctx.ResponseWriter, "Not a websocket handshake", 400)
- return
- } else if err != nil {
- return
- }
- c := &connection{send: make(chan []byte, 256), ws: ws, username: randomString(10)}
- h.register <- c
- go c.writePump()
- c.readPump()
-}
-
-func randomString(l int) string {
- bytes := make([]byte, l)
- for i := 0; i < l; i++ {
- bytes[i] = byte(randInt(65, 90))
- }
- return string(bytes)
-}
-
-func randInt(min int, max int) int {
- return min + rand.Intn(max-min)
-}
diff --git a/example/chat/main.go b/example/chat/main.go
deleted file mode 100644
index be055b25..00000000
--- a/example/chat/main.go
+++ /dev/null
@@ -1,17 +0,0 @@
-// Beego (http://beego.me/)
-// @description beego is an open-source, high-performance web framework for the Go programming language.
-// @link http://github.com/astaxie/beego for the canonical source repository
-// @license http://github.com/astaxie/beego/blob/master/LICENSE
-// @authors Unknwon
-package main
-
-import (
- "github.com/astaxie/beego"
- "github.com/astaxie/beego/example/chat/controllers"
-)
-
-func main() {
- beego.Router("/", &controllers.MainController{})
- beego.Router("/ws", &controllers.WSController{})
- beego.Run()
-}
diff --git a/example/chat/views/index.tpl b/example/chat/views/index.tpl
deleted file mode 100644
index 3a9d8838..00000000
--- a/example/chat/views/index.tpl
+++ /dev/null
@@ -1,92 +0,0 @@
-
-
-
-Chat Example
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/filter.go b/filter.go
index f673ab66..863223f7 100644
--- a/filter.go
+++ b/filter.go
@@ -32,14 +32,12 @@ type FilterRouter struct {
// ValidRouter checks if the current request is matched by this filter.
// If the request is matched, the values of the URL parameters defined
// by the filter pattern are also returned.
-func (f *FilterRouter) ValidRouter(url string) (bool, map[string]string) {
- isok, params := f.tree.Match(url)
- if isok == nil {
- return false, nil
- }
- if isok, ok := isok.(bool); ok {
- return isok, params
- } else {
- return false, nil
+func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool {
+ isOk := f.tree.Match(url, ctx)
+ if isOk != nil {
+ if b, ok := isOk.(bool); ok {
+ return b
+ }
}
+ return false
}
diff --git a/filter_test.go b/filter_test.go
index ff6f750b..d9928d8d 100644
--- a/filter_test.go
+++ b/filter_test.go
@@ -20,10 +20,16 @@ import (
"testing"
"github.com/astaxie/beego/context"
+ "github.com/astaxie/beego/logs"
)
+func init() {
+ BeeLogger = logs.NewLogger(10000)
+ BeeLogger.SetLogger("console", "")
+}
+
var FilterUser = func(ctx *context.Context) {
- ctx.Output.Body([]byte("i am " + ctx.Input.Params[":last"] + ctx.Input.Params[":first"]))
+ ctx.Output.Body([]byte("i am " + ctx.Input.Param(":last") + ctx.Input.Param(":first")))
}
func TestFilter(t *testing.T) {
diff --git a/flash.go b/flash.go
index 5ccf339a..a6485a17 100644
--- a/flash.go
+++ b/flash.go
@@ -83,27 +83,27 @@ func (fd *FlashData) Store(c *Controller) {
c.Data["flash"] = fd.Data
var flashValue string
for key, value := range fd.Data {
- flashValue += "\x00" + key + "\x23" + FlashSeperator + "\x23" + value + "\x00"
+ flashValue += "\x00" + key + "\x23" + BConfig.WebConfig.FlashSeparator + "\x23" + value + "\x00"
}
- c.Ctx.SetCookie(FlashName, url.QueryEscape(flashValue), 0, "/")
+ c.Ctx.SetCookie(BConfig.WebConfig.FlashName, url.QueryEscape(flashValue), 0, "/")
}
// ReadFromRequest parsed flash data from encoded values in cookie.
func ReadFromRequest(c *Controller) *FlashData {
flash := NewFlash()
- if cookie, err := c.Ctx.Request.Cookie(FlashName); err == nil {
+ if cookie, err := c.Ctx.Request.Cookie(BConfig.WebConfig.FlashName); err == nil {
v, _ := url.QueryUnescape(cookie.Value)
vals := strings.Split(v, "\x00")
for _, v := range vals {
if len(v) > 0 {
- kv := strings.Split(v, "\x23"+FlashSeperator+"\x23")
+ kv := strings.Split(v, "\x23"+BConfig.WebConfig.FlashSeparator+"\x23")
if len(kv) == 2 {
flash.Data[kv[0]] = kv[1]
}
}
}
//read one time then delete it
- c.Ctx.SetCookie(FlashName, "", -1, "/")
+ c.Ctx.SetCookie(BConfig.WebConfig.FlashName, "", -1, "/")
}
c.Data["flash"] = flash.Data
return flash
diff --git a/flash_test.go b/flash_test.go
index b655f552..640d54de 100644
--- a/flash_test.go
+++ b/flash_test.go
@@ -30,7 +30,7 @@ func (t *TestFlashController) TestWriteFlash() {
flash.Notice("TestFlashString")
flash.Store(&t.Controller)
// we choose to serve json because we don't want to load a template html file
- t.ServeJson(true)
+ t.ServeJSON(true)
}
func TestFlashHeader(t *testing.T) {
diff --git a/grace/conn.go b/grace/conn.go
index 2cf3a93d..6807e1ac 100644
--- a/grace/conn.go
+++ b/grace/conn.go
@@ -1,13 +1,28 @@
package grace
-import "net"
+import (
+ "errors"
+ "net"
+)
type graceConn struct {
net.Conn
- server *graceServer
+ server *Server
}
-func (c graceConn) Close() error {
+func (c graceConn) Close() (err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ switch x := r.(type) {
+ case string:
+ err = errors.New(x)
+ case error:
+ err = x
+ default:
+ err = errors.New("Unknown panic")
+ }
+ }
+ }()
c.server.wg.Done()
return c.Conn.Close()
}
diff --git a/grace/grace.go b/grace/grace.go
index e5577267..af530d50 100644
--- a/grace/grace.go
+++ b/grace/grace.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package grace use to hot reload
// Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/
//
// Usage:
@@ -32,7 +33,7 @@
// mux := http.NewServeMux()
// mux.HandleFunc("/hello", handler)
//
-// err := grace.ListenAndServe("localhost:8080", mux1)
+// err := grace.ListenAndServe("localhost:8080", mux)
// if err != nil {
// log.Println(err)
// }
@@ -52,46 +53,53 @@ import (
)
const (
- PRE_SIGNAL = iota
- POST_SIGNAL
-
- STATE_INIT
- STATE_RUNNING
- STATE_SHUTTING_DOWN
- STATE_TERMINATE
+ // PreSignal is the position to add filter before signal
+ PreSignal = iota
+ // PostSignal is the position to add filter after signal
+ PostSignal
+ // StateInit represent the application inited
+ StateInit
+ // StateRunning represent the application is running
+ StateRunning
+ // StateShuttingDown represent the application is shutting down
+ StateShuttingDown
+ // StateTerminate represent the application is killed
+ StateTerminate
)
var (
regLock *sync.Mutex
- runningServers map[string]*graceServer
+ runningServers map[string]*Server
runningServersOrder []string
socketPtrOffsetMap map[string]uint
runningServersForked bool
- DefaultReadTimeOut time.Duration
- DefaultWriteTimeOut time.Duration
+ // DefaultReadTimeOut is the HTTP read timeout
+ DefaultReadTimeOut time.Duration
+ // DefaultWriteTimeOut is the HTTP Write timeout
+ DefaultWriteTimeOut time.Duration
+ // DefaultMaxHeaderBytes is the Max HTTP Herder size, default is 0, no limit
DefaultMaxHeaderBytes int
- DefaultTimeout time.Duration
+ // DefaultTimeout is the shutdown server's timeout. default is 60s
+ DefaultTimeout = 60 * time.Second
isChild bool
socketOrder string
+ once sync.Once
)
-func init() {
+func onceInit() {
regLock = &sync.Mutex{}
flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)")
flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started")
- runningServers = make(map[string]*graceServer)
+ runningServers = make(map[string]*Server)
runningServersOrder = []string{}
socketPtrOffsetMap = make(map[string]uint)
-
- DefaultMaxHeaderBytes = 0
-
- DefaultTimeout = 60 * time.Second
}
// NewServer returns a new graceServer.
-func NewServer(addr string, handler http.Handler) (srv *graceServer) {
+func NewServer(addr string, handler http.Handler) (srv *Server) {
+ once.Do(onceInit)
regLock.Lock()
defer regLock.Unlock()
if !flag.Parsed() {
@@ -105,23 +113,23 @@ func NewServer(addr string, handler http.Handler) (srv *graceServer) {
socketPtrOffsetMap[addr] = uint(len(runningServersOrder))
}
- srv = &graceServer{
+ srv = &Server{
wg: sync.WaitGroup{},
sigChan: make(chan os.Signal),
isChild: isChild,
SignalHooks: map[int]map[os.Signal][]func(){
- PRE_SIGNAL: map[os.Signal][]func(){
+ PreSignal: map[os.Signal][]func(){
syscall.SIGHUP: []func(){},
syscall.SIGINT: []func(){},
syscall.SIGTERM: []func(){},
},
- POST_SIGNAL: map[os.Signal][]func(){
+ PostSignal: map[os.Signal][]func(){
syscall.SIGHUP: []func(){},
syscall.SIGINT: []func(){},
syscall.SIGTERM: []func(){},
},
},
- state: STATE_INIT,
+ state: StateInit,
Network: "tcp",
}
srv.Server = &http.Server{}
@@ -137,13 +145,13 @@ func NewServer(addr string, handler http.Handler) (srv *graceServer) {
return
}
-// refer http.ListenAndServe
+// ListenAndServe refer http.ListenAndServe
func ListenAndServe(addr string, handler http.Handler) error {
server := NewServer(addr, handler)
return server.ListenAndServe()
}
-// refer http.ListenAndServeTLS
+// ListenAndServeTLS refer http.ListenAndServeTLS
func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error {
server := NewServer(addr, handler)
return server.ListenAndServeTLS(certFile, keyFile)
diff --git a/grace/listener.go b/grace/listener.go
index 8c5d4f9b..5439d0b2 100644
--- a/grace/listener.go
+++ b/grace/listener.go
@@ -11,10 +11,10 @@ type graceListener struct {
net.Listener
stop chan error
stopped bool
- server *graceServer
+ server *Server
}
-func newGraceListener(l net.Listener, srv *graceServer) (el *graceListener) {
+func newGraceListener(l net.Listener, srv *Server) (el *graceListener) {
el = &graceListener{
Listener: l,
stop: make(chan error),
@@ -46,17 +46,17 @@ func (gl *graceListener) Accept() (c net.Conn, err error) {
return
}
-func (el *graceListener) Close() error {
- if el.stopped {
+func (gl *graceListener) Close() error {
+ if gl.stopped {
return syscall.EINVAL
}
- el.stop <- nil
- return <-el.stop
+ gl.stop <- nil
+ return <-gl.stop
}
-func (el *graceListener) File() *os.File {
+func (gl *graceListener) File() *os.File {
// returns a dup(2) - FD_CLOEXEC flag *not* set
- tl := el.Listener.(*net.TCPListener)
+ tl := gl.Listener.(*net.TCPListener)
fl, _ := tl.File()
return fl
}
diff --git a/grace/server.go b/grace/server.go
index aea8d7d3..f4512ded 100644
--- a/grace/server.go
+++ b/grace/server.go
@@ -15,7 +15,8 @@ import (
"time"
)
-type graceServer struct {
+// Server embedded http.Server
+type Server struct {
*http.Server
GraceListener net.Listener
SignalHooks map[int]map[os.Signal][]func()
@@ -30,19 +31,19 @@ type graceServer struct {
// Serve accepts incoming connections on the Listener l,
// creating a new service goroutine for each.
// The service goroutines read requests and then call srv.Handler to reply to them.
-func (srv *graceServer) Serve() (err error) {
- srv.state = STATE_RUNNING
+func (srv *Server) Serve() (err error) {
+ srv.state = StateRunning
err = srv.Server.Serve(srv.GraceListener)
log.Println(syscall.Getpid(), "Waiting for connections to finish...")
srv.wg.Wait()
- srv.state = STATE_TERMINATE
+ srv.state = StateTerminate
return
}
// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
// to handle requests on incoming connections. If srv.Addr is blank, ":http" is
// used.
-func (srv *graceServer) ListenAndServe() (err error) {
+func (srv *Server) ListenAndServe() (err error) {
addr := srv.Addr
if addr == "" {
addr = ":http"
@@ -83,7 +84,7 @@ func (srv *graceServer) ListenAndServe() (err error) {
// CA's certificate.
//
// If srv.Addr is blank, ":https" is used.
-func (srv *graceServer) ListenAndServeTLS(certFile, keyFile string) (err error) {
+func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
addr := srv.Addr
if addr == "" {
addr = ":https"
@@ -131,9 +132,9 @@ func (srv *graceServer) ListenAndServeTLS(certFile, keyFile string) (err error)
// getListener either opens a new socket to listen on, or takes the acceptor socket
// it got passed when restarted.
-func (srv *graceServer) getListener(laddr string) (l net.Listener, err error) {
+func (srv *Server) getListener(laddr string) (l net.Listener, err error) {
if srv.isChild {
- var ptrOffset uint = 0
+ var ptrOffset uint
if len(socketPtrOffsetMap) > 0 {
ptrOffset = socketPtrOffsetMap[laddr]
log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr])
@@ -157,7 +158,7 @@ func (srv *graceServer) getListener(laddr string) (l net.Listener, err error) {
// handleSignals listens for os Signals and calls any hooked in function that the
// user had registered with the signal.
-func (srv *graceServer) handleSignals() {
+func (srv *Server) handleSignals() {
var sig os.Signal
signal.Notify(
@@ -170,7 +171,7 @@ func (srv *graceServer) handleSignals() {
pid := syscall.Getpid()
for {
sig = <-srv.sigChan
- srv.signalHooks(PRE_SIGNAL, sig)
+ srv.signalHooks(PreSignal, sig)
switch sig {
case syscall.SIGHUP:
log.Println(pid, "Received SIGHUP. forking.")
@@ -187,11 +188,11 @@ func (srv *graceServer) handleSignals() {
default:
log.Printf("Received %v: nothing i care about...\n", sig)
}
- srv.signalHooks(POST_SIGNAL, sig)
+ srv.signalHooks(PostSignal, sig)
}
}
-func (srv *graceServer) signalHooks(ppFlag int, sig os.Signal) {
+func (srv *Server) signalHooks(ppFlag int, sig os.Signal) {
if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet {
return
}
@@ -204,12 +205,12 @@ func (srv *graceServer) signalHooks(ppFlag int, sig os.Signal) {
// shutdown closes the listener so that no new connections are accepted. it also
// starts a goroutine that will serverTimeout (stop all running requests) the server
// after DefaultTimeout.
-func (srv *graceServer) shutdown() {
- if srv.state != STATE_RUNNING {
+func (srv *Server) shutdown() {
+ if srv.state != StateRunning {
return
}
- srv.state = STATE_SHUTTING_DOWN
+ srv.state = StateShuttingDown
if DefaultTimeout >= 0 {
go srv.serverTimeout(DefaultTimeout)
}
@@ -224,26 +225,26 @@ func (srv *graceServer) shutdown() {
// serverTimeout forces the server to shutdown in a given timeout - whether it
// finished outstanding requests or not. if Read/WriteTimeout are not set or the
// max header size is very big a connection could hang
-func (srv *graceServer) serverTimeout(d time.Duration) {
+func (srv *Server) serverTimeout(d time.Duration) {
defer func() {
if r := recover(); r != nil {
log.Println("WaitGroup at 0", r)
}
}()
- if srv.state != STATE_SHUTTING_DOWN {
+ if srv.state != StateShuttingDown {
return
}
time.Sleep(d)
log.Println("[STOP - Hammer Time] Forcefully shutting down parent")
for {
- if srv.state == STATE_TERMINATE {
+ if srv.state == StateTerminate {
break
}
srv.wg.Done()
}
}
-func (srv *graceServer) fork() (err error) {
+func (srv *Server) fork() (err error) {
regLock.Lock()
defer regLock.Unlock()
if runningServersForked {
diff --git a/hooks.go b/hooks.go
new file mode 100644
index 00000000..78abf8ef
--- /dev/null
+++ b/hooks.go
@@ -0,0 +1,95 @@
+package beego
+
+import (
+ "encoding/json"
+ "mime"
+ "net/http"
+ "path/filepath"
+
+ "github.com/astaxie/beego/session"
+)
+
+//
+func registerMime() error {
+ for k, v := range mimemaps {
+ mime.AddExtensionType(k, v)
+ }
+ return nil
+}
+
+// register default error http handlers, 404,401,403,500 and 503.
+func registerDefaultErrorHandler() error {
+ m := map[string]func(http.ResponseWriter, *http.Request){
+ "401": unauthorized,
+ "402": paymentRequired,
+ "403": forbidden,
+ "404": notFound,
+ "405": methodNotAllowed,
+ "500": internalServerError,
+ "501": notImplemented,
+ "502": badGateway,
+ "503": serviceUnavailable,
+ "504": gatewayTimeout,
+ }
+ for e, h := range m {
+ if _, ok := ErrorMaps[e]; !ok {
+ ErrorHandler(e, h)
+ }
+ }
+ return nil
+}
+
+func registerSession() error {
+ if BConfig.WebConfig.Session.SessionOn {
+ var err error
+ sessionConfig := AppConfig.String("sessionConfig")
+ if sessionConfig == "" {
+ conf := map[string]interface{}{
+ "cookieName": BConfig.WebConfig.Session.SessionName,
+ "gclifetime": BConfig.WebConfig.Session.SessionGCMaxLifetime,
+ "providerConfig": filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig),
+ "secure": BConfig.Listen.EnableHTTPS,
+ "enableSetCookie": BConfig.WebConfig.Session.SessionAutoSetCookie,
+ "domain": BConfig.WebConfig.Session.SessionDomain,
+ "cookieLifeTime": BConfig.WebConfig.Session.SessionCookieLifeTime,
+ }
+ confBytes, err := json.Marshal(conf)
+ if err != nil {
+ return err
+ }
+ sessionConfig = string(confBytes)
+ }
+ if GlobalSessions, err = session.NewManager(BConfig.WebConfig.Session.SessionProvider, sessionConfig); err != nil {
+ return err
+ }
+ go GlobalSessions.GC()
+ }
+ return nil
+}
+
+func registerTemplate() error {
+ if BConfig.WebConfig.AutoRender {
+ if err := BuildTemplate(BConfig.WebConfig.ViewsPath); err != nil {
+ if BConfig.RunMode == DEV {
+ Warn(err)
+ }
+ return err
+ }
+ }
+ return nil
+}
+
+func registerDocs() error {
+ if BConfig.WebConfig.EnableDocs {
+ Get("/docs", serverDocs)
+ Get("/docs/*", serverDocs)
+ }
+ return nil
+}
+
+func registerAdmin() error {
+ if BConfig.Listen.EnableAdmin {
+ go beeAdminApp.Run()
+ }
+ return nil
+}
diff --git a/httplib/httplib.go b/httplib/httplib.go
index 68c22d70..fb64a30a 100644
--- a/httplib/httplib.go
+++ b/httplib/httplib.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package httplib is used as http.Client
// Usage:
//
// import "github.com/astaxie/beego/httplib"
@@ -51,7 +52,7 @@ import (
"time"
)
-var defaultSetting = BeegoHttpSettings{
+var defaultSetting = BeegoHTTPSettings{
UserAgent: "beegoServer",
ConnectTimeout: 60 * time.Second,
ReadWriteTimeout: 60 * time.Second,
@@ -69,25 +70,19 @@ func createDefaultCookie() {
defaultCookieJar, _ = cookiejar.New(nil)
}
-// Overwrite default settings
-func SetDefaultSetting(setting BeegoHttpSettings) {
+// SetDefaultSetting Overwrite default settings
+func SetDefaultSetting(setting BeegoHTTPSettings) {
settingMutex.Lock()
defer settingMutex.Unlock()
defaultSetting = setting
- if defaultSetting.ConnectTimeout == 0 {
- defaultSetting.ConnectTimeout = 60 * time.Second
- }
- if defaultSetting.ReadWriteTimeout == 0 {
- defaultSetting.ReadWriteTimeout = 60 * time.Second
- }
}
-// return *BeegoHttpRequest with specific method
-func NewBeegoRequest(rawurl, method string) *BeegoHttpRequest {
+// NewBeegoRequest return *BeegoHttpRequest with specific method
+func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest {
var resp http.Response
u, err := url.Parse(rawurl)
if err != nil {
- log.Fatal(err)
+ log.Println("Httplib:", err)
}
req := http.Request{
URL: u,
@@ -97,10 +92,10 @@ func NewBeegoRequest(rawurl, method string) *BeegoHttpRequest {
ProtoMajor: 1,
ProtoMinor: 1,
}
- return &BeegoHttpRequest{
+ return &BeegoHTTPRequest{
url: rawurl,
req: &req,
- params: map[string]string{},
+ params: map[string][]string{},
files: map[string]string{},
setting: defaultSetting,
resp: &resp,
@@ -108,37 +103,37 @@ func NewBeegoRequest(rawurl, method string) *BeegoHttpRequest {
}
// Get returns *BeegoHttpRequest with GET method.
-func Get(url string) *BeegoHttpRequest {
+func Get(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "GET")
}
// Post returns *BeegoHttpRequest with POST method.
-func Post(url string) *BeegoHttpRequest {
+func Post(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "POST")
}
// Put returns *BeegoHttpRequest with PUT method.
-func Put(url string) *BeegoHttpRequest {
+func Put(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "PUT")
}
// Delete returns *BeegoHttpRequest DELETE method.
-func Delete(url string) *BeegoHttpRequest {
+func Delete(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "DELETE")
}
// Head returns *BeegoHttpRequest with HEAD method.
-func Head(url string) *BeegoHttpRequest {
+func Head(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "HEAD")
}
-// BeegoHttpSettings
-type BeegoHttpSettings struct {
+// BeegoHTTPSettings is the http.Client setting
+type BeegoHTTPSettings struct {
ShowDebug bool
UserAgent string
ConnectTimeout time.Duration
ReadWriteTimeout time.Duration
- TlsClientConfig *tls.Config
+ TLSClientConfig *tls.Config
Proxy func(*http.Request) (*url.URL, error)
Transport http.RoundTripper
EnableCookie bool
@@ -146,92 +141,92 @@ type BeegoHttpSettings struct {
DumpBody bool
}
-// BeegoHttpRequest provides more useful methods for requesting one url than http.Request.
-type BeegoHttpRequest struct {
+// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request.
+type BeegoHTTPRequest struct {
url string
req *http.Request
- params map[string]string
+ params map[string][]string
files map[string]string
- setting BeegoHttpSettings
+ setting BeegoHTTPSettings
resp *http.Response
body []byte
dump []byte
}
-// get request
-func (b *BeegoHttpRequest) GetRequest() *http.Request {
+// GetRequest return the request object
+func (b *BeegoHTTPRequest) GetRequest() *http.Request {
return b.req
}
-// Change request settings
-func (b *BeegoHttpRequest) Setting(setting BeegoHttpSettings) *BeegoHttpRequest {
+// Setting Change request settings
+func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest {
b.setting = setting
return b
}
// SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password.
-func (b *BeegoHttpRequest) SetBasicAuth(username, password string) *BeegoHttpRequest {
+func (b *BeegoHTTPRequest) SetBasicAuth(username, password string) *BeegoHTTPRequest {
b.req.SetBasicAuth(username, password)
return b
}
// SetEnableCookie sets enable/disable cookiejar
-func (b *BeegoHttpRequest) SetEnableCookie(enable bool) *BeegoHttpRequest {
+func (b *BeegoHTTPRequest) SetEnableCookie(enable bool) *BeegoHTTPRequest {
b.setting.EnableCookie = enable
return b
}
// SetUserAgent sets User-Agent header field
-func (b *BeegoHttpRequest) SetUserAgent(useragent string) *BeegoHttpRequest {
+func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest {
b.setting.UserAgent = useragent
return b
}
// Debug sets show debug or not when executing request.
-func (b *BeegoHttpRequest) Debug(isdebug bool) *BeegoHttpRequest {
+func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest {
b.setting.ShowDebug = isdebug
return b
}
-// Dump Body.
-func (b *BeegoHttpRequest) DumpBody(isdump bool) *BeegoHttpRequest {
+// DumpBody setting whether need to Dump the Body.
+func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest {
b.setting.DumpBody = isdump
return b
}
-// return the DumpRequest
-func (b *BeegoHttpRequest) DumpRequest() []byte {
+// DumpRequest return the DumpRequest
+func (b *BeegoHTTPRequest) DumpRequest() []byte {
return b.dump
}
// SetTimeout sets connect time out and read-write time out for BeegoRequest.
-func (b *BeegoHttpRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHttpRequest {
+func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest {
b.setting.ConnectTimeout = connectTimeout
b.setting.ReadWriteTimeout = readWriteTimeout
return b
}
// SetTLSClientConfig sets tls connection configurations if visiting https url.
-func (b *BeegoHttpRequest) SetTLSClientConfig(config *tls.Config) *BeegoHttpRequest {
- b.setting.TlsClientConfig = config
+func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest {
+ b.setting.TLSClientConfig = config
return b
}
// Header add header item string in request.
-func (b *BeegoHttpRequest) Header(key, value string) *BeegoHttpRequest {
+func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest {
b.req.Header.Set(key, value)
return b
}
-// Set HOST
-func (b *BeegoHttpRequest) SetHost(host string) *BeegoHttpRequest {
+// SetHost set the request host
+func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest {
b.req.Host = host
return b
}
-// Set the protocol version for incoming requests.
+// SetProtocolVersion Set the protocol version for incoming requests.
// Client requests always use HTTP/1.1.
-func (b *BeegoHttpRequest) SetProtocolVersion(vers string) *BeegoHttpRequest {
+func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest {
if len(vers) == 0 {
vers = "HTTP/1.1"
}
@@ -247,44 +242,49 @@ func (b *BeegoHttpRequest) SetProtocolVersion(vers string) *BeegoHttpRequest {
}
// SetCookie add cookie into request.
-func (b *BeegoHttpRequest) SetCookie(cookie *http.Cookie) *BeegoHttpRequest {
+func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest {
b.req.Header.Add("Cookie", cookie.String())
return b
}
-// Set transport to
-func (b *BeegoHttpRequest) SetTransport(transport http.RoundTripper) *BeegoHttpRequest {
+// SetTransport set the setting transport
+func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest {
b.setting.Transport = transport
return b
}
-// Set http proxy
+// SetProxy set the http proxy
// example:
//
// func(req *http.Request) (*url.URL, error) {
// u, _ := url.ParseRequestURI("http://127.0.0.1:8118")
// return u, nil
// }
-func (b *BeegoHttpRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHttpRequest {
+func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHTTPRequest {
b.setting.Proxy = proxy
return b
}
// Param adds query param in to request.
// params build query string as ?key1=value1&key2=value2...
-func (b *BeegoHttpRequest) Param(key, value string) *BeegoHttpRequest {
- b.params[key] = value
+func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest {
+ if param, ok := b.params[key]; ok {
+ b.params[key] = append(param, value)
+ } else {
+ b.params[key] = []string{value}
+ }
return b
}
-func (b *BeegoHttpRequest) PostFile(formname, filename string) *BeegoHttpRequest {
+// PostFile add a post file to the request
+func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest {
b.files[formname] = filename
return b
}
// Body adds request raw body.
// it supports string and []byte.
-func (b *BeegoHttpRequest) Body(data interface{}) *BeegoHttpRequest {
+func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest {
switch t := data.(type) {
case string:
bf := bytes.NewBufferString(t)
@@ -298,8 +298,8 @@ func (b *BeegoHttpRequest) Body(data interface{}) *BeegoHttpRequest {
return b
}
-// JsonBody adds request raw body encoding by JSON.
-func (b *BeegoHttpRequest) JsonBody(obj interface{}) (*BeegoHttpRequest, error) {
+// JSONBody adds request raw body encoding by JSON.
+func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) {
if b.req.Body == nil && obj != nil {
buf := bytes.NewBuffer(nil)
enc := json.NewEncoder(buf)
@@ -313,7 +313,7 @@ func (b *BeegoHttpRequest) JsonBody(obj interface{}) (*BeegoHttpRequest, error)
return b, nil
}
-func (b *BeegoHttpRequest) buildUrl(paramBody string) {
+func (b *BeegoHTTPRequest) buildURL(paramBody string) {
// build GET url with query string
if b.req.Method == "GET" && len(paramBody) > 0 {
if strings.Index(b.url, "?") != -1 {
@@ -334,21 +334,23 @@ func (b *BeegoHttpRequest) buildUrl(paramBody string) {
for formname, filename := range b.files {
fileWriter, err := bodyWriter.CreateFormFile(formname, filename)
if err != nil {
- log.Fatal(err)
+ log.Println("Httplib:", err)
}
fh, err := os.Open(filename)
if err != nil {
- log.Fatal(err)
+ log.Println("Httplib:", err)
}
//iocopy
_, err = io.Copy(fileWriter, fh)
fh.Close()
if err != nil {
- log.Fatal(err)
+ log.Println("Httplib:", err)
}
}
for k, v := range b.params {
- bodyWriter.WriteField(k, v)
+ for _, vv := range v {
+ bodyWriter.WriteField(k, vv)
+ }
}
bodyWriter.Close()
pw.Close()
@@ -366,11 +368,11 @@ func (b *BeegoHttpRequest) buildUrl(paramBody string) {
}
}
-func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
+func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) {
if b.resp.StatusCode != 0 {
return b.resp, nil
}
- resp, err := b.SendOut()
+ resp, err := b.DoRequest()
if err != nil {
return nil, err
}
@@ -378,21 +380,24 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
return resp, nil
}
-func (b *BeegoHttpRequest) SendOut() (*http.Response, error) {
+// DoRequest will do the client.Do
+func (b *BeegoHTTPRequest) DoRequest() (*http.Response, error) {
var paramBody string
if len(b.params) > 0 {
var buf bytes.Buffer
for k, v := range b.params {
- buf.WriteString(url.QueryEscape(k))
- buf.WriteByte('=')
- buf.WriteString(url.QueryEscape(v))
- buf.WriteByte('&')
+ for _, vv := range v {
+ buf.WriteString(url.QueryEscape(k))
+ buf.WriteByte('=')
+ buf.WriteString(url.QueryEscape(vv))
+ buf.WriteByte('&')
+ }
}
paramBody = buf.String()
paramBody = paramBody[0 : len(paramBody)-1]
}
- b.buildUrl(paramBody)
+ b.buildURL(paramBody)
url, err := url.Parse(b.url)
if err != nil {
return nil, err
@@ -405,7 +410,7 @@ func (b *BeegoHttpRequest) SendOut() (*http.Response, error) {
if trans == nil {
// create default transport
trans = &http.Transport{
- TLSClientConfig: b.setting.TlsClientConfig,
+ TLSClientConfig: b.setting.TLSClientConfig,
Proxy: b.setting.Proxy,
Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout),
}
@@ -413,7 +418,7 @@ func (b *BeegoHttpRequest) SendOut() (*http.Response, error) {
// if b.transport is *http.Transport then set the settings.
if t, ok := trans.(*http.Transport); ok {
if t.TLSClientConfig == nil {
- t.TLSClientConfig = b.setting.TlsClientConfig
+ t.TLSClientConfig = b.setting.TLSClientConfig
}
if t.Proxy == nil {
t.Proxy = b.setting.Proxy
@@ -424,7 +429,7 @@ func (b *BeegoHttpRequest) SendOut() (*http.Response, error) {
}
}
- var jar http.CookieJar = nil
+ var jar http.CookieJar
if b.setting.EnableCookie {
if defaultCookieJar == nil {
createDefaultCookie()
@@ -453,7 +458,7 @@ func (b *BeegoHttpRequest) SendOut() (*http.Response, error) {
// String returns the body string in response.
// it calls Response inner.
-func (b *BeegoHttpRequest) String() (string, error) {
+func (b *BeegoHTTPRequest) String() (string, error) {
data, err := b.Bytes()
if err != nil {
return "", err
@@ -464,7 +469,7 @@ func (b *BeegoHttpRequest) String() (string, error) {
// Bytes returns the body []byte in response.
// it calls Response inner.
-func (b *BeegoHttpRequest) Bytes() ([]byte, error) {
+func (b *BeegoHTTPRequest) Bytes() ([]byte, error) {
if b.body != nil {
return b.body, nil
}
@@ -490,7 +495,7 @@ func (b *BeegoHttpRequest) Bytes() ([]byte, error) {
// ToFile saves the body data in response to one file.
// it calls Response inner.
-func (b *BeegoHttpRequest) ToFile(filename string) error {
+func (b *BeegoHTTPRequest) ToFile(filename string) error {
f, err := os.Create(filename)
if err != nil {
return err
@@ -509,9 +514,9 @@ func (b *BeegoHttpRequest) ToFile(filename string) error {
return err
}
-// ToJson returns the map that marshals from the body bytes as json in response .
+// ToJSON returns the map that marshals from the body bytes as json in response .
// it calls Response inner.
-func (b *BeegoHttpRequest) ToJson(v interface{}) error {
+func (b *BeegoHTTPRequest) ToJSON(v interface{}) error {
data, err := b.Bytes()
if err != nil {
return err
@@ -519,9 +524,9 @@ func (b *BeegoHttpRequest) ToJson(v interface{}) error {
return json.Unmarshal(data, v)
}
-// ToXml returns the map that marshals from the body bytes as xml in response .
+// ToXML returns the map that marshals from the body bytes as xml in response .
// it calls Response inner.
-func (b *BeegoHttpRequest) ToXml(v interface{}) error {
+func (b *BeegoHTTPRequest) ToXML(v interface{}) error {
data, err := b.Bytes()
if err != nil {
return err
@@ -530,7 +535,7 @@ func (b *BeegoHttpRequest) ToXml(v interface{}) error {
}
// Response executes request client gets response mannually.
-func (b *BeegoHttpRequest) Response() (*http.Response, error) {
+func (b *BeegoHTTPRequest) Response() (*http.Response, error) {
return b.getResponse()
}
diff --git a/httplib/httplib_test.go b/httplib/httplib_test.go
index 0b551c53..05815054 100644
--- a/httplib/httplib_test.go
+++ b/httplib/httplib_test.go
@@ -19,6 +19,7 @@ import (
"os"
"strings"
"testing"
+ "time"
)
func TestResponse(t *testing.T) {
@@ -149,10 +150,11 @@ func TestWithUserAgent(t *testing.T) {
func TestWithSetting(t *testing.T) {
v := "beego"
- var setting BeegoHttpSettings
+ var setting BeegoHTTPSettings
setting.EnableCookie = true
setting.UserAgent = v
setting.Transport = nil
+ setting.ReadWriteTimeout = 5 * time.Second
SetDefaultSetting(setting)
str, err := Get("http://httpbin.org/get").String()
@@ -176,11 +178,11 @@ func TestToJson(t *testing.T) {
t.Log(resp)
// httpbin will return http remote addr
- type Ip struct {
+ type IP struct {
Origin string `json:"origin"`
}
- var ip Ip
- err = req.ToJson(&ip)
+ var ip IP
+ err = req.ToJSON(&ip)
if err != nil {
t.Fatal(err)
}
diff --git a/log.go b/log.go
index 7949ed96..46ec57dd 100644
--- a/log.go
+++ b/log.go
@@ -32,20 +32,20 @@ const (
LevelDebug
)
-// SetLogLevel sets the global log level used by the simple
-// logger.
+// BeeLogger references the used application logger.
+var BeeLogger = logs.NewLogger(100)
+
+// SetLevel sets the global log level used by the simple logger.
func SetLevel(l int) {
BeeLogger.SetLevel(l)
}
+// SetLogFuncCall set the CallDepth, default is 3
func SetLogFuncCall(b bool) {
BeeLogger.EnableFuncCallDepth(b)
BeeLogger.SetLogFuncCallDepth(3)
}
-// logger references the used application logger.
-var BeeLogger *logs.BeeLogger
-
// SetLogger sets a new logger.
func SetLogger(adaptername string, config string) error {
err := BeeLogger.SetLogger(adaptername, config)
@@ -55,10 +55,12 @@ func SetLogger(adaptername string, config string) error {
return nil
}
+// Emergency logs a message at emergency level.
func Emergency(v ...interface{}) {
BeeLogger.Emergency(generateFmtStr(len(v)), v...)
}
+// Alert logs a message at alert level.
func Alert(v ...interface{}) {
BeeLogger.Alert(generateFmtStr(len(v)), v...)
}
@@ -78,21 +80,22 @@ func Warning(v ...interface{}) {
BeeLogger.Warning(generateFmtStr(len(v)), v...)
}
-// compatibility alias for Warning()
+// Warn compatibility alias for Warning()
func Warn(v ...interface{}) {
BeeLogger.Warn(generateFmtStr(len(v)), v...)
}
+// Notice logs a message at notice level.
func Notice(v ...interface{}) {
BeeLogger.Notice(generateFmtStr(len(v)), v...)
}
-// Info logs a message at info level.
+// Informational logs a message at info level.
func Informational(v ...interface{}) {
BeeLogger.Informational(generateFmtStr(len(v)), v...)
}
-// compatibility alias for Warning()
+// Info compatibility alias for Warning()
func Info(v ...interface{}) {
BeeLogger.Info(generateFmtStr(len(v)), v...)
}
diff --git a/logs/conn.go b/logs/conn.go
index 2240eece..3655bf51 100644
--- a/logs/conn.go
+++ b/logs/conn.go
@@ -21,9 +21,9 @@ import (
"net"
)
-// ConnWriter implements LoggerInterface.
+// connWriter implements LoggerInterface.
// it writes messages in keep-live tcp connection.
-type ConnWriter struct {
+type connWriter struct {
lg *log.Logger
innerWriter io.WriteCloser
ReconnectOnMsg bool `json:"reconnectOnMsg"`
@@ -33,22 +33,22 @@ type ConnWriter struct {
Level int `json:"level"`
}
-// create new ConnWrite returning as LoggerInterface.
-func NewConn() LoggerInterface {
- conn := new(ConnWriter)
+// NewConn create new ConnWrite returning as LoggerInterface.
+func NewConn() Logger {
+ conn := new(connWriter)
conn.Level = LevelTrace
return conn
}
-// init connection writer with json config.
+// Init init connection writer with json config.
// json config only need key "level".
-func (c *ConnWriter) Init(jsonconfig string) error {
+func (c *connWriter) Init(jsonconfig string) error {
return json.Unmarshal([]byte(jsonconfig), c)
}
-// write message in connection.
+// WriteMsg write message in connection.
// if connection is down, try to re-connect.
-func (c *ConnWriter) WriteMsg(msg string, level int) error {
+func (c *connWriter) WriteMsg(msg string, level int) error {
if level > c.Level {
return nil
}
@@ -66,19 +66,19 @@ func (c *ConnWriter) WriteMsg(msg string, level int) error {
return nil
}
-// implementing method. empty.
-func (c *ConnWriter) Flush() {
+// Flush implementing method. empty.
+func (c *connWriter) Flush() {
}
-// destroy connection writer and close tcp listener.
-func (c *ConnWriter) Destroy() {
+// Destroy destroy connection writer and close tcp listener.
+func (c *connWriter) Destroy() {
if c.innerWriter != nil {
c.innerWriter.Close()
}
}
-func (c *ConnWriter) connect() error {
+func (c *connWriter) connect() error {
if c.innerWriter != nil {
c.innerWriter.Close()
c.innerWriter = nil
@@ -98,7 +98,7 @@ func (c *ConnWriter) connect() error {
return nil
}
-func (c *ConnWriter) neddedConnectOnMsg() bool {
+func (c *connWriter) neddedConnectOnMsg() bool {
if c.Reconnect {
c.Reconnect = false
return true
diff --git a/logs/console.go b/logs/console.go
index ce7ecd54..23e8ebca 100644
--- a/logs/console.go
+++ b/logs/console.go
@@ -21,9 +21,11 @@ import (
"runtime"
)
-type Brush func(string) string
+// brush is a color join function
+type brush func(string) string
-func NewBrush(color string) Brush {
+// newBrush return a fix color Brush
+func newBrush(color string) brush {
pre := "\033["
reset := "\033[0m"
return func(text string) string {
@@ -31,43 +33,43 @@ func NewBrush(color string) Brush {
}
}
-var colors = []Brush{
- NewBrush("1;37"), // Emergency white
- NewBrush("1;36"), // Alert cyan
- NewBrush("1;35"), // Critical magenta
- NewBrush("1;31"), // Error red
- NewBrush("1;33"), // Warning yellow
- NewBrush("1;32"), // Notice green
- NewBrush("1;34"), // Informational blue
- NewBrush("1;34"), // Debug blue
+var colors = []brush{
+ newBrush("1;37"), // Emergency white
+ newBrush("1;36"), // Alert cyan
+ newBrush("1;35"), // Critical magenta
+ newBrush("1;31"), // Error red
+ newBrush("1;33"), // Warning yellow
+ newBrush("1;32"), // Notice green
+ newBrush("1;34"), // Informational blue
+ newBrush("1;34"), // Debug blue
}
-// ConsoleWriter implements LoggerInterface and writes messages to terminal.
-type ConsoleWriter struct {
+// consoleWriter implements LoggerInterface and writes messages to terminal.
+type consoleWriter struct {
lg *log.Logger
Level int `json:"level"`
}
-// create ConsoleWriter returning as LoggerInterface.
-func NewConsole() LoggerInterface {
- cw := &ConsoleWriter{
+// NewConsole create ConsoleWriter returning as LoggerInterface.
+func NewConsole() Logger {
+ cw := &consoleWriter{
lg: log.New(os.Stdout, "", log.Ldate|log.Ltime),
Level: LevelDebug,
}
return cw
}
-// init console logger.
+// Init init console logger.
// jsonconfig like '{"level":LevelTrace}'.
-func (c *ConsoleWriter) Init(jsonconfig string) error {
+func (c *consoleWriter) Init(jsonconfig string) error {
if len(jsonconfig) == 0 {
return nil
}
return json.Unmarshal([]byte(jsonconfig), c)
}
-// write message in console.
-func (c *ConsoleWriter) WriteMsg(msg string, level int) error {
+// WriteMsg write message in console.
+func (c *consoleWriter) WriteMsg(msg string, level int) error {
if level > c.Level {
return nil
}
@@ -80,13 +82,13 @@ func (c *ConsoleWriter) WriteMsg(msg string, level int) error {
return nil
}
-// implementing method. empty.
-func (c *ConsoleWriter) Destroy() {
+// Destroy implementing method. empty.
+func (c *consoleWriter) Destroy() {
}
-// implementing method. empty.
-func (c *ConsoleWriter) Flush() {
+// Flush implementing method. empty.
+func (c *consoleWriter) Flush() {
}
diff --git a/logs/console_test.go b/logs/console_test.go
index 2fad7241..c4bb1da2 100644
--- a/logs/console_test.go
+++ b/logs/console_test.go
@@ -43,11 +43,3 @@ func TestConsole(t *testing.T) {
testConsoleCalls(log2)
}
-func BenchmarkConsole(b *testing.B) {
- log := NewLogger(10000)
- log.EnableFuncCallDepth(true)
- log.SetLogger("console", "")
- for i := 0; i < b.N; i++ {
- log.Debug("debug")
- }
-}
diff --git a/logs/es/es.go b/logs/es/es.go
index 3a73d4dd..f8dc5f65 100644
--- a/logs/es/es.go
+++ b/logs/es/es.go
@@ -12,7 +12,8 @@ import (
"github.com/belogik/goes"
)
-func NewES() logs.LoggerInterface {
+// NewES return a LoggerInterface
+func NewES() logs.Logger {
cw := &esLogger{
Level: logs.LevelDebug,
}
@@ -46,6 +47,7 @@ func (el *esLogger) Init(jsonconfig string) error {
return nil
}
+// WriteMsg will write the msg and level into es
func (el *esLogger) WriteMsg(msg string, level int) error {
if level > el.Level {
return nil
@@ -63,10 +65,12 @@ func (el *esLogger) WriteMsg(msg string, level int) error {
return err
}
+// Destroy is a empty method
func (el *esLogger) Destroy() {
}
+// Flush is a empty method
func (el *esLogger) Flush() {
}
diff --git a/logs/file.go b/logs/file.go
index 2d3449ce..0eae734a 100644
--- a/logs/file.go
+++ b/logs/file.go
@@ -20,7 +20,6 @@ import (
"errors"
"fmt"
"io"
- "log"
"os"
"path/filepath"
"strings"
@@ -28,84 +27,62 @@ import (
"time"
)
-// FileLogWriter implements LoggerInterface.
+// fileLogWriter implements LoggerInterface.
// It writes messages by lines limit, file size limit, or time frequency.
-type FileLogWriter struct {
- *log.Logger
- mw *MuxWriter
+type fileLogWriter struct {
+ sync.Mutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize
// The opened file
- Filename string `json:"filename"`
+ Filename string `json:"filename"`
+ fileWriter *os.File
- Maxlines int `json:"maxlines"`
- maxlines_curlines int
+ // Rotate at line
+ MaxLines int `json:"maxlines"`
+ maxLinesCurLines int
// Rotate at size
- Maxsize int `json:"maxsize"`
- maxsize_cursize int
+ MaxSize int `json:"maxsize"`
+ maxSizeCurSize int
// Rotate daily
- Daily bool `json:"daily"`
- Maxdays int64 `json:"maxdays"`
- daily_opendate int
+ Daily bool `json:"daily"`
+ MaxDays int64 `json:"maxdays"`
+ dailyOpenDate int
Rotate bool `json:"rotate"`
- startLock sync.Mutex // Only one log can write to the file
-
Level int `json:"level"`
+
+ Perm os.FileMode `json:"perm"`
}
-// an *os.File writer with locker.
-type MuxWriter struct {
- sync.Mutex
- fd *os.File
-}
-
-// write to os.File.
-func (l *MuxWriter) Write(b []byte) (int, error) {
- l.Lock()
- defer l.Unlock()
- return l.fd.Write(b)
-}
-
-// set os.File in writer.
-func (l *MuxWriter) SetFd(fd *os.File) {
- if l.fd != nil {
- l.fd.Close()
- }
- l.fd = fd
-}
-
-// create a FileLogWriter returning as LoggerInterface.
-func NewFileWriter() LoggerInterface {
- w := &FileLogWriter{
+// NewFileWriter create a FileLogWriter returning as LoggerInterface.
+func newFileWriter() Logger {
+ w := &fileLogWriter{
Filename: "",
- Maxlines: 1000000,
- Maxsize: 1 << 28, //256 MB
+ MaxLines: 1000000,
+ MaxSize: 1 << 28, //256 MB
Daily: true,
- Maxdays: 7,
+ MaxDays: 7,
Rotate: true,
Level: LevelTrace,
+ Perm: 0660,
}
- // use MuxWriter instead direct use os.File for lock write when rotate
- w.mw = new(MuxWriter)
- // set MuxWriter as Logger's io.Writer
- w.Logger = log.New(w.mw, "", log.Ldate|log.Ltime)
return w
}
// Init file logger with json config.
-// jsonconfig like:
+// jsonConfig like:
// {
// "filename":"logs/beego.log",
-// "maxlines":10000,
+// "maxLines":10000,
// "maxsize":1<<30,
// "daily":true,
-// "maxdays":15,
-// "rotate":true
+// "maxDays":15,
+// "rotate":true,
+// "perm":0600
// }
-func (w *FileLogWriter) Init(jsonconfig string) error {
- err := json.Unmarshal([]byte(jsonconfig), w)
+func (w *fileLogWriter) Init(jsonConfig string) error {
+ err := json.Unmarshal([]byte(jsonConfig), w)
if err != nil {
return err
}
@@ -117,67 +94,120 @@ func (w *FileLogWriter) Init(jsonconfig string) error {
}
// start file logger. create log file and set to locker-inside file writer.
-func (w *FileLogWriter) startLogger() error {
- fd, err := w.createLogFile()
+func (w *fileLogWriter) startLogger() error {
+ file, err := w.createLogFile()
if err != nil {
return err
}
- w.mw.SetFd(fd)
+ if w.fileWriter != nil {
+ w.fileWriter.Close()
+ }
+ w.fileWriter = file
return w.initFd()
}
-func (w *FileLogWriter) docheck(size int) {
- w.startLock.Lock()
- defer w.startLock.Unlock()
- if w.Rotate && ((w.Maxlines > 0 && w.maxlines_curlines >= w.Maxlines) ||
- (w.Maxsize > 0 && w.maxsize_cursize >= w.Maxsize) ||
- (w.Daily && time.Now().Day() != w.daily_opendate)) {
- if err := w.DoRotate(); err != nil {
- fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
- return
- }
- }
- w.maxlines_curlines++
- w.maxsize_cursize += size
+func (w *fileLogWriter) needRotate(size int, day int) bool {
+ return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) ||
+ (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) ||
+ (w.Daily && day != w.dailyOpenDate)
+
}
-// write logger message into file.
-func (w *FileLogWriter) WriteMsg(msg string, level int) error {
+// WriteMsg write logger message into file.
+func (w *fileLogWriter) WriteMsg(msg string, level int) error {
if level > w.Level {
return nil
}
- n := 24 + len(msg) // 24 stand for the length "2013/06/23 21:00:22 [T] "
- w.docheck(n)
- w.Logger.Println(msg)
- return nil
+ //2016/01/12 21:34:33
+ now := time.Now()
+ y, mo, d := now.Date()
+ h, mi, s := now.Clock()
+ //len(2006/01/02 15:03:04)==19
+ var buf [20]byte
+ t := 3
+ for y >= 10 {
+ p := y / 10
+ buf[t] = byte('0' + y - p*10)
+ y = p
+ t--
+ }
+ buf[0] = byte('0' + y)
+ buf[4] = '/'
+ if mo > 9 {
+ buf[5] = '1'
+ buf[6] = byte('0' + mo - 9)
+ } else {
+ buf[5] = '0'
+ buf[6] = byte('0' + mo)
+ }
+ buf[7] = '/'
+ t = d / 10
+ buf[8] = byte('0' + t)
+ buf[9] = byte('0' + d - t*10)
+ buf[10] = ' '
+ t = h / 10
+ buf[11] = byte('0' + t)
+ buf[12] = byte('0' + h - t*10)
+ buf[13] = ':'
+ t = mi / 10
+ buf[14] = byte('0' + t)
+ buf[15] = byte('0' + mi - t*10)
+ buf[16] = ':'
+ t = s / 10
+ buf[17] = byte('0' + t)
+ buf[18] = byte('0' + s - t*10)
+ buf[19] = ' '
+ msg = string(buf[0:]) + msg + "\n"
+
+ if w.Rotate {
+ if w.needRotate(len(msg), d) {
+ w.Lock()
+ if w.needRotate(len(msg), d) {
+ if err := w.doRotate(); err != nil {
+ fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
+ }
+
+ }
+ w.Unlock()
+ }
+ }
+
+ w.Lock()
+ _, err := w.fileWriter.Write([]byte(msg))
+ if err == nil {
+ w.maxLinesCurLines++
+ w.maxSizeCurSize += len(msg)
+ }
+ w.Unlock()
+ return err
}
-func (w *FileLogWriter) createLogFile() (*os.File, error) {
+func (w *fileLogWriter) createLogFile() (*os.File, error) {
// Open the log file
- fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0660)
+ fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, w.Perm)
return fd, err
}
-func (w *FileLogWriter) initFd() error {
- fd := w.mw.fd
- finfo, err := fd.Stat()
+func (w *fileLogWriter) initFd() error {
+ fd := w.fileWriter
+ fInfo, err := fd.Stat()
if err != nil {
return fmt.Errorf("get stat err: %s\n", err)
}
- w.maxsize_cursize = int(finfo.Size())
- w.daily_opendate = time.Now().Day()
- w.maxlines_curlines = 0
- if finfo.Size() > 0 {
+ w.maxSizeCurSize = int(fInfo.Size())
+ w.dailyOpenDate = time.Now().Day()
+ w.maxLinesCurLines = 0
+ if fInfo.Size() > 0 {
count, err := w.lines()
if err != nil {
return err
}
- w.maxlines_curlines = count
+ w.maxLinesCurLines = count
}
return nil
}
-func (w *FileLogWriter) lines() (int, error) {
+func (w *fileLogWriter) lines() (int, error) {
fd, err := os.Open(w.Filename)
if err != nil {
return 0, err
@@ -205,59 +235,60 @@ func (w *FileLogWriter) lines() (int, error) {
}
// DoRotate means it need to write file in new file.
-// new file name like xx.log.2013-01-01.2
-func (w *FileLogWriter) DoRotate() error {
+// new file name like xx.2013-01-01.2.log
+func (w *fileLogWriter) doRotate() error {
_, err := os.Lstat(w.Filename)
- if err == nil { // file exists
- // Find the next available number
- num := 1
- fname := ""
- for ; err == nil && num <= 999; num++ {
- fname = w.Filename + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), num)
- _, err = os.Lstat(fname)
- }
- // return error if the last file checked still existed
- if err == nil {
- return fmt.Errorf("Rotate: Cannot find free log number to rename %s\n", w.Filename)
- }
-
- // block Logger's io.Writer
- w.mw.Lock()
- defer w.mw.Unlock()
-
- fd := w.mw.fd
- fd.Close()
-
- // close fd before rename
- // Rename the file to its newfound home
- err = os.Rename(w.Filename, fname)
- if err != nil {
- return fmt.Errorf("Rotate: %s\n", err)
- }
-
- // re-start logger
- err = w.startLogger()
- if err != nil {
- return fmt.Errorf("Rotate StartLogger: %s\n", err)
- }
-
- go w.deleteOldLog()
+ if err != nil {
+ return err
+ }
+ // file exists
+ // Find the next available number
+ num := 1
+ fName := ""
+ suffix := filepath.Ext(w.Filename)
+ filenameOnly := strings.TrimSuffix(w.Filename, suffix)
+ if suffix == "" {
+ suffix = ".log"
+ }
+ for ; err == nil && num <= 999; num++ {
+ fName = filenameOnly + fmt.Sprintf(".%s.%03d%s", time.Now().Format("2006-01-02"), num, suffix)
+ _, err = os.Lstat(fName)
+ }
+ // return error if the last file checked still existed
+ if err == nil {
+ return fmt.Errorf("Rotate: Cannot find free log number to rename %s\n", w.Filename)
}
+ // close fileWriter before rename
+ w.fileWriter.Close()
+
+ // Rename the file to its new found name
+ // even if occurs error,we MUST guarantee to restart new logger
+ renameErr := os.Rename(w.Filename, fName)
+ // re-start logger
+ startLoggerErr := w.startLogger()
+ go w.deleteOldLog()
+
+ if startLoggerErr != nil {
+ return fmt.Errorf("Rotate StartLogger: %s\n", startLoggerErr)
+ }
+ if renameErr != nil {
+ return fmt.Errorf("Rotate: %s\n", renameErr)
+ }
return nil
+
}
-func (w *FileLogWriter) deleteOldLog() {
+func (w *fileLogWriter) deleteOldLog() {
dir := filepath.Dir(w.Filename)
filepath.Walk(dir, func(path string, info os.FileInfo, err error) (returnErr error) {
defer func() {
if r := recover(); r != nil {
- returnErr = fmt.Errorf("Unable to delete old log '%s', error: %+v", path, r)
- fmt.Println(returnErr)
+ fmt.Fprintf(os.Stderr, "Unable to delete old log '%s', error: %v\n", path, r)
}
}()
- if !info.IsDir() && info.ModTime().Unix() < (time.Now().Unix()-60*60*24*w.Maxdays) {
+ if !info.IsDir() && info.ModTime().Unix() < (time.Now().Unix()-60*60*24*w.MaxDays) {
if strings.HasPrefix(filepath.Base(path), filepath.Base(w.Filename)) {
os.Remove(path)
}
@@ -266,18 +297,18 @@ func (w *FileLogWriter) deleteOldLog() {
})
}
-// destroy file logger, close file writer.
-func (w *FileLogWriter) Destroy() {
- w.mw.fd.Close()
+// Destroy close the file description, close file writer.
+func (w *fileLogWriter) Destroy() {
+ w.fileWriter.Close()
}
-// flush file logger.
+// Flush flush file logger.
// there are no buffering messages in file logger in memory.
// flush file means sync file from disk.
-func (w *FileLogWriter) Flush() {
- w.mw.fd.Sync()
+func (w *fileLogWriter) Flush() {
+ w.fileWriter.Sync()
}
func init() {
- Register("file", NewFileWriter)
+ Register("file", newFileWriter)
}
diff --git a/logs/file_test.go b/logs/file_test.go
index c71e9bb4..f9b54c26 100644
--- a/logs/file_test.go
+++ b/logs/file_test.go
@@ -23,7 +23,7 @@ import (
"time"
)
-func TestFile(t *testing.T) {
+func TestFile1(t *testing.T) {
log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test.log"}`)
log.Debug("debug")
@@ -34,25 +34,24 @@ func TestFile(t *testing.T) {
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
- time.Sleep(time.Second * 4)
f, err := os.Open("test.log")
if err != nil {
t.Fatal(err)
}
b := bufio.NewReader(f)
- linenum := 0
+ lineNum := 0
for {
line, _, err := b.ReadLine()
if err != nil {
break
}
if len(line) > 0 {
- linenum++
+ lineNum++
}
}
var expected = LevelDebug + 1
- if linenum != expected {
- t.Fatal(linenum, "not "+strconv.Itoa(expected)+" lines")
+ if lineNum != expected {
+ t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines")
}
os.Remove("test.log")
}
@@ -68,25 +67,24 @@ func TestFile2(t *testing.T) {
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
- time.Sleep(time.Second * 4)
f, err := os.Open("test2.log")
if err != nil {
t.Fatal(err)
}
b := bufio.NewReader(f)
- linenum := 0
+ lineNum := 0
for {
line, _, err := b.ReadLine()
if err != nil {
break
}
if len(line) > 0 {
- linenum++
+ lineNum++
}
}
var expected = LevelError + 1
- if linenum != expected {
- t.Fatal(linenum, "not "+strconv.Itoa(expected)+" lines")
+ if lineNum != expected {
+ t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines")
}
os.Remove("test2.log")
}
@@ -102,13 +100,13 @@ func TestFileRotate(t *testing.T) {
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
- time.Sleep(time.Second * 4)
- rotatename := "test3.log" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1)
- b, err := exists(rotatename)
+ rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log"
+ b, err := exists(rotateName)
if !b || err != nil {
+ os.Remove("test3.log")
t.Fatal("rotate not generated")
}
- os.Remove(rotatename)
+ os.Remove(rotateName)
os.Remove("test3.log")
}
@@ -131,3 +129,46 @@ func BenchmarkFile(b *testing.B) {
}
os.Remove("test4.log")
}
+
+
+func BenchmarkFileAsynchronous(b *testing.B) {
+ log := NewLogger(100000)
+ log.SetLogger("file", `{"filename":"test4.log"}`)
+ log.Async()
+ for i := 0; i < b.N; i++ {
+ log.Debug("debug")
+ }
+ os.Remove("test4.log")
+}
+
+func BenchmarkFileCallDepth(b *testing.B) {
+ log := NewLogger(100000)
+ log.SetLogger("file", `{"filename":"test4.log"}`)
+ log.EnableFuncCallDepth(true)
+ log.SetLogFuncCallDepth(2)
+ for i := 0; i < b.N; i++ {
+ log.Debug("debug")
+ }
+ os.Remove("test4.log")
+}
+
+func BenchmarkFileAsynchronousCallDepth(b *testing.B) {
+ log := NewLogger(100000)
+ log.SetLogger("file", `{"filename":"test4.log"}`)
+ log.EnableFuncCallDepth(true)
+ log.SetLogFuncCallDepth(2)
+ log.Async()
+ for i := 0; i < b.N; i++ {
+ log.Debug("debug")
+ }
+ os.Remove("test4.log")
+}
+
+func BenchmarkFileOnGoroutine(b *testing.B) {
+ log := NewLogger(100000)
+ log.SetLogger("file", `{"filename":"test4.log"}`)
+ for i := 0; i < b.N; i++ {
+ go log.Debug("debug")
+ }
+ os.Remove("test4.log")
+}
diff --git a/logs/log.go b/logs/log.go
index cebbc737..ccaaa3ad 100644
--- a/logs/log.go
+++ b/logs/log.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package logs provide a general log interface
// Usage:
//
// import "github.com/astaxie/beego/logs"
@@ -34,8 +35,10 @@ package logs
import (
"fmt"
+ "os"
"path"
"runtime"
+ "strconv"
"sync"
)
@@ -60,10 +63,10 @@ const (
LevelWarn = LevelWarning
)
-type loggerType func() LoggerInterface
+type loggerType func() Logger
-// LoggerInterface defines the behavior of a log provider.
-type LoggerInterface interface {
+// Logger defines the behavior of a log provider.
+type Logger interface {
Init(config string) error
WriteMsg(msg string, level int) error
Destroy()
@@ -93,8 +96,13 @@ type BeeLogger struct {
enableFuncCallDepth bool
loggerFuncCallDepth int
asynchronous bool
- msg chan *logMsg
- outputs map[string]LoggerInterface
+ msgChan chan *logMsg
+ outputs []*nameLogger
+}
+
+type nameLogger struct {
+ Logger
+ name string
}
type logMsg struct {
@@ -102,59 +110,79 @@ type logMsg struct {
msg string
}
+var logMsgPool *sync.Pool
+
// NewLogger returns a new BeeLogger.
-// channellen means the number of messages in chan.
+// channelLen means the number of messages in chan(used where asynchronous is true).
// if the buffering chan is full, logger adapters write to file or other way.
-func NewLogger(channellen int64) *BeeLogger {
+func NewLogger(channelLen int64) *BeeLogger {
bl := new(BeeLogger)
bl.level = LevelDebug
bl.loggerFuncCallDepth = 2
- bl.msg = make(chan *logMsg, channellen)
- bl.outputs = make(map[string]LoggerInterface)
+ bl.msgChan = make(chan *logMsg, channelLen)
return bl
}
+// Async set the log to asynchronous and start the goroutine
func (bl *BeeLogger) Async() *BeeLogger {
bl.asynchronous = true
+ logMsgPool = &sync.Pool{
+ New: func() interface{} {
+ return &logMsg{}
+ },
+ }
go bl.startLogger()
return bl
}
// SetLogger provides a given logger adapter into BeeLogger with config string.
// config need to be correct JSON as string: {"interval":360}.
-func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
+func (bl *BeeLogger) SetLogger(adapterName string, config string) error {
bl.lock.Lock()
defer bl.lock.Unlock()
- if log, ok := adapters[adaptername]; ok {
+ if log, ok := adapters[adapterName]; ok {
lg := log()
err := lg.Init(config)
- bl.outputs[adaptername] = lg
if err != nil {
- fmt.Println("logs.BeeLogger.SetLogger: " + err.Error())
+ fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error())
return err
}
+ bl.outputs = append(bl.outputs, &nameLogger{name: adapterName, Logger: lg})
} else {
- return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adaptername)
+ return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName)
}
return nil
}
-// remove a logger adapter in BeeLogger.
-func (bl *BeeLogger) DelLogger(adaptername string) error {
+// DelLogger remove a logger adapter in BeeLogger.
+func (bl *BeeLogger) DelLogger(adapterName string) error {
bl.lock.Lock()
defer bl.lock.Unlock()
- if lg, ok := bl.outputs[adaptername]; ok {
- lg.Destroy()
- delete(bl.outputs, adaptername)
- return nil
- } else {
- return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adaptername)
+ outputs := []*nameLogger{}
+ for _, lg := range bl.outputs {
+ if lg.name == adapterName {
+ lg.Destroy()
+ } else {
+ outputs = append(outputs, lg)
+ }
+ }
+ if len(outputs) == len(bl.outputs) {
+ return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName)
+ }
+ bl.outputs = outputs
+ return nil
+}
+
+func (bl *BeeLogger) writeToLoggers(msg string, level int) {
+ for _, l := range bl.outputs {
+ err := l.WriteMsg(msg, level)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "unable to WriteMsg to adapter:%v,error:%v\n", l.name, err)
+ }
}
}
-func (bl *BeeLogger) writerMsg(loglevel int, msg string) error {
- lm := new(logMsg)
- lm.level = loglevel
+func (bl *BeeLogger) writeMsg(logLevel int, msg string) error {
if bl.enableFuncCallDepth {
_, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth)
if !ok {
@@ -162,43 +190,37 @@ func (bl *BeeLogger) writerMsg(loglevel int, msg string) error {
line = 0
}
_, filename := path.Split(file)
- lm.msg = fmt.Sprintf("[%s:%d] %s", filename, line, msg)
- } else {
- lm.msg = msg
+ msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "]" + msg
}
if bl.asynchronous {
- bl.msg <- lm
+ lm := logMsgPool.Get().(*logMsg)
+ lm.level = logLevel
+ lm.msg = msg
+ bl.msgChan <- lm
} else {
- for name, l := range bl.outputs {
- err := l.WriteMsg(lm.msg, lm.level)
- if err != nil {
- fmt.Println("unable to WriteMsg to adapter:", name, err)
- return err
- }
- }
+ bl.writeToLoggers(msg, logLevel)
}
return nil
}
-// Set log message level.
-//
+// SetLevel Set log message level.
// If message level (such as LevelDebug) is higher than logger level (such as LevelWarning),
// log providers will not even be sent the message.
func (bl *BeeLogger) SetLevel(l int) {
bl.level = l
}
-// set log funcCallDepth
+// SetLogFuncCallDepth set log funcCallDepth
func (bl *BeeLogger) SetLogFuncCallDepth(d int) {
bl.loggerFuncCallDepth = d
}
-// get log funcCallDepth for wrapper
+// GetLogFuncCallDepth return log funcCallDepth for wrapper
func (bl *BeeLogger) GetLogFuncCallDepth() int {
return bl.loggerFuncCallDepth
}
-// enable log funcCallDepth
+// EnableFuncCallDepth enable log funcCallDepth
func (bl *BeeLogger) EnableFuncCallDepth(b bool) {
bl.enableFuncCallDepth = b
}
@@ -208,137 +230,129 @@ func (bl *BeeLogger) EnableFuncCallDepth(b bool) {
func (bl *BeeLogger) startLogger() {
for {
select {
- case bm := <-bl.msg:
- for _, l := range bl.outputs {
- err := l.WriteMsg(bm.msg, bm.level)
- if err != nil {
- fmt.Println("ERROR, unable to WriteMsg:", err)
- }
- }
+ case bm := <-bl.msgChan:
+ bl.writeToLoggers(bm.msg, bm.level)
+ logMsgPool.Put(bm)
}
}
}
-// Log EMERGENCY level message.
+// Emergency Log EMERGENCY level message.
func (bl *BeeLogger) Emergency(format string, v ...interface{}) {
if LevelEmergency > bl.level {
return
}
msg := fmt.Sprintf("[M] "+format, v...)
- bl.writerMsg(LevelEmergency, msg)
+ bl.writeMsg(LevelEmergency, msg)
}
-// Log ALERT level message.
+// Alert Log ALERT level message.
func (bl *BeeLogger) Alert(format string, v ...interface{}) {
if LevelAlert > bl.level {
return
}
msg := fmt.Sprintf("[A] "+format, v...)
- bl.writerMsg(LevelAlert, msg)
+ bl.writeMsg(LevelAlert, msg)
}
-// Log CRITICAL level message.
+// Critical Log CRITICAL level message.
func (bl *BeeLogger) Critical(format string, v ...interface{}) {
if LevelCritical > bl.level {
return
}
msg := fmt.Sprintf("[C] "+format, v...)
- bl.writerMsg(LevelCritical, msg)
+ bl.writeMsg(LevelCritical, msg)
}
-// Log ERROR level message.
+// Error Log ERROR level message.
func (bl *BeeLogger) Error(format string, v ...interface{}) {
if LevelError > bl.level {
return
}
msg := fmt.Sprintf("[E] "+format, v...)
- bl.writerMsg(LevelError, msg)
+ bl.writeMsg(LevelError, msg)
}
-// Log WARNING level message.
+// Warning Log WARNING level message.
func (bl *BeeLogger) Warning(format string, v ...interface{}) {
if LevelWarning > bl.level {
return
}
msg := fmt.Sprintf("[W] "+format, v...)
- bl.writerMsg(LevelWarning, msg)
+ bl.writeMsg(LevelWarning, msg)
}
-// Log NOTICE level message.
+// Notice Log NOTICE level message.
func (bl *BeeLogger) Notice(format string, v ...interface{}) {
if LevelNotice > bl.level {
return
}
msg := fmt.Sprintf("[N] "+format, v...)
- bl.writerMsg(LevelNotice, msg)
+ bl.writeMsg(LevelNotice, msg)
}
-// Log INFORMATIONAL level message.
+// Informational Log INFORMATIONAL level message.
func (bl *BeeLogger) Informational(format string, v ...interface{}) {
if LevelInformational > bl.level {
return
}
msg := fmt.Sprintf("[I] "+format, v...)
- bl.writerMsg(LevelInformational, msg)
+ bl.writeMsg(LevelInformational, msg)
}
-// Log DEBUG level message.
+// Debug Log DEBUG level message.
func (bl *BeeLogger) Debug(format string, v ...interface{}) {
if LevelDebug > bl.level {
return
}
msg := fmt.Sprintf("[D] "+format, v...)
- bl.writerMsg(LevelDebug, msg)
+ bl.writeMsg(LevelDebug, msg)
}
-// Log WARN level message.
+// Warn Log WARN level message.
// compatibility alias for Warning()
func (bl *BeeLogger) Warn(format string, v ...interface{}) {
if LevelWarning > bl.level {
return
}
msg := fmt.Sprintf("[W] "+format, v...)
- bl.writerMsg(LevelWarning, msg)
+ bl.writeMsg(LevelWarning, msg)
}
-// Log INFO level message.
+// Info Log INFO level message.
// compatibility alias for Informational()
func (bl *BeeLogger) Info(format string, v ...interface{}) {
if LevelInformational > bl.level {
return
}
msg := fmt.Sprintf("[I] "+format, v...)
- bl.writerMsg(LevelInformational, msg)
+ bl.writeMsg(LevelInformational, msg)
}
-// Log TRACE level message.
+// Trace Log TRACE level message.
// compatibility alias for Debug()
func (bl *BeeLogger) Trace(format string, v ...interface{}) {
if LevelDebug > bl.level {
return
}
msg := fmt.Sprintf("[D] "+format, v...)
- bl.writerMsg(LevelDebug, msg)
+ bl.writeMsg(LevelDebug, msg)
}
-// flush all chan data.
+// Flush flush all chan data.
func (bl *BeeLogger) Flush() {
for _, l := range bl.outputs {
l.Flush()
}
}
-// close logger, flush all chan data and destroy all adapters in BeeLogger.
+// Close close logger, flush all chan data and destroy all adapters in BeeLogger.
func (bl *BeeLogger) Close() {
for {
- if len(bl.msg) > 0 {
- bm := <-bl.msg
- for _, l := range bl.outputs {
- err := l.WriteMsg(bm.msg, bm.level)
- if err != nil {
- fmt.Println("ERROR, unable to WriteMsg (while closing logger):", err)
- }
- }
+ if len(bl.msgChan) > 0 {
+ bm := <-bl.msgChan
+ bl.writeToLoggers(bm.msg, bm.level)
+ logMsgPool.Put(bm)
continue
}
break
diff --git a/logs/smtp.go b/logs/smtp.go
index 95123ebf..748462f9 100644
--- a/logs/smtp.go
+++ b/logs/smtp.go
@@ -24,31 +24,26 @@ import (
"time"
)
-const (
-// no usage
-// subjectPhrase = "Diagnostic message from server"
-)
-
-// smtpWriter implements LoggerInterface and is used to send emails via given SMTP-server.
-type SmtpWriter struct {
- Username string `json:"Username"`
+// SMTPWriter implements LoggerInterface and is used to send emails via given SMTP-server.
+type SMTPWriter struct {
+ Username string `json:"username"`
Password string `json:"password"`
- Host string `json:"Host"`
+ Host string `json:"host"`
Subject string `json:"subject"`
FromAddress string `json:"fromAddress"`
RecipientAddresses []string `json:"sendTos"`
Level int `json:"level"`
}
-// create smtp writer.
-func NewSmtpWriter() LoggerInterface {
- return &SmtpWriter{Level: LevelTrace}
+// NewSMTPWriter create smtp writer.
+func newSMTPWriter() Logger {
+ return &SMTPWriter{Level: LevelTrace}
}
-// init smtp writer with json config.
+// Init smtp writer with json config.
// config like:
// {
-// "Username":"example@gmail.com",
+// "username":"example@gmail.com",
// "password:"password",
// "host":"smtp.gmail.com:465",
// "subject":"email title",
@@ -56,7 +51,7 @@ func NewSmtpWriter() LoggerInterface {
// "sendTos":["email1","email2"],
// "level":LevelError
// }
-func (s *SmtpWriter) Init(jsonconfig string) error {
+func (s *SMTPWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), s)
if err != nil {
return err
@@ -64,7 +59,7 @@ func (s *SmtpWriter) Init(jsonconfig string) error {
return nil
}
-func (s *SmtpWriter) GetSmtpAuth(host string) smtp.Auth {
+func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth {
if len(strings.Trim(s.Username, " ")) == 0 && len(strings.Trim(s.Password, " ")) == 0 {
return nil
}
@@ -76,7 +71,7 @@ func (s *SmtpWriter) GetSmtpAuth(host string) smtp.Auth {
)
}
-func (s *SmtpWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error {
+func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error {
client, err := smtp.Dial(hostAddressWithPort)
if err != nil {
return err
@@ -129,9 +124,9 @@ func (s *SmtpWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd
return nil
}
-// write message in smtp writer.
+// WriteMsg write message in smtp writer.
// it will send an email with subject and only this message.
-func (s *SmtpWriter) WriteMsg(msg string, level int) error {
+func (s *SMTPWriter) WriteMsg(msg string, level int) error {
if level > s.Level {
return nil
}
@@ -139,27 +134,27 @@ func (s *SmtpWriter) WriteMsg(msg string, level int) error {
hp := strings.Split(s.Host, ":")
// Set up authentication information.
- auth := s.GetSmtpAuth(hp[0])
+ auth := s.getSMTPAuth(hp[0])
// Connect to the server, authenticate, set the sender and recipient,
// and send the email all in one step.
- content_type := "Content-Type: text/plain" + "; charset=UTF-8"
+ contentType := "Content-Type: text/plain" + "; charset=UTF-8"
mailmsg := []byte("To: " + strings.Join(s.RecipientAddresses, ";") + "\r\nFrom: " + s.FromAddress + "<" + s.FromAddress +
- ">\r\nSubject: " + s.Subject + "\r\n" + content_type + "\r\n\r\n" + fmt.Sprintf(".%s", time.Now().Format("2006-01-02 15:04:05")) + msg)
+ ">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", time.Now().Format("2006-01-02 15:04:05")) + msg)
return s.sendMail(s.Host, auth, s.FromAddress, s.RecipientAddresses, mailmsg)
}
-// implementing method. empty.
-func (s *SmtpWriter) Flush() {
+// Flush implementing method. empty.
+func (s *SMTPWriter) Flush() {
return
}
-// implementing method. empty.
-func (s *SmtpWriter) Destroy() {
+// Destroy implementing method. empty.
+func (s *SMTPWriter) Destroy() {
return
}
func init() {
- Register("smtp", NewSmtpWriter)
+ Register("smtp", newSMTPWriter)
}
diff --git a/memzipfile.go b/memzipfile.go
deleted file mode 100644
index cc5e3851..00000000
--- a/memzipfile.go
+++ /dev/null
@@ -1,212 +0,0 @@
-// Copyright 2014 beego Author. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package beego
-
-import (
- "bytes"
- "compress/flate"
- "compress/gzip"
- "errors"
- "io"
- "io/ioutil"
- "net/http"
- "os"
- "strings"
- "sync"
- "time"
-)
-
-var gmfim map[string]*memFileInfo = make(map[string]*memFileInfo)
-var lock sync.RWMutex
-
-// OpenMemZipFile returns MemFile object with a compressed static file.
-// it's used for serve static file if gzip enable.
-func openMemZipFile(path string, zip string) (*memFile, error) {
- osfile, e := os.Open(path)
- if e != nil {
- return nil, e
- }
- defer osfile.Close()
-
- osfileinfo, e := osfile.Stat()
- if e != nil {
- return nil, e
- }
-
- modtime := osfileinfo.ModTime()
- fileSize := osfileinfo.Size()
- lock.RLock()
- cfi, ok := gmfim[zip+":"+path]
- lock.RUnlock()
- if !(ok && cfi.ModTime() == modtime && cfi.fileSize == fileSize) {
- var content []byte
- if zip == "gzip" {
- var zipbuf bytes.Buffer
- gzipwriter, e := gzip.NewWriterLevel(&zipbuf, gzip.BestCompression)
- if e != nil {
- return nil, e
- }
- _, e = io.Copy(gzipwriter, osfile)
- gzipwriter.Close()
- if e != nil {
- return nil, e
- }
- content, e = ioutil.ReadAll(&zipbuf)
- if e != nil {
- return nil, e
- }
- } else if zip == "deflate" {
- var zipbuf bytes.Buffer
- deflatewriter, e := flate.NewWriter(&zipbuf, flate.BestCompression)
- if e != nil {
- return nil, e
- }
- _, e = io.Copy(deflatewriter, osfile)
- deflatewriter.Close()
- if e != nil {
- return nil, e
- }
- content, e = ioutil.ReadAll(&zipbuf)
- if e != nil {
- return nil, e
- }
- } else {
- content, e = ioutil.ReadAll(osfile)
- if e != nil {
- return nil, e
- }
- }
-
- cfi = &memFileInfo{osfileinfo, modtime, content, int64(len(content)), fileSize}
- lock.Lock()
- defer lock.Unlock()
- gmfim[zip+":"+path] = cfi
- }
- return &memFile{fi: cfi, offset: 0}, nil
-}
-
-// MemFileInfo contains a compressed file bytes and file information.
-// it implements os.FileInfo interface.
-type memFileInfo struct {
- os.FileInfo
- modTime time.Time
- content []byte
- contentSize int64
- fileSize int64
-}
-
-// Name returns the compressed filename.
-func (fi *memFileInfo) Name() string {
- return fi.Name()
-}
-
-// Size returns the raw file content size, not compressed size.
-func (fi *memFileInfo) Size() int64 {
- return fi.contentSize
-}
-
-// Mode returns file mode.
-func (fi *memFileInfo) Mode() os.FileMode {
- return fi.Mode()
-}
-
-// ModTime returns the last modified time of raw file.
-func (fi *memFileInfo) ModTime() time.Time {
- return fi.modTime
-}
-
-// IsDir returns the compressing file is a directory or not.
-func (fi *memFileInfo) IsDir() bool {
- return fi.IsDir()
-}
-
-// return nil. implement the os.FileInfo interface method.
-func (fi *memFileInfo) Sys() interface{} {
- return nil
-}
-
-// MemFile contains MemFileInfo and bytes offset when reading.
-// it implements io.Reader,io.ReadCloser and io.Seeker.
-type memFile struct {
- fi *memFileInfo
- offset int64
-}
-
-// Close memfile.
-func (f *memFile) Close() error {
- return nil
-}
-
-// Get os.FileInfo of memfile.
-func (f *memFile) Stat() (os.FileInfo, error) {
- return f.fi, nil
-}
-
-// read os.FileInfo of files in directory of memfile.
-// it returns empty slice.
-func (f *memFile) Readdir(count int) ([]os.FileInfo, error) {
- infos := []os.FileInfo{}
-
- return infos, nil
-}
-
-// Read bytes from the compressed file bytes.
-func (f *memFile) Read(p []byte) (n int, err error) {
- if len(f.fi.content)-int(f.offset) >= len(p) {
- n = len(p)
- } else {
- n = len(f.fi.content) - int(f.offset)
- err = io.EOF
- }
- copy(p, f.fi.content[f.offset:f.offset+int64(n)])
- f.offset += int64(n)
- return
-}
-
-var errWhence = errors.New("Seek: invalid whence")
-var errOffset = errors.New("Seek: invalid offset")
-
-// Read bytes from the compressed file bytes by seeker.
-func (f *memFile) Seek(offset int64, whence int) (ret int64, err error) {
- switch whence {
- default:
- return 0, errWhence
- case os.SEEK_SET:
- case os.SEEK_CUR:
- offset += f.offset
- case os.SEEK_END:
- offset += int64(len(f.fi.content))
- }
- if offset < 0 || int(offset) > len(f.fi.content) {
- return 0, errOffset
- }
- f.offset = offset
- return f.offset, nil
-}
-
-// GetAcceptEncodingZip returns accept encoding format in http header.
-// zip is first, then deflate if both accepted.
-// If no accepted, return empty string.
-func getAcceptEncodingZip(r *http.Request) string {
- ss := r.Header.Get("Accept-Encoding")
- ss = strings.ToLower(ss)
- if strings.Contains(ss, "gzip") {
- return "gzip"
- } else if strings.Contains(ss, "deflate") {
- return "deflate"
- } else {
- return ""
- }
-}
diff --git a/middleware/i18n.go b/middleware/i18n.go
deleted file mode 100644
index f54b4bb5..00000000
--- a/middleware/i18n.go
+++ /dev/null
@@ -1,71 +0,0 @@
-// Copyright 2014 beego Author. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Usage:
-//
-// import "github.com/astaxie/beego/middleware"
-//
-// I18N = middleware.NewLocale("conf/i18n.conf", beego.AppConfig.String("language"))
-//
-// more docs: http://beego.me/docs/module/i18n.md
-package middleware
-
-import (
- "encoding/json"
- "io/ioutil"
- "os"
-)
-
-type Translation struct {
- filepath string
- CurrentLocal string
- Locales map[string]map[string]string
-}
-
-func NewLocale(filepath string, defaultlocal string) *Translation {
- file, err := os.Open(filepath)
- if err != nil {
- panic("open " + filepath + " err :" + err.Error())
- }
- data, err := ioutil.ReadAll(file)
- if err != nil {
- panic("read " + filepath + " err :" + err.Error())
- }
-
- i18n := make(map[string]map[string]string)
- if err = json.Unmarshal(data, &i18n); err != nil {
- panic("json.Unmarshal " + filepath + " err :" + err.Error())
- }
- return &Translation{
- filepath: filepath,
- CurrentLocal: defaultlocal,
- Locales: i18n,
- }
-}
-
-func (t *Translation) SetLocale(local string) {
- t.CurrentLocal = local
-}
-
-func (t *Translation) Translate(key string, local string) string {
- if local == "" {
- local = t.CurrentLocal
- }
- if ct, ok := t.Locales[key]; ok {
- if v, o := ct[local]; o {
- return v
- }
- }
- return key
-}
diff --git a/migration/ddl.go b/migration/ddl.go
index f9b60117..51243337 100644
--- a/migration/ddl.go
+++ b/migration/ddl.go
@@ -14,33 +14,40 @@
package migration
+// Table store the tablename and Column
type Table struct {
TableName string
Columns []*Column
}
+// Create return the create sql
func (t *Table) Create() string {
return ""
}
+// Drop return the drop sql
func (t *Table) Drop() string {
return ""
}
+// Column define the columns name type and Default
type Column struct {
Name string
Type string
Default interface{}
}
+// Create return create sql with the provided tbname and columns
func Create(tbname string, columns ...Column) string {
return ""
}
+// Drop return the drop sql with the provided tbname and columns
func Drop(tbname string, columns ...Column) string {
return ""
}
+// TableDDL is still in think
func TableDDL(tbname string, columns ...Column) string {
return ""
}
diff --git a/migration/migration.go b/migration/migration.go
index d64d60d3..1591bc50 100644
--- a/migration/migration.go
+++ b/migration/migration.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// migration package for migration
+// Package migration is used for migration
//
// The table structure is as follow:
//
@@ -39,8 +39,8 @@ import (
// const the data format for the bee generate migration datatype
const (
- M_DATE_FORMAT = "20060102_150405"
- M_DB_DATE_FORMAT = "2006-01-02 15:04:05"
+ DateFormat = "20060102_150405"
+ DBDateFormat = "2006-01-02 15:04:05"
)
// Migrationer is an interface for all Migration struct
@@ -60,24 +60,24 @@ func init() {
migrationMap = make(map[string]Migrationer)
}
-// the basic type which will implement the basic type
+// Migration the basic type which will implement the basic type
type Migration struct {
sqls []string
Created string
}
-// implement in the Inheritance struct for upgrade
+// Up implement in the Inheritance struct for upgrade
func (m *Migration) Up() {
}
-// implement in the Inheritance struct for down
+// Down implement in the Inheritance struct for down
func (m *Migration) Down() {
}
-// add sql want to execute
-func (m *Migration) Sql(sql string) {
+// SQL add sql want to execute
+func (m *Migration) SQL(sql string) {
m.sqls = append(m.sqls, sql)
}
@@ -86,7 +86,7 @@ func (m *Migration) Reset() {
m.sqls = make([]string, 0)
}
-// execute the sql already add in the sql
+// Exec execute the sql already add in the sql
func (m *Migration) Exec(name, status string) error {
o := orm.NewOrm()
for _, s := range m.sqls {
@@ -104,33 +104,32 @@ func (m *Migration) addOrUpdateRecord(name, status string) error {
o := orm.NewOrm()
if status == "down" {
status = "rollback"
- p, err := o.Raw("update migrations set `status` = ?, `rollback_statements` = ?, `created_at` = ? where name = ?").Prepare()
+ p, err := o.Raw("update migrations set status = ?, rollback_statements = ?, created_at = ? where name = ?").Prepare()
if err != nil {
return nil
}
- _, err = p.Exec(status, strings.Join(m.sqls, "; "), time.Now().Format(M_DB_DATE_FORMAT), name)
- return err
- } else {
- status = "update"
- p, err := o.Raw("insert into migrations(`name`, `created_at`, `statements`, `status`) values(?,?,?,?)").Prepare()
- if err != nil {
- return err
- }
- _, err = p.Exec(name, time.Now().Format(M_DB_DATE_FORMAT), strings.Join(m.sqls, "; "), status)
+ _, err = p.Exec(status, strings.Join(m.sqls, "; "), time.Now().Format(DBDateFormat), name)
return err
}
+ status = "update"
+ p, err := o.Raw("insert into migrations(name, created_at, statements, status) values(?,?,?,?)").Prepare()
+ if err != nil {
+ return err
+ }
+ _, err = p.Exec(name, time.Now().Format(DBDateFormat), strings.Join(m.sqls, "; "), status)
+ return err
}
-// get the unixtime from the Created
+// GetCreated get the unixtime from the Created
func (m *Migration) GetCreated() int64 {
- t, err := time.Parse(M_DATE_FORMAT, m.Created)
+ t, err := time.Parse(DateFormat, m.Created)
if err != nil {
return 0
}
return t.Unix()
}
-// register the Migration in the map
+// Register register the Migration in the map
func Register(name string, m Migrationer) error {
if _, ok := migrationMap[name]; ok {
return errors.New("already exist name:" + name)
@@ -139,7 +138,7 @@ func Register(name string, m Migrationer) error {
return nil
}
-// upgrate the migration from lasttime
+// Upgrade upgrate the migration from lasttime
func Upgrade(lasttime int64) error {
sm := sortMap(migrationMap)
i := 0
@@ -163,7 +162,7 @@ func Upgrade(lasttime int64) error {
return nil
}
-//rollback the migration by the name
+// Rollback rollback the migration by the name
func Rollback(name string) error {
if v, ok := migrationMap[name]; ok {
beego.Info("start rollback")
@@ -178,14 +177,13 @@ func Rollback(name string) error {
beego.Info("end rollback")
time.Sleep(2 * time.Second)
return nil
- } else {
- beego.Error("not exist the migrationMap name:" + name)
- time.Sleep(2 * time.Second)
- return errors.New("not exist the migrationMap name:" + name)
}
+ beego.Error("not exist the migrationMap name:" + name)
+ time.Sleep(2 * time.Second)
+ return errors.New("not exist the migrationMap name:" + name)
}
-// reset all migration
+// Reset reset all migration
// run all migration's down function
func Reset() error {
sm := sortMap(migrationMap)
@@ -214,7 +212,7 @@ func Reset() error {
return nil
}
-// first Reset, then Upgrade
+// Refresh first Reset, then Upgrade
func Refresh() error {
err := Reset()
if err != nil {
diff --git a/mime.go b/mime.go
index 20246c21..e85fcb2a 100644
--- a/mime.go
+++ b/mime.go
@@ -14,11 +14,7 @@
package beego
-import (
- "mime"
-)
-
-var mimemaps map[string]string = map[string]string{
+var mimemaps = map[string]string{
".3dm": "x-world/x-3dmf",
".3dmf": "x-world/x-3dmf",
".7z": "application/x-7z-compressed",
@@ -558,10 +554,3 @@ var mimemaps map[string]string = map[string]string{
".oex": "application/x-opera-extension",
".mustache": "text/html",
}
-
-func initMime() error {
- for k, v := range mimemaps {
- mime.AddExtensionType(k, v)
- }
- return nil
-}
diff --git a/namespace.go b/namespace.go
index ebb7c14f..0dfdd7af 100644
--- a/namespace.go
+++ b/namespace.go
@@ -23,16 +23,17 @@ import (
type namespaceCond func(*beecontext.Context) bool
-type innnerNamespace func(*Namespace)
+// LinkNamespace used as link action
+type LinkNamespace func(*Namespace)
// Namespace is store all the info
type Namespace struct {
prefix string
- handlers *ControllerRegistor
+ handlers *ControllerRegister
}
-// get new Namespace
-func NewNamespace(prefix string, params ...innnerNamespace) *Namespace {
+// NewNamespace get new Namespace
+func NewNamespace(prefix string, params ...LinkNamespace) *Namespace {
ns := &Namespace{
prefix: prefix,
handlers: NewControllerRegister(),
@@ -43,7 +44,7 @@ func NewNamespace(prefix string, params ...innnerNamespace) *Namespace {
return ns
}
-// set condtion function
+// Cond set condtion function
// if cond return true can run this namespace, else can't
// usage:
// ns.Cond(func (ctx *context.Context) bool{
@@ -72,7 +73,7 @@ func (n *Namespace) Cond(cond namespaceCond) *Namespace {
return n
}
-// add filter in the Namespace
+// Filter add filter in the Namespace
// action has before & after
// FilterFunc
// usage:
@@ -95,98 +96,98 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace {
return n
}
-// same as beego.Rourer
+// Router same as beego.Rourer
// refer: https://godoc.org/github.com/astaxie/beego#Router
func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace {
n.handlers.Add(rootpath, c, mappingMethods...)
return n
}
-// same as beego.AutoRouter
+// AutoRouter same as beego.AutoRouter
// refer: https://godoc.org/github.com/astaxie/beego#AutoRouter
func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace {
n.handlers.AddAuto(c)
return n
}
-// same as beego.AutoPrefix
+// AutoPrefix same as beego.AutoPrefix
// refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix
func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace {
n.handlers.AddAutoPrefix(prefix, c)
return n
}
-// same as beego.Get
+// Get same as beego.Get
// refer: https://godoc.org/github.com/astaxie/beego#Get
func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace {
n.handlers.Get(rootpath, f)
return n
}
-// same as beego.Post
+// Post same as beego.Post
// refer: https://godoc.org/github.com/astaxie/beego#Post
func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace {
n.handlers.Post(rootpath, f)
return n
}
-// same as beego.Delete
+// Delete same as beego.Delete
// refer: https://godoc.org/github.com/astaxie/beego#Delete
func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace {
n.handlers.Delete(rootpath, f)
return n
}
-// same as beego.Put
+// Put same as beego.Put
// refer: https://godoc.org/github.com/astaxie/beego#Put
func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace {
n.handlers.Put(rootpath, f)
return n
}
-// same as beego.Head
+// Head same as beego.Head
// refer: https://godoc.org/github.com/astaxie/beego#Head
func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace {
n.handlers.Head(rootpath, f)
return n
}
-// same as beego.Options
+// Options same as beego.Options
// refer: https://godoc.org/github.com/astaxie/beego#Options
func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace {
n.handlers.Options(rootpath, f)
return n
}
-// same as beego.Patch
+// Patch same as beego.Patch
// refer: https://godoc.org/github.com/astaxie/beego#Patch
func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace {
n.handlers.Patch(rootpath, f)
return n
}
-// same as beego.Any
+// Any same as beego.Any
// refer: https://godoc.org/github.com/astaxie/beego#Any
func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace {
n.handlers.Any(rootpath, f)
return n
}
-// same as beego.Handler
+// Handler same as beego.Handler
// refer: https://godoc.org/github.com/astaxie/beego#Handler
func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace {
n.handlers.Handler(rootpath, h)
return n
}
-// add include class
+// Include add include class
// refer: https://godoc.org/github.com/astaxie/beego#Include
func (n *Namespace) Include(cList ...ControllerInterface) *Namespace {
n.handlers.Include(cList...)
return n
}
-// nest Namespace
+// Namespace add nest Namespace
// usage:
//ns := beego.NewNamespace(“/v1”).
//Namespace(
@@ -230,7 +231,7 @@ func (n *Namespace) Namespace(ns ...*Namespace) *Namespace {
return n
}
-// register Namespace into beego.Handler
+// AddNamespace register Namespace into beego.Handler
// support multi Namespace
func AddNamespace(nl ...*Namespace) {
for _, n := range nl {
@@ -275,113 +276,113 @@ func addPrefix(t *Tree, prefix string) {
}
-// Namespace Condition
-func NSCond(cond namespaceCond) innnerNamespace {
+// NSCond is Namespace Condition
+func NSCond(cond namespaceCond) LinkNamespace {
return func(ns *Namespace) {
ns.Cond(cond)
}
}
-// Namespace BeforeRouter filter
-func NSBefore(filiterList ...FilterFunc) innnerNamespace {
+// NSBefore Namespace BeforeRouter filter
+func NSBefore(filiterList ...FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Filter("before", filiterList...)
}
}
-// Namespace FinishRouter filter
-func NSAfter(filiterList ...FilterFunc) innnerNamespace {
+// NSAfter add Namespace FinishRouter filter
+func NSAfter(filiterList ...FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Filter("after", filiterList...)
}
}
-// Namespace Include ControllerInterface
-func NSInclude(cList ...ControllerInterface) innnerNamespace {
+// NSInclude Namespace Include ControllerInterface
+func NSInclude(cList ...ControllerInterface) LinkNamespace {
return func(ns *Namespace) {
ns.Include(cList...)
}
}
-// Namespace Router
-func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) innnerNamespace {
+// NSRouter call Namespace Router
+func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace {
return func(ns *Namespace) {
ns.Router(rootpath, c, mappingMethods...)
}
}
-// Namespace Get
-func NSGet(rootpath string, f FilterFunc) innnerNamespace {
+// NSGet call Namespace Get
+func NSGet(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Get(rootpath, f)
}
}
-// Namespace Post
-func NSPost(rootpath string, f FilterFunc) innnerNamespace {
+// NSPost call Namespace Post
+func NSPost(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Post(rootpath, f)
}
}
-// Namespace Head
-func NSHead(rootpath string, f FilterFunc) innnerNamespace {
+// NSHead call Namespace Head
+func NSHead(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Head(rootpath, f)
}
}
-// Namespace Put
-func NSPut(rootpath string, f FilterFunc) innnerNamespace {
+// NSPut call Namespace Put
+func NSPut(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Put(rootpath, f)
}
}
-// Namespace Delete
-func NSDelete(rootpath string, f FilterFunc) innnerNamespace {
+// NSDelete call Namespace Delete
+func NSDelete(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Delete(rootpath, f)
}
}
-// Namespace Any
-func NSAny(rootpath string, f FilterFunc) innnerNamespace {
+// NSAny call Namespace Any
+func NSAny(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Any(rootpath, f)
}
}
-// Namespace Options
-func NSOptions(rootpath string, f FilterFunc) innnerNamespace {
+// NSOptions call Namespace Options
+func NSOptions(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Options(rootpath, f)
}
}
-// Namespace Patch
-func NSPatch(rootpath string, f FilterFunc) innnerNamespace {
+// NSPatch call Namespace Patch
+func NSPatch(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Patch(rootpath, f)
}
}
-//Namespace AutoRouter
-func NSAutoRouter(c ControllerInterface) innnerNamespace {
+// NSAutoRouter call Namespace AutoRouter
+func NSAutoRouter(c ControllerInterface) LinkNamespace {
return func(ns *Namespace) {
ns.AutoRouter(c)
}
}
-// Namespace AutoPrefix
-func NSAutoPrefix(prefix string, c ControllerInterface) innnerNamespace {
+// NSAutoPrefix call Namespace AutoPrefix
+func NSAutoPrefix(prefix string, c ControllerInterface) LinkNamespace {
return func(ns *Namespace) {
ns.AutoPrefix(prefix, c)
}
}
-// Namespace add sub Namespace
-func NSNamespace(prefix string, params ...innnerNamespace) innnerNamespace {
+// NSNamespace add sub Namespace
+func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace {
return func(ns *Namespace) {
n := NewNamespace(prefix, params...)
ns.Namespace(n)
diff --git a/orm/cmd.go b/orm/cmd.go
index 2358ef3c..3638a75c 100644
--- a/orm/cmd.go
+++ b/orm/cmd.go
@@ -46,7 +46,7 @@ func printHelp(errs ...string) {
os.Exit(2)
}
-// listen for orm command and then run it if command arguments passed.
+// RunCommand listen for orm command and then run it if command arguments passed.
func RunCommand() {
if len(os.Args) < 2 || os.Args[1] != "orm" {
return
@@ -100,7 +100,7 @@ func (d *commandSyncDb) Parse(args []string) {
func (d *commandSyncDb) Run() error {
var drops []string
if d.force {
- drops = getDbDropSql(d.al)
+ drops = getDbDropSQL(d.al)
}
db := d.al.DB
@@ -124,7 +124,7 @@ func (d *commandSyncDb) Run() error {
}
}
- sqls, indexes := getDbCreateSql(d.al)
+ sqls, indexes := getDbCreateSQL(d.al)
tables, err := d.al.DbBaser.GetTables(db)
if err != nil {
@@ -180,7 +180,7 @@ func (d *commandSyncDb) Run() error {
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
}
- query := idx.Sql
+ query := idx.SQL
_, err := db.Exec(query)
if d.verbose {
fmt.Printf(" %s\n", query)
@@ -203,7 +203,7 @@ func (d *commandSyncDb) Run() error {
queries := []string{sqls[i]}
for _, idx := range indexes[mi.table] {
- queries = append(queries, idx.Sql)
+ queries = append(queries, idx.SQL)
}
for _, query := range queries {
@@ -228,12 +228,12 @@ func (d *commandSyncDb) Run() error {
}
// database creation commander interface implement.
-type commandSqlAll struct {
+type commandSQLAll struct {
al *alias
}
// parse orm command line arguments.
-func (d *commandSqlAll) Parse(args []string) {
+func (d *commandSQLAll) Parse(args []string) {
var name string
flagSet := flag.NewFlagSet("orm command: sqlall", flag.ExitOnError)
@@ -244,13 +244,13 @@ func (d *commandSqlAll) Parse(args []string) {
}
// run orm line command.
-func (d *commandSqlAll) Run() error {
- sqls, indexes := getDbCreateSql(d.al)
+func (d *commandSQLAll) Run() error {
+ sqls, indexes := getDbCreateSQL(d.al)
var all []string
for i, mi := range modelCache.allOrdered() {
queries := []string{sqls[i]}
for _, idx := range indexes[mi.table] {
- queries = append(queries, idx.Sql)
+ queries = append(queries, idx.SQL)
}
sql := strings.Join(queries, "\n")
all = append(all, sql)
@@ -262,10 +262,10 @@ func (d *commandSqlAll) Run() error {
func init() {
commands["syncdb"] = new(commandSyncDb)
- commands["sqlall"] = new(commandSqlAll)
+ commands["sqlall"] = new(commandSQLAll)
}
-// run syncdb command line.
+// RunSyncdb run syncdb command line.
// name means table's alias name. default is "default".
// force means run next sql if the current is error.
// verbose means show all info when running command or not.
diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go
index ea105624..da0ee8ab 100644
--- a/orm/cmd_utils.go
+++ b/orm/cmd_utils.go
@@ -23,11 +23,11 @@ import (
type dbIndex struct {
Table string
Name string
- Sql string
+ SQL string
}
// create database drop sql.
-func getDbDropSql(al *alias) (sqls []string) {
+func getDbDropSQL(al *alias) (sqls []string) {
if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model")
os.Exit(2)
@@ -45,13 +45,14 @@ func getDbDropSql(al *alias) (sqls []string) {
func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
T := al.DbBaser.DbTypes()
fieldType := fi.fieldType
+ fieldSize := fi.size
checkColumn:
switch fieldType {
case TypeBooleanField:
col = T["bool"]
case TypeCharField:
- col = fmt.Sprintf(T["string"], fi.size)
+ col = fmt.Sprintf(T["string"], fieldSize)
case TypeTextField:
col = T["string-text"]
case TypeDateField:
@@ -65,7 +66,7 @@ checkColumn:
case TypeIntegerField:
col = T["int32"]
case TypeBigIntegerField:
- if al.Driver == DR_Sqlite {
+ if al.Driver == DRSqlite {
fieldType = TypeIntegerField
goto checkColumn
}
@@ -89,6 +90,7 @@ checkColumn:
}
case RelForeignKey, RelOneToOne:
fieldType = fi.relModelInfo.fields.pk.fieldType
+ fieldSize = fi.relModelInfo.fields.pk.size
goto checkColumn
}
@@ -104,15 +106,15 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string {
typ += " " + "NOT NULL"
}
- return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s",
- Q, fi.mi.table, Q,
- Q, fi.column, Q,
+ return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s",
+ Q, fi.mi.table, Q,
+ Q, fi.column, Q,
typ, getColumnDefault(fi),
)
}
// create database creation string.
-func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
+func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model")
os.Exit(2)
@@ -142,7 +144,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
if fi.auto {
switch al.Driver {
- case DR_Sqlite, DR_Postgres:
+ case DRSqlite, DRPostgres:
column += T["auto"]
default:
column += col + " " + T["auto"]
@@ -159,7 +161,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
//if fi.initial.String() != "" {
// column += " DEFAULT " + fi.initial.String()
//}
-
+
// Append attribute DEFAULT
column += getColumnDefault(fi)
@@ -201,7 +203,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
sql += strings.Join(columns, ",\n")
sql += "\n)"
- if al.Driver == DR_MySQL {
+ if al.Driver == DRMySQL {
var engine string
if mi.model != nil {
engine = getTableEngine(mi.addrField)
@@ -237,7 +239,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
index := dbIndex{}
index.Table = mi.table
index.Name = name
- index.Sql = sql
+ index.SQL = sql
tableIndexes[mi.table] = append(tableIndexes[mi.table], index)
}
@@ -247,7 +249,6 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
return
}
-
// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands
func getColumnDefault(fi *fieldInfo) string {
var (
@@ -263,16 +264,20 @@ func getColumnDefault(fi *fieldInfo) string {
// These defaults will be useful if there no config value orm:"default" and NOT NULL is on
switch fi.fieldType {
- case TypeDateField, TypeDateTimeField:
- return v;
-
- case TypeBooleanField, TypeBitField, TypeSmallIntegerField, TypeIntegerField,
- TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField,
+ case TypeDateField, TypeDateTimeField, TypeTextField:
+ return v
+
+ case TypeBitField, TypeSmallIntegerField, TypeIntegerField,
+ TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField,
TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField,
TypeDecimalField:
- d = "0"
+ t = " DEFAULT %s "
+ d = "0"
+ case TypeBooleanField:
+ t = " DEFAULT %s "
+ d = "FALSE"
}
-
+
if fi.colDefault {
if !fi.initial.Exist() {
v = fmt.Sprintf(t, "")
diff --git a/orm/db.go b/orm/db.go
index 20dc80f2..b62c165b 100644
--- a/orm/db.go
+++ b/orm/db.go
@@ -24,12 +24,13 @@ import (
)
const (
- format_Date = "2006-01-02"
- format_DateTime = "2006-01-02 15:04:05"
+ formatDate = "2006-01-02"
+ formatDateTime = "2006-01-02 15:04:05"
)
var (
- ErrMissPK = errors.New("missed pk value") // missing pk error
+ // ErrMissPK missing pk error
+ ErrMissPK = errors.New("missed pk value")
)
var (
@@ -216,14 +217,14 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
}
}
if fi.null == false && value == nil {
- return nil, errors.New(fmt.Sprintf("field `%s` cannot be NULL", fi.fullName))
+ return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName)
}
}
}
}
switch fi.fieldType {
case TypeDateField, TypeDateTimeField:
- if fi.auto_now || fi.auto_now_add && insert {
+ if fi.autoNow || fi.autoNowAdd && insert {
if insert {
if t, ok := value.(time.Time); ok && !t.IsZero() {
break
@@ -282,13 +283,12 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
var id int64
err := row.Scan(&id)
return id, err
- } else {
- if res, err := stmt.Exec(values...); err == nil {
- return res.LastInsertId()
- } else {
- return 0, err
- }
}
+ res, err := stmt.Exec(values...)
+ if err == nil {
+ return res.LastInsertId()
+ }
+ return 0, err
}
// query sql ,read records and persist in dbBaser.
@@ -339,15 +339,11 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
return ErrNoRows
}
return err
- } else {
- elm := reflect.New(mi.addrField.Elem().Type())
- mind := reflect.Indirect(elm)
-
- d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz)
-
- ind.Set(mind)
}
-
+ elm := reflect.New(mi.addrField.Elem().Type())
+ mind := reflect.Indirect(elm)
+ d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz)
+ ind.Set(mind)
return nil
}
@@ -444,20 +440,19 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
- if res, err := q.Exec(query, values...); err == nil {
+ res, err := q.Exec(query, values...)
+ if err == nil {
if isMulti {
return res.RowsAffected()
}
return res.LastInsertId()
- } else {
- return 0, err
}
- } else {
- row := q.QueryRow(query, values...)
- var id int64
- err := row.Scan(&id)
- return id, err
+ return 0, err
}
+ row := q.QueryRow(query, values...)
+ var id int64
+ err := row.Scan(&id)
+ return id, err
}
// execute update sql dbQuerier with given struct reflect.Value.
@@ -493,11 +488,11 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
d.ins.ReplaceMarks(&query)
- if res, err := q.Exec(query, setValues...); err == nil {
+ res, err := q.Exec(query, setValues...)
+ if err == nil {
return res.RowsAffected()
- } else {
- return 0, err
}
+ return 0, err
}
// execute delete sql dbQuerier with given struct reflect.Value.
@@ -513,14 +508,12 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, pkName, Q)
d.ins.ReplaceMarks(&query)
-
- if res, err := q.Exec(query, pkValue); err == nil {
-
+ res, err := q.Exec(query, pkValue)
+ if err == nil {
num, err := res.RowsAffected()
if err != nil {
return 0, err
}
-
if num > 0 {
if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
@@ -529,17 +522,14 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
}
}
-
err := d.deleteRels(q, mi, []interface{}{pkValue}, tz)
if err != nil {
return num, err
}
}
-
return num, err
- } else {
- return 0, err
}
+ return 0, err
}
// update table-related record by querySet.
@@ -565,11 +555,11 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
tables.parseRelated(qs.related, qs.relDepth)
}
- where, args := tables.getCondSql(cond, false, tz)
+ where, args := tables.getCondSQL(cond, false, tz)
values = append(values, args...)
- join := tables.getJoinSql()
+ join := tables.getJoinSQL()
var query, T string
@@ -585,13 +575,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
col := fmt.Sprintf("%s%s%s%s", T, Q, v, Q)
if c, ok := values[i].(colValue); ok {
switch c.opt {
- case Col_Add:
+ case ColAdd:
cols = append(cols, col+" = "+col+" + ?")
- case Col_Minus:
+ case ColMinus:
cols = append(cols, col+" = "+col+" - ?")
- case Col_Multiply:
+ case ColMultiply:
cols = append(cols, col+" = "+col+" * ?")
- case Col_Except:
+ case ColExcept:
cols = append(cols, col+" = "+col+" / ?")
}
values[i] = c.value
@@ -610,12 +600,11 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
}
d.ins.ReplaceMarks(&query)
-
- if res, err := q.Exec(query, values...); err == nil {
+ res, err := q.Exec(query, values...)
+ if err == nil {
return res.RowsAffected()
- } else {
- return 0, err
}
+ return 0, err
}
// delete related records.
@@ -624,23 +613,23 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
for _, fi := range mi.fields.fieldsReverse {
fi = fi.reverseFieldInfo
switch fi.onDelete {
- case od_CASCADE:
+ case odCascade:
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
_, err := d.DeleteBatch(q, nil, fi.mi, cond, tz)
if err != nil {
return err
}
- case od_SET_DEFAULT, od_SET_NULL:
+ case odSetDefault, odSetNULL:
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
params := Params{fi.column: nil}
- if fi.onDelete == od_SET_DEFAULT {
+ if fi.onDelete == odSetDefault {
params[fi.column] = fi.initial.String()
}
_, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz)
if err != nil {
return err
}
- case od_DO_NOTHING:
+ case odDoNothing:
}
}
return nil
@@ -661,8 +650,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
Q := d.ins.TableQuote()
- where, args := tables.getCondSql(cond, false, tz)
- join := tables.getJoinSql()
+ where, args := tables.getCondSQL(cond, false, tz)
+ join := tables.getJoinSQL()
cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q)
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where)
@@ -670,16 +659,14 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
d.ins.ReplaceMarks(&query)
var rs *sql.Rows
- if r, err := q.Query(query, args...); err != nil {
+ r, err := q.Query(query, args...)
+ if err != nil {
return 0, err
- } else {
- rs = r
}
-
+ rs = r
defer rs.Close()
var ref interface{}
-
args = make([]interface{}, 0)
cnt := 0
for rs.Next() {
@@ -702,24 +689,21 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql)
d.ins.ReplaceMarks(&query)
-
- if res, err := q.Exec(query, args...); err == nil {
+ res, err := q.Exec(query, args...)
+ if err == nil {
num, err := res.RowsAffected()
if err != nil {
return 0, err
}
-
if num > 0 {
err := d.deleteRels(q, mi, args, tz)
if err != nil {
return num, err
}
}
-
return num, nil
- } else {
- return 0, err
}
+ return 0, err
}
// read related records.
@@ -801,10 +785,11 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
- where, args := tables.getCondSql(cond, false, tz)
- orderBy := tables.getOrderSql(qs.orders)
- limit := tables.getLimitSql(mi, offset, rlimit)
- join := tables.getJoinSql()
+ where, args := tables.getCondSQL(cond, false, tz)
+ groupBy := tables.getGroupSQL(qs.groups)
+ orderBy := tables.getOrderSQL(qs.orders)
+ limit := tables.getLimitSQL(mi, offset, rlimit)
+ join := tables.getJoinSQL()
for _, tbl := range tables.tables {
if tbl.sel {
@@ -814,16 +799,20 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
}
}
- query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit)
+ sqlSelect := "SELECT"
+ if qs.distinct {
+ sqlSelect += " DISTINCT"
+ }
+ query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit)
d.ins.ReplaceMarks(&query)
var rs *sql.Rows
- if r, err := q.Query(query, args...); err != nil {
+ r, err := q.Query(query, args...)
+ if err != nil {
return 0, err
- } else {
- rs = r
}
+ rs = r
refs := make([]interface{}, colsNum)
for i := range refs {
@@ -937,9 +926,9 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
- where, args := tables.getCondSql(cond, false, tz)
- tables.getOrderSql(qs.orders)
- join := tables.getJoinSql()
+ where, args := tables.getCondSQL(cond, false, tz)
+ tables.getOrderSQL(qs.orders)
+ join := tables.getJoinSQL()
Q := d.ins.TableQuote()
@@ -954,7 +943,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
}
// generate sql with replacing operator string placeholders and replaced values.
-func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
+func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
sql := ""
params := getFlatParams(fi, args, tz)
@@ -979,7 +968,7 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
if len(params) > 1 {
panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params)))
}
- sql = d.ins.OperatorSql(operator)
+ sql = d.ins.OperatorSQL(operator)
switch operator {
case "exact":
if arg == nil {
@@ -1107,12 +1096,12 @@ setValue:
)
if len(s) >= 19 {
s = s[:19]
- t, err = time.ParseInLocation(format_DateTime, s, tz)
+ t, err = time.ParseInLocation(formatDateTime, s, tz)
} else {
if len(s) > 10 {
s = s[:10]
}
- t, err = time.ParseInLocation(format_Date, s, tz)
+ t, err = time.ParseInLocation(formatDate, s, tz)
}
t = t.In(DefaultTimeLoc)
@@ -1443,24 +1432,22 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
}
}
- where, args := tables.getCondSql(cond, false, tz)
- orderBy := tables.getOrderSql(qs.orders)
- limit := tables.getLimitSql(mi, qs.offset, qs.limit)
- join := tables.getJoinSql()
+ where, args := tables.getCondSQL(cond, false, tz)
+ groupBy := tables.getGroupSQL(qs.groups)
+ orderBy := tables.getOrderSQL(qs.orders)
+ limit := tables.getLimitSQL(mi, qs.offset, qs.limit)
+ join := tables.getJoinSQL()
sels := strings.Join(cols, ", ")
- query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit)
+ query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s%s", sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit)
d.ins.ReplaceMarks(&query)
- var rs *sql.Rows
- if r, err := q.Query(query, args...); err != nil {
+ rs, err := q.Query(query, args...)
+ if err != nil {
return 0, err
- } else {
- rs = r
}
-
refs := make([]interface{}, len(cols))
for i := range refs {
var ref interface{}
@@ -1475,11 +1462,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
)
for rs.Next() {
if cnt == 0 {
- if cols, err := rs.Columns(); err != nil {
+ cols, err := rs.Columns()
+ if err != nil {
return 0, err
- } else {
- columns = cols
}
+ columns = cols
}
if err := rs.Scan(refs...); err != nil {
@@ -1643,7 +1630,7 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e
}
// not implement.
-func (d *dbBase) OperatorSql(operator string) string {
+func (d *dbBase) OperatorSQL(operator string) string {
panic(ErrNotImplement)
}
diff --git a/orm/db_alias.go b/orm/db_alias.go
index 0a862241..79576b8e 100644
--- a/orm/db_alias.go
+++ b/orm/db_alias.go
@@ -22,15 +22,17 @@ import (
"time"
)
-// database driver constant int.
+// DriverType database driver constant int.
type DriverType int
+// Enum the Database driver
const (
- _ DriverType = iota // int enum type
- DR_MySQL // mysql
- DR_Sqlite // sqlite
- DR_Oracle // oracle
- DR_Postgres // pgsql
+ _ DriverType = iota // int enum type
+ DRMySQL // mysql
+ DRSqlite // sqlite
+ DROracle // oracle
+ DRPostgres // pgsql
+ DRTiDB // TiDB
)
// database driver string.
@@ -53,15 +55,17 @@ var _ Driver = new(driver)
var (
dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
drivers = map[string]DriverType{
- "mysql": DR_MySQL,
- "postgres": DR_Postgres,
- "sqlite3": DR_Sqlite,
+ "mysql": DRMySQL,
+ "postgres": DRPostgres,
+ "sqlite3": DRSqlite,
+ "tidb": DRTiDB,
}
dbBasers = map[DriverType]dbBaser{
- DR_MySQL: newdbBaseMysql(),
- DR_Sqlite: newdbBaseSqlite(),
- DR_Oracle: newdbBaseMysql(),
- DR_Postgres: newdbBasePostgres(),
+ DRMySQL: newdbBaseMysql(),
+ DRSqlite: newdbBaseSqlite(),
+ DROracle: newdbBaseOracle(),
+ DRPostgres: newdbBasePostgres(),
+ DRTiDB: newdbBaseTidb(),
}
)
@@ -119,7 +123,7 @@ func detectTZ(al *alias) {
}
switch al.Driver {
- case DR_MySQL:
+ case DRMySQL:
row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
var tz string
row.Scan(&tz)
@@ -147,10 +151,10 @@ func detectTZ(al *alias) {
al.Engine = "INNODB"
}
- case DR_Sqlite:
+ case DRSqlite:
al.TZ = time.UTC
- case DR_Postgres:
+ case DRPostgres:
row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
var tz string
row.Scan(&tz)
@@ -188,12 +192,13 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
return al, nil
}
+// AddAliasWthDB add a aliasName for the drivename
func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
_, err := addAliasWthDB(aliasName, driverName, db)
return err
}
-// Setting the database connect params. Use the database driver self dataSource args.
+// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args.
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
var (
err error
@@ -236,7 +241,7 @@ end:
return err
}
-// Register a database driver use specify driver name, this can be definition the driver is which database type.
+// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
func RegisterDriver(driverName string, typ DriverType) error {
if t, ok := drivers[driverName]; ok == false {
drivers[driverName] = typ
@@ -248,7 +253,7 @@ func RegisterDriver(driverName string, typ DriverType) error {
return nil
}
-// Change the database default used timezone
+// SetDataBaseTZ Change the database default used timezone
func SetDataBaseTZ(aliasName string, tz *time.Location) error {
if al, ok := dataBaseCache.get(aliasName); ok {
al.TZ = tz
@@ -258,14 +263,14 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error {
return nil
}
-// Change the max idle conns for *sql.DB, use specify database alias name
+// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name
func SetMaxIdleConns(aliasName string, maxIdleConns int) {
al := getDbAlias(aliasName)
al.MaxIdleConns = maxIdleConns
al.DB.SetMaxIdleConns(maxIdleConns)
}
-// Change the max open conns for *sql.DB, use specify database alias name
+// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
func SetMaxOpenConns(aliasName string, maxOpenConns int) {
al := getDbAlias(aliasName)
al.MaxOpenConns = maxOpenConns
@@ -275,7 +280,7 @@ func SetMaxOpenConns(aliasName string, maxOpenConns int) {
}
}
-// Get *sql.DB from registered database by db alias name.
+// GetDB Get *sql.DB from registered database by db alias name.
// Use "default" as alias name if you not set.
func GetDB(aliasNames ...string) (*sql.DB, error) {
var name string
@@ -284,9 +289,9 @@ func GetDB(aliasNames ...string) (*sql.DB, error) {
} else {
name = "default"
}
- if al, ok := dataBaseCache.get(name); ok {
+ al, ok := dataBaseCache.get(name)
+ if ok {
return al.DB, nil
- } else {
- return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
}
+ return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
}
diff --git a/orm/db_mysql.go b/orm/db_mysql.go
index 182914a2..10fe2657 100644
--- a/orm/db_mysql.go
+++ b/orm/db_mysql.go
@@ -67,7 +67,7 @@ type dbBaseMysql struct {
var _ dbBaser = new(dbBaseMysql)
// get mysql operator.
-func (d *dbBaseMysql) OperatorSql(operator string) string {
+func (d *dbBaseMysql) OperatorSQL(operator string) string {
return mysqlOperators[operator]
}
diff --git a/orm/db_postgres.go b/orm/db_postgres.go
index 6500ef52..7dbef95a 100644
--- a/orm/db_postgres.go
+++ b/orm/db_postgres.go
@@ -66,7 +66,7 @@ type dbBasePostgres struct {
var _ dbBaser = new(dbBasePostgres)
// get postgresql operator.
-func (d *dbBasePostgres) OperatorSql(operator string) string {
+func (d *dbBasePostgres) OperatorSQL(operator string) string {
return postgresOperators[operator]
}
@@ -101,7 +101,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
num := 0
for _, c := range q {
if c == '?' {
- num += 1
+ num++
}
}
if num == 0 {
@@ -114,7 +114,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
if c == '?' {
data = append(data, '$')
data = append(data, []byte(strconv.Itoa(num))...)
- num += 1
+ num++
} else {
data = append(data, c)
}
diff --git a/orm/db_sqlite.go b/orm/db_sqlite.go
index c2dcbb46..a3cb69a7 100644
--- a/orm/db_sqlite.go
+++ b/orm/db_sqlite.go
@@ -66,7 +66,7 @@ type dbBaseSqlite struct {
var _ dbBaser = new(dbBaseSqlite)
// get sqlite operator.
-func (d *dbBaseSqlite) OperatorSql(operator string) string {
+func (d *dbBaseSqlite) OperatorSQL(operator string) string {
return sqliteOperators[operator]
}
diff --git a/orm/db_tables.go b/orm/db_tables.go
index a9aa10ab..e4c74ace 100644
--- a/orm/db_tables.go
+++ b/orm/db_tables.go
@@ -164,7 +164,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
}
// generate join string.
-func (t *dbTables) getJoinSql() (join string) {
+func (t *dbTables) getJoinSQL() (join string) {
Q := t.base.TableQuote()
for _, jt := range t.tables {
@@ -186,7 +186,7 @@ func (t *dbTables) getJoinSql() (join string) {
table = jt.mi.table
switch {
- case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
+ case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
c1 = jt.fi.mi.fields.pk.column
for _, ffi := range jt.mi.fields.fieldsRel {
if jt.fi.mi == ffi.relModelInfo {
@@ -220,7 +220,7 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
)
num := len(exprs) - 1
- names := make([]string, 0)
+ var names []string
inner := true
@@ -326,7 +326,7 @@ loopFor:
}
// generate condition sql.
-func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
+func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
if cond == nil || cond.IsEmpty() {
return
}
@@ -347,7 +347,7 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
where += "NOT "
}
if p.isCond {
- w, ps := t.getCondSql(p.cond, true, tz)
+ w, ps := t.getCondSQL(p.cond, true, tz)
if w != "" {
w = fmt.Sprintf("( %s) ", w)
}
@@ -372,12 +372,12 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
operator = "exact"
}
- operSql, args := t.base.GenerateOperatorSql(mi, fi, operator, p.args, tz)
+ operSQL, args := t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz)
leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
t.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
- where += fmt.Sprintf("%s %s ", leftCol, operSql)
+ where += fmt.Sprintf("%s %s ", leftCol, operSQL)
params = append(params, args...)
}
@@ -390,8 +390,32 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
return
}
+// generate group sql.
+func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
+ if len(groups) == 0 {
+ return
+ }
+
+ Q := t.base.TableQuote()
+
+ groupSqls := make([]string, 0, len(groups))
+ for _, group := range groups {
+ exprs := strings.Split(group, ExprSep)
+
+ index, _, fi, suc := t.parseExprs(t.mi, exprs)
+ if suc == false {
+ panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
+ }
+
+ groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q))
+ }
+
+ groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", "))
+ return
+}
+
// generate order sql.
-func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
+func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
if len(orders) == 0 {
return
}
@@ -415,12 +439,12 @@ func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc))
}
- orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
+ orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
return
}
// generate limit sql.
-func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) {
+func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) {
if limit == 0 {
limit = int64(DefaultRowsLimit)
}
diff --git a/orm/db_tidb.go b/orm/db_tidb.go
new file mode 100644
index 00000000..6020a488
--- /dev/null
+++ b/orm/db_tidb.go
@@ -0,0 +1,63 @@
+// Copyright 2015 TiDB Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package orm
+
+import (
+ "fmt"
+)
+
+// mysql dbBaser implementation.
+type dbBaseTidb struct {
+ dbBase
+}
+
+var _ dbBaser = new(dbBaseTidb)
+
+// get mysql operator.
+func (d *dbBaseTidb) OperatorSQL(operator string) string {
+ return mysqlOperators[operator]
+}
+
+// get mysql table field types.
+func (d *dbBaseTidb) DbTypes() map[string]string {
+ return mysqlTypes
+}
+
+// show table sql for mysql.
+func (d *dbBaseTidb) ShowTablesQuery() string {
+ return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
+}
+
+// show columns sql of table for mysql.
+func (d *dbBaseTidb) ShowColumnsQuery(table string) string {
+ return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
+ "WHERE table_schema = DATABASE() AND table_name = '%s'", table)
+}
+
+// execute sql to check index exist.
+func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool {
+ row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
+ "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
+ var cnt int
+ row.Scan(&cnt)
+ return cnt > 0
+}
+
+// create new mysql dbBaser.
+func newdbBaseTidb() dbBaser {
+ b := new(dbBaseTidb)
+ b.ins = b
+ return b
+}
diff --git a/orm/db_utils.go b/orm/db_utils.go
index 4a3ba464..ae9b1625 100644
--- a/orm/db_utils.go
+++ b/orm/db_utils.go
@@ -24,9 +24,8 @@ import (
func getDbAlias(name string) *alias {
if al, ok := dataBaseCache.get(name); ok {
return al
- } else {
- panic(fmt.Errorf("unknown DataBase alias name %s", name))
}
+ panic(fmt.Errorf("unknown DataBase alias name %s", name))
}
// get pk column info.
@@ -80,19 +79,19 @@ outFor:
var err error
if len(v) >= 19 {
s := v[:19]
- t, err = time.ParseInLocation(format_DateTime, s, DefaultTimeLoc)
+ t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc)
} else {
s := v
if len(v) > 10 {
s = v[:10]
}
- t, err = time.ParseInLocation(format_Date, s, tz)
+ t, err = time.ParseInLocation(formatDate, s, tz)
}
if err == nil {
if fi.fieldType == TypeDateField {
- v = t.In(tz).Format(format_Date)
+ v = t.In(tz).Format(formatDate)
} else {
- v = t.In(tz).Format(format_DateTime)
+ v = t.In(tz).Format(formatDateTime)
}
}
}
@@ -137,9 +136,9 @@ outFor:
case reflect.Struct:
if v, ok := arg.(time.Time); ok {
if fi != nil && fi.fieldType == TypeDateField {
- arg = v.In(tz).Format(format_Date)
+ arg = v.In(tz).Format(formatDate)
} else {
- arg = v.In(tz).Format(format_DateTime)
+ arg = v.In(tz).Format(formatDateTime)
}
} else {
typ := val.Type()
diff --git a/orm/models.go b/orm/models.go
index dcb32b55..faf551be 100644
--- a/orm/models.go
+++ b/orm/models.go
@@ -19,10 +19,10 @@ import (
)
const (
- od_CASCADE = "cascade"
- od_SET_NULL = "set_null"
- od_SET_DEFAULT = "set_default"
- od_DO_NOTHING = "do_nothing"
+ odCascade = "cascade"
+ odSetNULL = "set_null"
+ odSetDefault = "set_default"
+ odDoNothing = "do_nothing"
defaultStructTagName = "orm"
defaultStructTagDelim = ";"
)
@@ -113,7 +113,7 @@ func (mc *_modelCache) clean() {
mc.done = false
}
-// Clean model cache. Then you can re-RegisterModel.
+// ResetModelCache Clean model cache. Then you can re-RegisterModel.
// Common use this api for test case.
func ResetModelCache() {
modelCache.clean()
diff --git a/orm/models_boot.go b/orm/models_boot.go
index cb44bc05..3690557b 100644
--- a/orm/models_boot.go
+++ b/orm/models_boot.go
@@ -51,19 +51,16 @@ func registerModel(prefix string, model interface{}) {
}
info := newModelInfo(val)
-
if info.fields.pk == nil {
outFor:
for _, fi := range info.fields.fieldsDB {
- if fi.name == "Id" {
- if fi.sf.Tag.Get(defaultStructTagName) == "" {
- switch fi.addrValue.Elem().Kind() {
- case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
- fi.auto = true
- fi.pk = true
- info.fields.pk = fi
- break outFor
- }
+ if strings.ToLower(fi.name) == "id" {
+ switch fi.addrValue.Elem().Kind() {
+ case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
+ fi.auto = true
+ fi.pk = true
+ info.fields.pk = fi
+ break outFor
}
}
}
@@ -269,7 +266,10 @@ func bootStrap() {
if found == false {
mForC:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] {
- if ffi.relModelInfo == mi {
+ conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough ||
+ fi.relTable != "" && fi.relTable == ffi.relTable ||
+ fi.relThrough == "" && fi.relTable == ""
+ if ffi.relModelInfo == mi && conditions {
found = true
fi.reverseField = ffi.reverseFieldInfoTwo.name
@@ -298,12 +298,12 @@ end:
}
}
-// register models
+// RegisterModel register models
func RegisterModel(models ...interface{}) {
RegisterModelWithPrefix("", models...)
}
-// register models with a prefix
+// RegisterModelWithPrefix register models with a prefix
func RegisterModelWithPrefix(prefix string, models ...interface{}) {
if modelCache.done {
panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
@@ -314,7 +314,7 @@ func RegisterModelWithPrefix(prefix string, models ...interface{}) {
}
}
-// bootrap models.
+// BootStrap bootrap models.
// make all model parsed and can not add more models
func BootStrap() {
if modelCache.done {
diff --git a/orm/models_fields.go b/orm/models_fields.go
index f038dd0f..a8cf8e4f 100644
--- a/orm/models_fields.go
+++ b/orm/models_fields.go
@@ -15,49 +15,28 @@
package orm
import (
- "errors"
"fmt"
"strconv"
"time"
)
+// Define the Type enum
const (
- // bool
TypeBooleanField = 1 << iota
-
- // string
TypeCharField
-
- // string
TypeTextField
-
- // time.Time
TypeDateField
- // time.Time
TypeDateTimeField
-
- // int8
TypeBitField
- // int16
TypeSmallIntegerField
- // int32
TypeIntegerField
- // int64
TypeBigIntegerField
- // uint8
TypePositiveBitField
- // uint16
TypePositiveSmallIntegerField
- // uint32
TypePositiveIntegerField
- // uint64
TypePositiveBigIntegerField
-
- // float64
TypeFloatField
- // float64
TypeDecimalField
-
RelForeignKey
RelOneToOne
RelManyToMany
@@ -65,6 +44,7 @@ const (
RelReverseMany
)
+// Define some logic enum
const (
IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5
IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 8 << 9
@@ -72,25 +52,30 @@ const (
IsFieldType = ^-RelReverseMany<<1 + 1
)
-// A true/false field.
+// BooleanField A true/false field.
type BooleanField bool
+// Value return the BooleanField
func (e BooleanField) Value() bool {
return bool(e)
}
+// Set will set the BooleanField
func (e *BooleanField) Set(d bool) {
*e = BooleanField(d)
}
+// String format the Bool to string
func (e *BooleanField) String() string {
return strconv.FormatBool(e.Value())
}
+// FieldType return BooleanField the type
func (e *BooleanField) FieldType() int {
return TypeBooleanField
}
+// SetRaw set the interface to bool
func (e *BooleanField) SetRaw(value interface{}) error {
switch d := value.(type) {
case bool:
@@ -102,56 +87,65 @@ func (e *BooleanField) SetRaw(value interface{}) error {
}
return err
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return the current value
func (e *BooleanField) RawValue() interface{} {
return e.Value()
}
+// verify the BooleanField implement the Fielder interface
var _ Fielder = new(BooleanField)
-// A string field
+// CharField A string field
// required values tag: size
// The size is enforced at the database level and in models’s validation.
// eg: `orm:"size(120)"`
type CharField string
+// Value return the CharField's Value
func (e CharField) Value() string {
return string(e)
}
+// Set CharField value
func (e *CharField) Set(d string) {
*e = CharField(d)
}
+// String return the CharField
func (e *CharField) String() string {
return e.Value()
}
+// FieldType return the enum type
func (e *CharField) FieldType() int {
return TypeCharField
}
+// SetRaw set the interface to string
func (e *CharField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return the CharField value
func (e *CharField) RawValue() interface{} {
return e.Value()
}
+// verify CharField implement Fielder
var _ Fielder = new(CharField)
-// A date, represented in go by a time.Time instance.
+// DateField A date, represented in go by a time.Time instance.
// only date values like 2006-01-02
// Has a few extra, optional attr tag:
//
@@ -166,106 +160,125 @@ var _ Fielder = new(CharField)
// eg: `orm:"auto_now"` or `orm:"auto_now_add"`
type DateField time.Time
+// Value return the time.Time
func (e DateField) Value() time.Time {
return time.Time(e)
}
+// Set set the DateField's value
func (e *DateField) Set(d time.Time) {
*e = DateField(d)
}
+// String convert datatime to string
func (e *DateField) String() string {
return e.Value().String()
}
+// FieldType return enum type Date
func (e *DateField) FieldType() int {
return TypeDateField
}
+// SetRaw convert the interface to time.Time. Allow string and time.Time
func (e *DateField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
- v, err := timeParse(d, format_Date)
+ v, err := timeParse(d, formatDate)
if err != nil {
e.Set(v)
}
return err
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return Date value
func (e *DateField) RawValue() interface{} {
return e.Value()
}
+// verify DateField implement fielder interface
var _ Fielder = new(DateField)
-// A date, represented in go by a time.Time instance.
+// DateTimeField A date, represented in go by a time.Time instance.
// datetime values like 2006-01-02 15:04:05
// Takes the same extra arguments as DateField.
type DateTimeField time.Time
+// Value return the datatime value
func (e DateTimeField) Value() time.Time {
return time.Time(e)
}
+// Set set the time.Time to datatime
func (e *DateTimeField) Set(d time.Time) {
*e = DateTimeField(d)
}
+// String return the time's String
func (e *DateTimeField) String() string {
return e.Value().String()
}
+// FieldType return the enum TypeDateTimeField
func (e *DateTimeField) FieldType() int {
return TypeDateTimeField
}
+// SetRaw convert the string or time.Time to DateTimeField
func (e *DateTimeField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
- v, err := timeParse(d, format_DateTime)
+ v, err := timeParse(d, formatDateTime)
if err != nil {
e.Set(v)
}
return err
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return the datatime value
func (e *DateTimeField) RawValue() interface{} {
return e.Value()
}
+// verify datatime implement fielder
var _ Fielder = new(DateTimeField)
-// A floating-point number represented in go by a float32 value.
+// FloatField A floating-point number represented in go by a float32 value.
type FloatField float64
+// Value return the FloatField value
func (e FloatField) Value() float64 {
return float64(e)
}
+// Set the Float64
func (e *FloatField) Set(d float64) {
*e = FloatField(d)
}
+// String return the string
func (e *FloatField) String() string {
return ToStr(e.Value(), -1, 32)
}
+// FieldType return the enum type
func (e *FloatField) FieldType() int {
return TypeFloatField
}
+// SetRaw converter interface Float64 float32 or string to FloatField
func (e *FloatField) SetRaw(value interface{}) error {
switch d := value.(type) {
case float32:
@@ -278,36 +291,43 @@ func (e *FloatField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return the FloatField value
func (e *FloatField) RawValue() interface{} {
return e.Value()
}
+// verify FloatField implement Fielder
var _ Fielder = new(FloatField)
-// -32768 to 32767
+// SmallIntegerField -32768 to 32767
type SmallIntegerField int16
+// Value return int16 value
func (e SmallIntegerField) Value() int16 {
return int16(e)
}
+// Set the SmallIntegerField value
func (e *SmallIntegerField) Set(d int16) {
*e = SmallIntegerField(d)
}
+// String convert smallint to string
func (e *SmallIntegerField) String() string {
return ToStr(e.Value())
}
+// FieldType return enum type SmallIntegerField
func (e *SmallIntegerField) FieldType() int {
return TypeSmallIntegerField
}
+// SetRaw convert interface int16/string to int16
func (e *SmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int16:
@@ -318,36 +338,43 @@ func (e *SmallIntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return smallint value
func (e *SmallIntegerField) RawValue() interface{} {
return e.Value()
}
+// verify SmallIntegerField implement Fielder
var _ Fielder = new(SmallIntegerField)
-// -2147483648 to 2147483647
+// IntegerField -2147483648 to 2147483647
type IntegerField int32
+// Value return the int32
func (e IntegerField) Value() int32 {
return int32(e)
}
+// Set IntegerField value
func (e *IntegerField) Set(d int32) {
*e = IntegerField(d)
}
+// String convert Int32 to string
func (e *IntegerField) String() string {
return ToStr(e.Value())
}
+// FieldType return the enum type
func (e *IntegerField) FieldType() int {
return TypeIntegerField
}
+// SetRaw convert interface int32/string to int32
func (e *IntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int32:
@@ -358,36 +385,43 @@ func (e *IntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return IntegerField value
func (e *IntegerField) RawValue() interface{} {
return e.Value()
}
+// verify IntegerField implement Fielder
var _ Fielder = new(IntegerField)
-// -9223372036854775808 to 9223372036854775807.
+// BigIntegerField -9223372036854775808 to 9223372036854775807.
type BigIntegerField int64
+// Value return int64
func (e BigIntegerField) Value() int64 {
return int64(e)
}
+// Set the BigIntegerField value
func (e *BigIntegerField) Set(d int64) {
*e = BigIntegerField(d)
}
+// String convert BigIntegerField to string
func (e *BigIntegerField) String() string {
return ToStr(e.Value())
}
+// FieldType return enum type
func (e *BigIntegerField) FieldType() int {
return TypeBigIntegerField
}
+// SetRaw convert interface int64/string to int64
func (e *BigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int64:
@@ -398,36 +432,43 @@ func (e *BigIntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return BigIntegerField value
func (e *BigIntegerField) RawValue() interface{} {
return e.Value()
}
+// verify BigIntegerField implement Fielder
var _ Fielder = new(BigIntegerField)
-// 0 to 65535
+// PositiveSmallIntegerField 0 to 65535
type PositiveSmallIntegerField uint16
+// Value return uint16
func (e PositiveSmallIntegerField) Value() uint16 {
return uint16(e)
}
+// Set PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) Set(d uint16) {
*e = PositiveSmallIntegerField(d)
}
+// String convert uint16 to string
func (e *PositiveSmallIntegerField) String() string {
return ToStr(e.Value())
}
+// FieldType return enum type
func (e *PositiveSmallIntegerField) FieldType() int {
return TypePositiveSmallIntegerField
}
+// SetRaw convert Interface uint16/string to uint16
func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint16:
@@ -438,36 +479,43 @@ func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue returns PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) RawValue() interface{} {
return e.Value()
}
+// verify PositiveSmallIntegerField implement Fielder
var _ Fielder = new(PositiveSmallIntegerField)
-// 0 to 4294967295
+// PositiveIntegerField 0 to 4294967295
type PositiveIntegerField uint32
+// Value return PositiveIntegerField value. Uint32
func (e PositiveIntegerField) Value() uint32 {
return uint32(e)
}
+// Set the PositiveIntegerField value
func (e *PositiveIntegerField) Set(d uint32) {
*e = PositiveIntegerField(d)
}
+// String convert PositiveIntegerField to string
func (e *PositiveIntegerField) String() string {
return ToStr(e.Value())
}
+// FieldType return enum type
func (e *PositiveIntegerField) FieldType() int {
return TypePositiveIntegerField
}
+// SetRaw convert interface uint32/string to Uint32
func (e *PositiveIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint32:
@@ -478,36 +526,43 @@ func (e *PositiveIntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return the PositiveIntegerField Value
func (e *PositiveIntegerField) RawValue() interface{} {
return e.Value()
}
+// verify PositiveIntegerField implement Fielder
var _ Fielder = new(PositiveIntegerField)
-// 0 to 18446744073709551615
+// PositiveBigIntegerField 0 to 18446744073709551615
type PositiveBigIntegerField uint64
+// Value return uint64
func (e PositiveBigIntegerField) Value() uint64 {
return uint64(e)
}
+// Set PositiveBigIntegerField value
func (e *PositiveBigIntegerField) Set(d uint64) {
*e = PositiveBigIntegerField(d)
}
+// String convert PositiveBigIntegerField to string
func (e *PositiveBigIntegerField) String() string {
return ToStr(e.Value())
}
+// FieldType return enum type
func (e *PositiveBigIntegerField) FieldType() int {
return TypePositiveIntegerField
}
+// SetRaw convert interface uint64/string to Uint64
func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint64:
@@ -518,48 +573,57 @@ func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return PositiveBigIntegerField value
func (e *PositiveBigIntegerField) RawValue() interface{} {
return e.Value()
}
+// verify PositiveBigIntegerField implement Fielder
var _ Fielder = new(PositiveBigIntegerField)
-// A large text field.
+// TextField A large text field.
type TextField string
+// Value return TextField value
func (e TextField) Value() string {
return string(e)
}
+// Set the TextField value
func (e *TextField) Set(d string) {
*e = TextField(d)
}
+// String convert TextField to string
func (e *TextField) String() string {
return e.Value()
}
+// FieldType return enum type
func (e *TextField) FieldType() int {
return TypeTextField
}
+// SetRaw convert interface string to string
func (e *TextField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return TextField value
func (e *TextField) RawValue() interface{} {
return e.Value()
}
+// verify TextField implement Fielder
var _ Fielder = new(TextField)
diff --git a/orm/models_info_f.go b/orm/models_info_f.go
index 84a0c024..14e1f2c6 100644
--- a/orm/models_info_f.go
+++ b/orm/models_info_f.go
@@ -119,8 +119,8 @@ type fieldInfo struct {
colDefault bool
initial StrTo
size int
- auto_now bool
- auto_now_add bool
+ autoNow bool
+ autoNowAdd bool
rel bool
reverse bool
reverseField string
@@ -223,6 +223,11 @@ checkType:
break checkType
case "many":
fieldType = RelReverseMany
+ if tv := tags["rel_table"]; tv != "" {
+ fi.relTable = tv
+ } else if tv := tags["rel_through"]; tv != "" {
+ fi.relThrough = tv
+ }
break checkType
default:
err = fmt.Errorf("error")
@@ -309,20 +314,20 @@ checkType:
if fi.rel && fi.dbcol {
switch onDelete {
- case od_CASCADE, od_DO_NOTHING:
- case od_SET_DEFAULT:
+ case odCascade, odDoNothing:
+ case odSetDefault:
if initial.Exist() == false {
err = errors.New("on_delete: set_default need set field a default value")
goto end
}
- case od_SET_NULL:
+ case odSetNULL:
if fi.null == false {
err = errors.New("on_delete: set_null need set field null")
goto end
}
default:
if onDelete == "" {
- onDelete = od_CASCADE
+ onDelete = odCascade
} else {
err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete)
goto end
@@ -350,9 +355,9 @@ checkType:
fi.unique = false
case TypeDateField, TypeDateTimeField:
if attrs["auto_now"] {
- fi.auto_now = true
+ fi.autoNow = true
} else if attrs["auto_now_add"] {
- fi.auto_now_add = true
+ fi.autoNowAdd = true
}
case TypeFloatField:
case TypeDecimalField:
diff --git a/orm/models_info_m.go b/orm/models_info_m.go
index 3600ee7c..2654cdb5 100644
--- a/orm/models_info_m.go
+++ b/orm/models_info_m.go
@@ -15,7 +15,6 @@
package orm
import (
- "errors"
"fmt"
"os"
"reflect"
@@ -72,13 +71,13 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
added := info.fields.Add(fi)
if added == false {
- err = errors.New(fmt.Sprintf("duplicate column name: %s", fi.column))
+ err = fmt.Errorf("duplicate column name: %s", fi.column)
break
}
if fi.pk {
if info.fields.pk != nil {
- err = errors.New(fmt.Sprintf("one model must have one pk field only"))
+ err = fmt.Errorf("one model must have one pk field only")
break
} else {
info.fields.pk = fi
diff --git a/orm/models_test.go b/orm/models_test.go
index 6ca9590c..ee56e8e8 100644
--- a/orm/models_test.go
+++ b/orm/models_test.go
@@ -25,6 +25,9 @@ import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
+
+ // As tidb can't use go get, so disable the tidb testing now
+ // _ "github.com/pingcap/tidb"
)
// A slice string field.
@@ -76,21 +79,21 @@ func (e *SliceStringField) RawValue() interface{} {
var _ Fielder = new(SliceStringField)
// A json field.
-type JsonField struct {
+type JSONField struct {
Name string
Data string
}
-func (e *JsonField) String() string {
+func (e *JSONField) String() string {
data, _ := json.Marshal(e)
return string(data)
}
-func (e *JsonField) FieldType() int {
+func (e *JSONField) FieldType() int {
return TypeTextField
}
-func (e *JsonField) SetRaw(value interface{}) error {
+func (e *JSONField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
return json.Unmarshal([]byte(d), e)
@@ -99,14 +102,14 @@ func (e *JsonField) SetRaw(value interface{}) error {
}
}
-func (e *JsonField) RawValue() interface{} {
+func (e *JSONField) RawValue() interface{} {
return e.String()
}
-var _ Fielder = new(JsonField)
+var _ Fielder = new(JSONField)
type Data struct {
- Id int
+ ID int `orm:"column(id)"`
Boolean bool
Char string `orm:"size(50)"`
Text string `orm:"type(text)"`
@@ -130,7 +133,7 @@ type Data struct {
}
type DataNull struct {
- Id int
+ ID int `orm:"column(id)"`
Boolean bool `orm:"null"`
Char string `orm:"null;size(50)"`
Text string `orm:"null;type(text)"`
@@ -193,7 +196,7 @@ type Float32 float64
type Float64 float64
type DataCustom struct {
- Id int
+ ID int `orm:"column(id)"`
Boolean Boolean
Char string `orm:"size(50)"`
Text string `orm:"type(text)"`
@@ -216,28 +219,28 @@ type DataCustom struct {
// only for mysql
type UserBig struct {
- Id uint64
+ ID uint64 `orm:"column(id)"`
Name string
}
type User struct {
- Id int
- UserName string `orm:"size(30);unique"`
- Email string `orm:"size(100)"`
- Password string `orm:"size(100)"`
- Status int16 `orm:"column(Status)"`
- IsStaff bool
- IsActive bool `orm:"default(1)"`
- Created time.Time `orm:"auto_now_add;type(date)"`
- Updated time.Time `orm:"auto_now"`
- Profile *Profile `orm:"null;rel(one);on_delete(set_null)"`
- Posts []*Post `orm:"reverse(many)" json:"-"`
- ShouldSkip string `orm:"-"`
- Nums int
- Langs SliceStringField `orm:"size(100)"`
- Extra JsonField `orm:"type(text)"`
- unexport bool `orm:"-"`
- unexport_ bool
+ ID int `orm:"column(id)"`
+ UserName string `orm:"size(30);unique"`
+ Email string `orm:"size(100)"`
+ Password string `orm:"size(100)"`
+ Status int16 `orm:"column(Status)"`
+ IsStaff bool
+ IsActive bool `orm:"default(true)"`
+ Created time.Time `orm:"auto_now_add;type(date)"`
+ Updated time.Time `orm:"auto_now"`
+ Profile *Profile `orm:"null;rel(one);on_delete(set_null)"`
+ Posts []*Post `orm:"reverse(many)" json:"-"`
+ ShouldSkip string `orm:"-"`
+ Nums int
+ Langs SliceStringField `orm:"size(100)"`
+ Extra JSONField `orm:"type(text)"`
+ unexport bool `orm:"-"`
+ unexportBool bool
}
func (u *User) TableIndex() [][]string {
@@ -259,7 +262,7 @@ func NewUser() *User {
}
type Profile struct {
- Id int
+ ID int `orm:"column(id)"`
Age int16
Money float64
User *User `orm:"reverse(one)" json:"-"`
@@ -276,7 +279,7 @@ func NewProfile() *Profile {
}
type Post struct {
- Id int
+ ID int `orm:"column(id)"`
User *User `orm:"rel(fk)"`
Title string `orm:"size(60)"`
Content string `orm:"type(text)"`
@@ -297,7 +300,7 @@ func NewPost() *Post {
}
type Tag struct {
- Id int
+ ID int `orm:"column(id)"`
Name string `orm:"size(30)"`
BestPost *Post `orm:"rel(one);null"`
Posts []*Post `orm:"reverse(many)" json:"-"`
@@ -309,7 +312,7 @@ func NewTag() *Tag {
}
type PostTags struct {
- Id int
+ ID int `orm:"column(id)"`
Post *Post `orm:"rel(fk)"`
Tag *Tag `orm:"rel(fk)"`
}
@@ -319,7 +322,7 @@ func (m *PostTags) TableName() string {
}
type Comment struct {
- Id int
+ ID int `orm:"column(id)"`
Post *Post `orm:"rel(fk);column(post)"`
Content string `orm:"type(text)"`
Parent *Comment `orm:"null;rel(fk)"`
@@ -331,6 +334,24 @@ func NewComment() *Comment {
return obj
}
+type Group struct {
+ ID int `orm:"column(gid);size(32)"`
+ Name string
+ Permissions []*Permission `orm:"reverse(many)" json:"-"`
+}
+
+type Permission struct {
+ ID int `orm:"column(id)"`
+ Name string
+ Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.GroupPermissions)"`
+}
+
+type GroupPermissions struct {
+ ID int `orm:"column(id)"`
+ Group *Group `orm:"rel(fk)"`
+ Permission *Permission `orm:"rel(fk)"`
+}
+
var DBARGS = struct {
Driver string
Source string
@@ -345,6 +366,7 @@ var (
IsMysql = DBARGS.Driver == "mysql"
IsSqlite = DBARGS.Driver == "sqlite3"
IsPostgres = DBARGS.Driver == "postgres"
+ IsTidb = DBARGS.Driver == "tidb"
)
var (
@@ -364,6 +386,7 @@ Default DB Drivers.
mysql: https://github.com/go-sql-driver/mysql
sqlite3: https://github.com/mattn/go-sqlite3
postgres: https://github.com/lib/pq
+tidb: https://github.com/pingcap/tidb
usage:
@@ -371,6 +394,7 @@ go get -u github.com/astaxie/beego/orm
go get -u github.com/go-sql-driver/mysql
go get -u github.com/mattn/go-sqlite3
go get -u github.com/lib/pq
+go get -u github.com/pingcap/tidb
#### MySQL
mysql -u root -e 'create database orm_test;'
@@ -390,6 +414,12 @@ psql -c 'create database orm_test;' -U postgres
export ORM_DRIVER=postgres
export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
go test -v github.com/astaxie/beego/orm
+
+#### TiDB
+export ORM_DRIVER=tidb
+export ORM_SOURCE='memory://test/test'
+go test -v github.com/astaxie/beego/orm
+
`)
os.Exit(2)
}
@@ -397,7 +427,7 @@ go test -v github.com/astaxie/beego/orm
RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20)
alias := getDbAlias("default")
- if alias.Driver == DR_MySQL {
+ if alias.Driver == DRMySQL {
alias.Engine = "INNODB"
}
diff --git a/orm/orm.go b/orm/orm.go
index f881433b..d00d6d03 100644
--- a/orm/orm.go
+++ b/orm/orm.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package orm provide ORM for MySQL/PostgreSQL/sqlite
// Simple Usage
//
// package main
@@ -59,12 +60,13 @@ import (
"time"
)
+// DebugQueries define the debug
const (
- Debug_Queries = iota
+ DebugQueries = iota
)
+// Define common vars
var (
- // DebugLevel = Debug_Queries
Debug = false
DebugLog = NewLog(os.Stderr)
DefaultRowsLimit = 1000
@@ -79,7 +81,10 @@ var (
ErrNotImplement = errors.New("have not implement")
)
+// Params stores the Params
type Params map[string]interface{}
+
+// ParamsList stores paramslist
type ParamsList []interface{}
type orm struct {
@@ -188,7 +193,7 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
o.setPk(mi, ind, id)
- cnt += 1
+ cnt++
}
} else {
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
@@ -489,7 +494,7 @@ func (o *orm) Driver() Driver {
return driver(o.alias.Name)
}
-// create new orm
+// NewOrm create new orm
func NewOrm() Ormer {
BootStrap() // execute only once
@@ -501,7 +506,7 @@ func NewOrm() Ormer {
return o
}
-// create a new ormer object with specify *sql.DB for query
+// NewOrmWithDB create a new ormer object with specify *sql.DB for query
func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
var al *alias
diff --git a/orm/orm_conds.go b/orm/orm_conds.go
index 6344653a..b2eae418 100644
--- a/orm/orm_conds.go
+++ b/orm/orm_conds.go
@@ -19,6 +19,7 @@ import (
"strings"
)
+// ExprSep define the expression seperation
const (
ExprSep = "__"
)
@@ -32,19 +33,19 @@ type condValue struct {
isCond bool
}
-// condition struct.
+// Condition struct.
// work for WHERE conditions.
type Condition struct {
params []condValue
}
-// return new condition struct
+// NewCondition return new condition struct
func NewCondition() *Condition {
c := &Condition{}
return c
}
-// add expression to condition
+// And add expression to condition
func (c Condition) And(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf(" args cannot empty"))
@@ -53,7 +54,7 @@ func (c Condition) And(expr string, args ...interface{}) *Condition {
return &c
}
-// add NOT expression to condition
+// AndNot add NOT expression to condition
func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf(" args cannot empty"))
@@ -62,7 +63,7 @@ func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
return &c
}
-// combine a condition to current condition
+// AndCond combine a condition to current condition
func (c *Condition) AndCond(cond *Condition) *Condition {
c = c.clone()
if c == cond {
@@ -74,7 +75,7 @@ func (c *Condition) AndCond(cond *Condition) *Condition {
return c
}
-// add OR expression to condition
+// Or add OR expression to condition
func (c Condition) Or(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf(" args cannot empty"))
@@ -83,7 +84,7 @@ func (c Condition) Or(expr string, args ...interface{}) *Condition {
return &c
}
-// add OR NOT expression to condition
+// OrNot add OR NOT expression to condition
func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf(" args cannot empty"))
@@ -92,7 +93,7 @@ func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
return &c
}
-// combine a OR condition to current condition
+// OrCond combine a OR condition to current condition
func (c *Condition) OrCond(cond *Condition) *Condition {
c = c.clone()
if c == cond {
@@ -104,12 +105,12 @@ func (c *Condition) OrCond(cond *Condition) *Condition {
return c
}
-// check the condition arguments are empty or not.
+// IsEmpty check the condition arguments are empty or not.
func (c *Condition) IsEmpty() bool {
return len(c.params) == 0
}
-// clone a condition
+// clone clone a condition
func (c Condition) clone() *Condition {
return &c
}
diff --git a/orm/orm_log.go b/orm/orm_log.go
index 419d8e11..712eb219 100644
--- a/orm/orm_log.go
+++ b/orm/orm_log.go
@@ -23,11 +23,12 @@ import (
"time"
)
+// Log implement the log.Logger
type Log struct {
*log.Logger
}
-// set io.Writer to create a Logger.
+// NewLog set io.Writer to create a Logger.
func NewLog(out io.Writer) *Log {
d := new(Log)
d.Logger = log.New(out, "[ORM]", 1e9)
@@ -41,7 +42,7 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error
if err != nil {
flag = "FAIL"
}
- con := fmt.Sprintf(" - %s - [Queries/%s] - [%s / %11s / %7.1fms] - [%s]", t.Format(format_DateTime), alias.Name, flag, operaton, elsp, query)
+ con := fmt.Sprintf(" - %s - [Queries/%s] - [%s / %11s / %7.1fms] - [%s]", t.Format(formatDateTime), alias.Name, flag, operaton, elsp, query)
cons := make([]string, 0, len(args))
for _, arg := range args {
cons = append(cons, fmt.Sprintf("%v", arg))
diff --git a/orm/orm_querym2m.go b/orm/orm_querym2m.go
index 1eaccf72..60c77cdf 100644
--- a/orm/orm_querym2m.go
+++ b/orm/orm_querym2m.go
@@ -14,9 +14,7 @@
package orm
-import (
- "reflect"
-)
+import "reflect"
// model to model struct
type queryM2M struct {
@@ -44,7 +42,21 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
dbase := orm.alias.DbBaser
var models []interface{}
+ var other_values []interface{}
+ var other_names []string
+ for _, colname := range mi.fields.dbcols {
+ if colname != mfi.column && colname != rfi.column && colname != fi.mi.fields.pk.column &&
+ mi.fields.columns[colname] != mi.fields.pk {
+ other_names = append(other_names, colname)
+ }
+ }
+ for i, md := range mds {
+ if reflect.Indirect(reflect.ValueOf(md)).Kind() != reflect.Struct && i > 0 {
+ other_values = append(other_values, md)
+ mds = append(mds[:i], mds[i+1:]...)
+ }
+ }
for _, md := range mds {
val := reflect.ValueOf(md)
if val.Kind() == reflect.Slice || val.Kind() == reflect.Array {
@@ -67,11 +79,9 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
names := []string{mfi.column, rfi.column}
values := make([]interface{}, 0, len(models)*2)
-
for _, md := range models {
ind := reflect.Indirect(reflect.ValueOf(md))
-
var v2 interface{}
if ind.Kind() != reflect.Struct {
v2 = ind.Interface()
@@ -81,11 +91,11 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
panic(ErrMissPK)
}
}
-
values = append(values, v1, v2)
}
-
+ names = append(names, other_names...)
+ values = append(values, other_values...)
return dbase.InsertValue(orm.db, mi, true, names, values)
}
diff --git a/orm/orm_queryset.go b/orm/orm_queryset.go
index 5cc47617..802a1fe0 100644
--- a/orm/orm_queryset.go
+++ b/orm/orm_queryset.go
@@ -25,11 +25,12 @@ type colValue struct {
type operator int
+// define Col operations
const (
- Col_Add operator = iota
- Col_Minus
- Col_Multiply
- Col_Except
+ ColAdd operator = iota
+ ColMinus
+ ColMultiply
+ ColExcept
)
// ColValue do the field raw changes. e.g Nums = Nums + 10. usage:
@@ -38,7 +39,7 @@ const (
// }
func ColValue(opt operator, value interface{}) interface{} {
switch opt {
- case Col_Add, Col_Minus, Col_Multiply, Col_Except:
+ case ColAdd, ColMinus, ColMultiply, ColExcept:
default:
panic(fmt.Errorf("orm.ColValue wrong operator"))
}
@@ -60,7 +61,9 @@ type querySet struct {
relDepth int
limit int64
offset int64
+ groups []string
orders []string
+ distinct bool
orm *orm
}
@@ -105,6 +108,12 @@ func (o querySet) Offset(offset interface{}) QuerySeter {
return &o
}
+// add GROUP expression
+func (o querySet) GroupBy(exprs ...string) QuerySeter {
+ o.groups = exprs
+ return &o
+}
+
// add ORDER expression.
// "column" means ASC, "-column" means DESC.
func (o querySet) OrderBy(exprs ...string) QuerySeter {
@@ -112,24 +121,30 @@ func (o querySet) OrderBy(exprs ...string) QuerySeter {
return &o
}
+// add DISTINCT to SELECT
+func (o querySet) Distinct() QuerySeter {
+ o.distinct = true
+ return &o
+}
+
// set relation model to query together.
// it will query relation models and assign to parent model.
func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
- if len(params) == 0 {
- o.relDepth = DefaultRelsDepth
- } else {
- for _, p := range params {
- switch val := p.(type) {
- case string:
- o.related = append(o.related, val)
- case int:
- o.relDepth = val
- default:
- panic(fmt.Errorf(" wrong param kind: %v", val))
- }
- }
- }
- return &o
+ if len(params) == 0 {
+ o.relDepth = DefaultRelsDepth
+ } else {
+ for _, p := range params {
+ switch val := p.(type) {
+ case string:
+ o.related = append(o.related, val)
+ case int:
+ o.relDepth = val
+ default:
+ panic(fmt.Errorf(" wrong param kind: %v", val))
+ }
+ }
+ }
+ return &o
}
// set condition to QuerySeter.
diff --git a/orm/orm_raw.go b/orm/orm_raw.go
index 1452d6fc..cbb18064 100644
--- a/orm/orm_raw.go
+++ b/orm/orm_raw.go
@@ -165,14 +165,14 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
if str != "" {
if len(str) >= 19 {
str = str[:19]
- t, err := time.ParseInLocation(format_DateTime, str, o.orm.alias.TZ)
+ t, err := time.ParseInLocation(formatDateTime, str, o.orm.alias.TZ)
if err == nil {
t = t.In(DefaultTimeLoc)
ind.Set(reflect.ValueOf(t))
}
} else if len(str) >= 10 {
str = str[:10]
- t, err := time.ParseInLocation(format_Date, str, DefaultTimeLoc)
+ t, err := time.ParseInLocation(formatDate, str, DefaultTimeLoc)
if err == nil {
ind.Set(reflect.ValueOf(t))
}
@@ -255,12 +255,13 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr
// query data and map to container
func (o *rawSet) QueryRow(containers ...interface{}) error {
- refs := make([]interface{}, 0, len(containers))
- sInds := make([]reflect.Value, 0)
- eTyps := make([]reflect.Type, 0)
-
+ var (
+ refs = make([]interface{}, 0, len(containers))
+ sInds []reflect.Value
+ eTyps []reflect.Type
+ sMi *modelInfo
+ )
structMode := false
- var sMi *modelInfo
for _, container := range containers {
val := reflect.ValueOf(container)
ind := reflect.Indirect(val)
@@ -385,12 +386,13 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
// query data rows and map to container
func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
- refs := make([]interface{}, 0, len(containers))
- sInds := make([]reflect.Value, 0)
- eTyps := make([]reflect.Type, 0)
-
+ var (
+ refs = make([]interface{}, 0, len(containers))
+ sInds []reflect.Value
+ eTyps []reflect.Type
+ sMi *modelInfo
+ )
structMode := false
- var sMi *modelInfo
for _, container := range containers {
val := reflect.ValueOf(container)
sInd := reflect.Indirect(val)
@@ -557,10 +559,9 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
var rs *sql.Rows
- if r, err := o.orm.db.Query(query, args...); err != nil {
+ rs, err := o.orm.db.Query(query, args...)
+ if err != nil {
return 0, err
- } else {
- rs = r
}
defer rs.Close()
@@ -574,30 +575,30 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er
for rs.Next() {
if cnt == 0 {
- if columns, err := rs.Columns(); err != nil {
+ columns, err := rs.Columns()
+ if err != nil {
return 0, err
+ }
+ if len(needCols) > 0 {
+ indexs = make([]int, 0, len(needCols))
} else {
+ indexs = make([]int, 0, len(columns))
+ }
+
+ cols = columns
+ refs = make([]interface{}, len(cols))
+ for i := range refs {
+ var ref sql.NullString
+ refs[i] = &ref
+
if len(needCols) > 0 {
- indexs = make([]int, 0, len(needCols))
- } else {
- indexs = make([]int, 0, len(columns))
- }
-
- cols = columns
- refs = make([]interface{}, len(cols))
- for i := range refs {
- var ref sql.NullString
- refs[i] = &ref
-
- if len(needCols) > 0 {
- for _, c := range needCols {
- if c == cols[i] {
- indexs = append(indexs, i)
- }
+ for _, c := range needCols {
+ if c == cols[i] {
+ indexs = append(indexs, i)
}
- } else {
- indexs = append(indexs, i)
}
+ } else {
+ indexs = append(indexs, i)
}
}
}
@@ -684,11 +685,9 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
- var rs *sql.Rows
- if r, err := o.orm.db.Query(query, args...); err != nil {
+ rs, err := o.orm.db.Query(query, args...)
+ if err != nil {
return 0, err
- } else {
- rs = r
}
defer rs.Close()
@@ -706,32 +705,29 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
for rs.Next() {
if cnt == 0 {
- if columns, err := rs.Columns(); err != nil {
+ columns, err := rs.Columns()
+ if err != nil {
return 0, err
- } else {
- cols = columns
- refs = make([]interface{}, len(cols))
- for i := range refs {
- if keyCol == cols[i] {
- keyIndex = i
- }
-
- if typ == 1 || keyIndex == i {
- var ref sql.NullString
- refs[i] = &ref
- } else {
- var ref interface{}
- refs[i] = &ref
- }
-
- if valueCol == cols[i] {
- valueIndex = i
- }
+ }
+ cols = columns
+ refs = make([]interface{}, len(cols))
+ for i := range refs {
+ if keyCol == cols[i] {
+ keyIndex = i
}
-
- if keyIndex == -1 || valueIndex == -1 {
- panic(fmt.Errorf(" RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol))
+ if typ == 1 || keyIndex == i {
+ var ref sql.NullString
+ refs[i] = &ref
+ } else {
+ var ref interface{}
+ refs[i] = &ref
}
+ if valueCol == cols[i] {
+ valueIndex = i
+ }
+ }
+ if keyIndex == -1 || valueIndex == -1 {
+ panic(fmt.Errorf(" RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol))
}
}
diff --git a/orm/orm_test.go b/orm/orm_test.go
index 14eadabd..d6f6c7a9 100644
--- a/orm/orm_test.go
+++ b/orm/orm_test.go
@@ -31,13 +31,26 @@ import (
var _ = os.PathSeparator
var (
- test_Date = format_Date + " -0700"
- test_DateTime = format_DateTime + " -0700"
+ testDate = formatDate + " -0700"
+ testDateTime = formatDateTime + " -0700"
)
-func ValuesCompare(is bool, a interface{}, args ...interface{}) (err error, ok bool) {
+type argAny []interface{}
+
+// get interface by index from interface slice
+func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
+ if i >= 0 && i < len(a) {
+ r = a[i]
+ }
+ if len(args) > 0 {
+ r = args[0]
+ }
+ return
+}
+
+func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err error) {
if len(args) == 0 {
- return fmt.Errorf("miss args"), false
+ return false, fmt.Errorf("miss args")
}
b := args[0]
arg := argAny(args)
@@ -71,21 +84,21 @@ func ValuesCompare(is bool, a interface{}, args ...interface{}) (err error, ok b
wrongArg:
if err != nil {
- return err, false
+ return false, err
}
- return nil, true
+ return true, nil
}
func AssertIs(a interface{}, args ...interface{}) error {
- if err, ok := ValuesCompare(true, a, args...); ok == false {
+ if ok, err := ValuesCompare(true, a, args...); ok == false {
return err
}
return nil
}
func AssertNot(a interface{}, args ...interface{}) error {
- if err, ok := ValuesCompare(false, a, args...); ok == false {
+ if ok, err := ValuesCompare(false, a, args...); ok == false {
return err
}
return nil
@@ -171,8 +184,11 @@ func TestSyncDb(t *testing.T) {
RegisterModel(new(Comment))
RegisterModel(new(UserBig))
RegisterModel(new(PostTags))
+ RegisterModel(new(Group))
+ RegisterModel(new(Permission))
+ RegisterModel(new(GroupPermissions))
- err := RunSyncdb("default", true, false)
+ err := RunSyncdb("default", true, Debug)
throwFail(t, err)
modelCache.clean()
@@ -187,6 +203,9 @@ func TestRegisterModels(t *testing.T) {
RegisterModel(new(Comment))
RegisterModel(new(UserBig))
RegisterModel(new(PostTags))
+ RegisterModel(new(Group))
+ RegisterModel(new(Permission))
+ RegisterModel(new(GroupPermissions))
BootStrap()
@@ -208,7 +227,7 @@ func TestModelSyntax(t *testing.T) {
}
}
-var Data_Values = map[string]interface{}{
+var DataValues = map[string]interface{}{
"Boolean": true,
"Char": "char",
"Text": "text",
@@ -235,7 +254,7 @@ func TestDataTypes(t *testing.T) {
d := Data{}
ind := reflect.Indirect(reflect.ValueOf(&d))
- for name, value := range Data_Values {
+ for name, value := range DataValues {
e := ind.FieldByName(name)
e.Set(reflect.ValueOf(value))
}
@@ -244,22 +263,22 @@ func TestDataTypes(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
- d = Data{Id: 1}
+ d = Data{ID: 1}
err = dORM.Read(&d)
throwFail(t, err)
ind = reflect.Indirect(reflect.ValueOf(&d))
- for name, value := range Data_Values {
+ for name, value := range DataValues {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
case "Date":
- vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
- value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
+ vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
+ value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
case "DateTime":
- vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
- value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
+ vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
+ value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
}
throwFail(t, AssertIs(vu == value, true), value, vu)
}
@@ -278,7 +297,7 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
- d = DataNull{Id: 1}
+ d = DataNull{ID: 1}
err = dORM.Read(&d)
throwFail(t, err)
@@ -309,7 +328,7 @@ func TestNullDataTypes(t *testing.T) {
_, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec()
throwFail(t, err)
- d = DataNull{Id: 2}
+ d = DataNull{ID: 2}
err = dORM.Read(&d)
throwFail(t, err)
@@ -362,7 +381,7 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 3))
- d = DataNull{Id: 3}
+ d = DataNull{ID: 3}
err = dORM.Read(&d)
throwFail(t, err)
@@ -402,7 +421,7 @@ func TestDataCustomTypes(t *testing.T) {
d := DataCustom{}
ind := reflect.Indirect(reflect.ValueOf(&d))
- for name, value := range Data_Values {
+ for name, value := range DataValues {
e := ind.FieldByName(name)
if !e.IsValid() {
continue
@@ -414,13 +433,13 @@ func TestDataCustomTypes(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
- d = DataCustom{Id: 1}
+ d = DataCustom{ID: 1}
err = dORM.Read(&d)
throwFail(t, err)
ind = reflect.Indirect(reflect.ValueOf(&d))
- for name, value := range Data_Values {
+ for name, value := range DataValues {
e := ind.FieldByName(name)
if !e.IsValid() {
continue
@@ -451,7 +470,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
- u := &User{Id: user.Id}
+ u := &User{ID: user.ID}
err = dORM.Read(u)
throwFail(t, err)
@@ -461,8 +480,8 @@ func TestCRUD(t *testing.T) {
throwFail(t, AssertIs(u.Status, 3))
throwFail(t, AssertIs(u.IsStaff, true))
throwFail(t, AssertIs(u.IsActive, true))
- throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), test_Date))
- throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), test_DateTime))
+ throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), testDate))
+ throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), testDateTime))
user.UserName = "astaxie"
user.Profile = profile
@@ -470,11 +489,11 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
- u = &User{Id: user.Id}
+ u = &User{ID: user.ID}
err = dORM.Read(u)
throwFailNow(t, err)
throwFail(t, AssertIs(u.UserName, "astaxie"))
- throwFail(t, AssertIs(u.Profile.Id, profile.Id))
+ throwFail(t, AssertIs(u.Profile.ID, profile.ID))
u = &User{UserName: "astaxie", Password: "pass"}
err = dORM.Read(u, "UserName")
@@ -487,7 +506,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
- u = &User{Id: user.Id}
+ u = &User{ID: user.ID}
err = dORM.Read(u)
throwFailNow(t, err)
throwFail(t, AssertIs(u.UserName, "QQ"))
@@ -497,7 +516,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
- u = &User{Id: user.Id}
+ u = &User{ID: user.ID}
err = dORM.Read(u)
throwFail(t, err)
throwFail(t, AssertIs(true, u.Profile == nil))
@@ -506,7 +525,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
- u = &User{Id: 100}
+ u = &User{ID: 100}
err = dORM.Read(u)
throwFail(t, AssertIs(err, ErrNoRows))
@@ -516,7 +535,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
- ub = UserBig{Id: 1}
+ ub = UserBig{ID: 1}
err = dORM.Read(&ub)
throwFail(t, err)
throwFail(t, AssertIs(ub.Name, "name"))
@@ -586,7 +605,7 @@ func TestInsertTestData(t *testing.T) {
throwFail(t, AssertIs(id, 4))
tags := []*Tag{
- {Name: "golang", BestPost: &Post{Id: 2}},
+ {Name: "golang", BestPost: &Post{ID: 2}},
{Name: "example"},
{Name: "format"},
{Name: "c++"},
@@ -635,10 +654,47 @@ The program—and web server—godoc processes Go source files to extract docume
throwFail(t, err)
throwFail(t, AssertIs(id > 0, true))
}
+
+ permissions := []*Permission{
+ {Name: "writePosts"},
+ {Name: "readComments"},
+ {Name: "readPosts"},
+ }
+
+ groups := []*Group{
+ {
+ Name: "admins",
+ Permissions: []*Permission{permissions[0], permissions[1], permissions[2]},
+ },
+ {
+ Name: "users",
+ Permissions: []*Permission{permissions[1], permissions[2]},
+ },
+ }
+
+ for _, permission := range permissions {
+ id, err := dORM.Insert(permission)
+ throwFail(t, err)
+ throwFail(t, AssertIs(id > 0, true))
+ }
+
+ for _, group := range groups {
+ _, err := dORM.Insert(group)
+ throwFail(t, err)
+ throwFail(t, AssertIs(id > 0, true))
+
+ num := len(group.Permissions)
+ if num > 0 {
+ nums, err := dORM.QueryM2M(group, "permissions").Add(group.Permissions)
+ throwFailNow(t, err)
+ throwFailNow(t, AssertIs(nums, num))
+ }
+ }
+
}
func TestCustomField(t *testing.T) {
- user := User{Id: 2}
+ user := User{ID: 2}
err := dORM.Read(&user)
throwFailNow(t, err)
@@ -648,7 +704,7 @@ func TestCustomField(t *testing.T) {
_, err = dORM.Update(&user, "Langs", "Extra")
throwFailNow(t, err)
- user = User{Id: 2}
+ user = User{ID: 2}
err = dORM.Read(&user)
throwFailNow(t, err)
throwFailNow(t, AssertIs(len(user.Langs), 2))
@@ -702,7 +758,7 @@ func TestOperators(t *testing.T) {
var shouldNum int
- if IsSqlite {
+ if IsSqlite || IsTidb {
shouldNum = 2
} else {
shouldNum = 0
@@ -740,7 +796,7 @@ func TestOperators(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
- if IsSqlite {
+ if IsSqlite || IsTidb {
shouldNum = 1
} else {
shouldNum = 0
@@ -758,7 +814,7 @@ func TestOperators(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 2))
- if IsSqlite {
+ if IsSqlite || IsTidb {
shouldNum = 2
} else {
shouldNum = 0
@@ -889,9 +945,9 @@ func TestAll(t *testing.T) {
throwFailNow(t, AssertIs(users2[0].UserName, "slene"))
throwFailNow(t, AssertIs(users2[1].UserName, "astaxie"))
throwFailNow(t, AssertIs(users2[2].UserName, "nobody"))
- throwFailNow(t, AssertIs(users2[0].Id, 0))
- throwFailNow(t, AssertIs(users2[1].Id, 0))
- throwFailNow(t, AssertIs(users2[2].Id, 0))
+ throwFailNow(t, AssertIs(users2[0].ID, 0))
+ throwFailNow(t, AssertIs(users2[1].ID, 0))
+ throwFailNow(t, AssertIs(users2[2].ID, 0))
throwFailNow(t, AssertIs(users2[0].Profile == nil, false))
throwFailNow(t, AssertIs(users2[1].Profile == nil, false))
throwFailNow(t, AssertIs(users2[2].Profile == nil, true))
@@ -986,6 +1042,10 @@ func TestValuesFlat(t *testing.T) {
}
func TestRelatedSel(t *testing.T) {
+ if IsTidb {
+ // Skip it. TiDB does not support relation now.
+ return
+ }
qs := dORM.QueryTable("user")
num, err := qs.Filter("profile__age", 28).Count()
throwFail(t, err)
@@ -1112,7 +1172,7 @@ func TestReverseQuery(t *testing.T) {
func TestLoadRelated(t *testing.T) {
// load reverse foreign key
- user := User{Id: 3}
+ user := User{ID: 3}
err := dORM.Read(&user)
throwFailNow(t, err)
@@ -1121,7 +1181,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 2))
throwFailNow(t, AssertIs(len(user.Posts), 2))
- throwFailNow(t, AssertIs(user.Posts[0].User.Id, 3))
+ throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3))
num, err = dORM.LoadRelated(&user, "Posts", true)
throwFailNow(t, err)
@@ -1143,8 +1203,8 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting"))
// load reverse one to one
- profile := Profile{Id: 3}
- profile.BestPost = &Post{Id: 2}
+ profile := Profile{ID: 3}
+ profile.BestPost = &Post{ID: 2}
num, err = dORM.Update(&profile, "BestPost")
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1))
@@ -1183,7 +1243,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(user.Profile.BestPost == nil, false))
throwFailNow(t, AssertIs(user.Profile.BestPost.Title, "Examples"))
- post := Post{Id: 2}
+ post := Post{ID: 2}
// load rel foreign key
err = dORM.Read(&post)
@@ -1204,7 +1264,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(post.User.Profile.Age, 30))
// load rel m2m
- post = Post{Id: 2}
+ post = Post{ID: 2}
err = dORM.Read(&post)
throwFailNow(t, err)
@@ -1224,7 +1284,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(post.Tags[0].BestPost.User.UserName, "astaxie"))
// load reverse m2m
- tag := Tag{Id: 1}
+ tag := Tag{ID: 1}
err = dORM.Read(&tag)
throwFailNow(t, err)
@@ -1233,19 +1293,19 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3))
throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction"))
- throwFailNow(t, AssertIs(tag.Posts[0].User.Id, 2))
+ throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2))
throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true))
num, err = dORM.LoadRelated(&tag, "Posts", true)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3))
throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction"))
- throwFailNow(t, AssertIs(tag.Posts[0].User.Id, 2))
+ throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2))
throwFailNow(t, AssertIs(tag.Posts[0].User.UserName, "slene"))
}
func TestQueryM2M(t *testing.T) {
- post := Post{Id: 4}
+ post := Post{ID: 4}
m2m := dORM.QueryM2M(&post, "Tags")
tag1 := []*Tag{{Name: "TestTag1"}, {Name: "TestTag2"}}
@@ -1319,7 +1379,7 @@ func TestQueryM2M(t *testing.T) {
for _, post := range posts {
p := post.(*Post)
- p.User = &User{Id: 1}
+ p.User = &User{ID: 1}
_, err := dORM.Insert(post)
throwFailNow(t, err)
}
@@ -1394,6 +1454,18 @@ func TestQueryRelate(t *testing.T) {
// throwFailNow(t, AssertIs(num, 2))
}
+func TestPkManyRelated(t *testing.T) {
+ permission := &Permission{Name: "readPosts"}
+ err := dORM.Read(permission, "Name")
+ throwFailNow(t, err)
+
+ var groups []*Group
+ qs := dORM.QueryTable("Group")
+ num, err := qs.Filter("Permissions__Permission", permission.ID).All(&groups)
+ throwFailNow(t, err)
+ throwFailNow(t, AssertIs(num, 2))
+}
+
func TestPrepareInsert(t *testing.T) {
qs := dORM.QueryTable("user")
i, err := qs.PrepareInsert()
@@ -1459,10 +1531,10 @@ func TestRawQueryRow(t *testing.T) {
Decimal float64
)
- data_values := make(map[string]interface{}, len(Data_Values))
+ dataValues := make(map[string]interface{}, len(DataValues))
- for k, v := range Data_Values {
- data_values[strings.ToLower(k)] = v
+ for k, v := range DataValues {
+ dataValues[strings.ToLower(k)] = v
}
Q := dDbBaser.TableQuote()
@@ -1488,14 +1560,14 @@ func TestRawQueryRow(t *testing.T) {
throwFail(t, AssertIs(id, 1))
case "date":
v = v.(time.Time).In(DefaultTimeLoc)
- value := data_values[col].(time.Time).In(DefaultTimeLoc)
- throwFail(t, AssertIs(v, value, test_Date))
+ value := dataValues[col].(time.Time).In(DefaultTimeLoc)
+ throwFail(t, AssertIs(v, value, testDate))
case "datetime":
v = v.(time.Time).In(DefaultTimeLoc)
- value := data_values[col].(time.Time).In(DefaultTimeLoc)
- throwFail(t, AssertIs(v, value, test_DateTime))
+ value := dataValues[col].(time.Time).In(DefaultTimeLoc)
+ throwFail(t, AssertIs(v, value, testDateTime))
default:
- throwFail(t, AssertIs(v, data_values[col]))
+ throwFail(t, AssertIs(v, dataValues[col]))
}
}
@@ -1529,16 +1601,16 @@ func TestQueryRows(t *testing.T) {
ind := reflect.Indirect(reflect.ValueOf(datas[0]))
- for name, value := range Data_Values {
+ for name, value := range DataValues {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
case "Date":
- vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
- value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
+ vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
+ value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
case "DateTime":
- vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
- value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
+ vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
+ value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
}
throwFail(t, AssertIs(vu == value, true), value, vu)
}
@@ -1553,16 +1625,16 @@ func TestQueryRows(t *testing.T) {
ind = reflect.Indirect(reflect.ValueOf(datas2[0]))
- for name, value := range Data_Values {
+ for name, value := range DataValues {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
case "Date":
- vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
- value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
+ vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
+ value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
case "DateTime":
- vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
- value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
+ vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
+ value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
}
throwFail(t, AssertIs(vu == value, true), value, vu)
}
@@ -1699,25 +1771,25 @@ func TestUpdate(t *testing.T) {
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{
- "Nums": ColValue(Col_Add, 100),
+ "Nums": ColValue(ColAdd, 100),
})
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{
- "Nums": ColValue(Col_Minus, 50),
+ "Nums": ColValue(ColMinus, 50),
})
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{
- "Nums": ColValue(Col_Multiply, 3),
+ "Nums": ColValue(ColMultiply, 3),
})
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{
- "Nums": ColValue(Col_Except, 5),
+ "Nums": ColValue(ColExcept, 5),
})
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
@@ -1838,15 +1910,15 @@ func TestReadOrCreate(t *testing.T) {
throwFail(t, AssertIs(u.Status, 7))
throwFail(t, AssertIs(u.IsStaff, false))
throwFail(t, AssertIs(u.IsActive, true))
- throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), test_Date))
- throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), test_DateTime))
+ throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), testDate))
+ throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), testDateTime))
nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"}
created, pk, err = dORM.ReadOrCreate(nu, "UserName")
throwFail(t, err)
throwFail(t, AssertIs(created, false))
- throwFail(t, AssertIs(nu.Id, u.Id))
- throwFail(t, AssertIs(pk, u.Id))
+ throwFail(t, AssertIs(nu.ID, u.ID))
+ throwFail(t, AssertIs(pk, u.ID))
throwFail(t, AssertIs(nu.UserName, u.UserName))
throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above
throwFail(t, AssertIs(nu.Password, u.Password))
diff --git a/orm/qb.go b/orm/qb.go
index efe368db..9f778916 100644
--- a/orm/qb.go
+++ b/orm/qb.go
@@ -16,6 +16,7 @@ package orm
import "errors"
+// QueryBuilder is the Query builder interface
type QueryBuilder interface {
Select(fields ...string) QueryBuilder
From(tables ...string) QueryBuilder
@@ -43,15 +44,18 @@ type QueryBuilder interface {
String() string
}
+// NewQueryBuilder return the QueryBuilder
func NewQueryBuilder(driver string) (qb QueryBuilder, err error) {
if driver == "mysql" {
qb = new(MySQLQueryBuilder)
+ } else if driver == "tidb" {
+ qb = new(TiDBQueryBuilder)
} else if driver == "postgres" {
- err = errors.New("postgres query builder is not supported yet!")
+ err = errors.New("postgres query builder is not supported yet")
} else if driver == "sqlite" {
- err = errors.New("sqlite query builder is not supported yet!")
+ err = errors.New("sqlite query builder is not supported yet")
} else {
- err = errors.New("unknown driver for query builder!")
+ err = errors.New("unknown driver for query builder")
}
return
}
diff --git a/orm/qb_mysql.go b/orm/qb_mysql.go
index 9ce9b7d9..f6d1e185 100644
--- a/orm/qb_mysql.go
+++ b/orm/qb_mysql.go
@@ -20,134 +20,160 @@ import (
"strings"
)
-const COMMA_SPACE = ", "
+// CommaSpace is the seperation
+const CommaSpace = ", "
+// MySQLQueryBuilder is the SQL build
type MySQLQueryBuilder struct {
Tokens []string
}
+// Select will join the fields
func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder {
- qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, COMMA_SPACE))
+ qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace))
return qb
}
+// From join the tables
func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder {
- qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, COMMA_SPACE))
+ qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace))
return qb
}
+// InnerJoin INNER JOIN the table
func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INNER JOIN", table)
return qb
}
+// LeftJoin LEFT JOIN the table
func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LEFT JOIN", table)
return qb
}
+// RightJoin RIGHT JOIN the table
func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table)
return qb
}
+// On join with on cond
func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "ON", cond)
return qb
}
+// Where join the Where cond
func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "WHERE", cond)
return qb
}
+// And join the and cond
func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "AND", cond)
return qb
}
+// Or join the or cond
func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OR", cond)
return qb
}
+// In join the IN (vals)
func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder {
- qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, COMMA_SPACE), ")")
+ qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")")
return qb
}
+// OrderBy join the Order by fields
func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder {
- qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, COMMA_SPACE))
+ qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace))
return qb
}
+// Asc join the asc
func (qb *MySQLQueryBuilder) Asc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "ASC")
return qb
}
+// Desc join the desc
func (qb *MySQLQueryBuilder) Desc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "DESC")
return qb
}
+// Limit join the limit num
func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit))
return qb
}
+// Offset join the offset num
func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset))
return qb
}
+// GroupBy join the Group by fields
func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder {
- qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, COMMA_SPACE))
+ qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace))
return qb
}
+// Having join the Having cond
func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "HAVING", cond)
return qb
}
+// Update join the update table
func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder {
- qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, COMMA_SPACE))
+ qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace))
return qb
}
+// Set join the set kv
func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder {
- qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, COMMA_SPACE))
+ qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace))
return qb
}
+// Delete join the Delete tables
func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "DELETE")
if len(tables) != 0 {
- qb.Tokens = append(qb.Tokens, strings.Join(tables, COMMA_SPACE))
+ qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace))
}
return qb
}
+// InsertInto join the insert SQL
func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INSERT INTO", table)
if len(fields) != 0 {
- fieldsStr := strings.Join(fields, COMMA_SPACE)
+ fieldsStr := strings.Join(fields, CommaSpace)
qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")")
}
return qb
}
+// Values join the Values(vals)
func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder {
- valsStr := strings.Join(vals, COMMA_SPACE)
+ valsStr := strings.Join(vals, CommaSpace)
qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")")
return qb
}
+// Subquery join the sub as alias
func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string {
return fmt.Sprintf("(%s) AS %s", sub, alias)
}
+// String join all Tokens
func (qb *MySQLQueryBuilder) String() string {
return strings.Join(qb.Tokens, " ")
}
diff --git a/orm/qb_tidb.go b/orm/qb_tidb.go
new file mode 100644
index 00000000..c504049e
--- /dev/null
+++ b/orm/qb_tidb.go
@@ -0,0 +1,176 @@
+// Copyright 2015 TiDB Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package orm
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+)
+
+// TiDBQueryBuilder is the SQL build
+type TiDBQueryBuilder struct {
+ Tokens []string
+}
+
+// Select will join the fields
+func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace))
+ return qb
+}
+
+// From join the tables
+func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace))
+ return qb
+}
+
+// InnerJoin INNER JOIN the table
+func (qb *TiDBQueryBuilder) InnerJoin(table string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "INNER JOIN", table)
+ return qb
+}
+
+// LeftJoin LEFT JOIN the table
+func (qb *TiDBQueryBuilder) LeftJoin(table string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "LEFT JOIN", table)
+ return qb
+}
+
+// RightJoin RIGHT JOIN the table
+func (qb *TiDBQueryBuilder) RightJoin(table string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table)
+ return qb
+}
+
+// On join with on cond
+func (qb *TiDBQueryBuilder) On(cond string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "ON", cond)
+ return qb
+}
+
+// Where join the Where cond
+func (qb *TiDBQueryBuilder) Where(cond string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "WHERE", cond)
+ return qb
+}
+
+// And join the and cond
+func (qb *TiDBQueryBuilder) And(cond string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "AND", cond)
+ return qb
+}
+
+// Or join the or cond
+func (qb *TiDBQueryBuilder) Or(cond string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "OR", cond)
+ return qb
+}
+
+// In join the IN (vals)
+func (qb *TiDBQueryBuilder) In(vals ...string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")")
+ return qb
+}
+
+// OrderBy join the Order by fields
+func (qb *TiDBQueryBuilder) OrderBy(fields ...string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace))
+ return qb
+}
+
+// Asc join the asc
+func (qb *TiDBQueryBuilder) Asc() QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "ASC")
+ return qb
+}
+
+// Desc join the desc
+func (qb *TiDBQueryBuilder) Desc() QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "DESC")
+ return qb
+}
+
+// Limit join the limit num
+func (qb *TiDBQueryBuilder) Limit(limit int) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit))
+ return qb
+}
+
+// Offset join the offset num
+func (qb *TiDBQueryBuilder) Offset(offset int) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset))
+ return qb
+}
+
+// GroupBy join the Group by fields
+func (qb *TiDBQueryBuilder) GroupBy(fields ...string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace))
+ return qb
+}
+
+// Having join the Having cond
+func (qb *TiDBQueryBuilder) Having(cond string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "HAVING", cond)
+ return qb
+}
+
+// Update join the update table
+func (qb *TiDBQueryBuilder) Update(tables ...string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace))
+ return qb
+}
+
+// Set join the set kv
+func (qb *TiDBQueryBuilder) Set(kv ...string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace))
+ return qb
+}
+
+// Delete join the Delete tables
+func (qb *TiDBQueryBuilder) Delete(tables ...string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "DELETE")
+ if len(tables) != 0 {
+ qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace))
+ }
+ return qb
+}
+
+// InsertInto join the insert SQL
+func (qb *TiDBQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "INSERT INTO", table)
+ if len(fields) != 0 {
+ fieldsStr := strings.Join(fields, CommaSpace)
+ qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")")
+ }
+ return qb
+}
+
+// Values join the Values(vals)
+func (qb *TiDBQueryBuilder) Values(vals ...string) QueryBuilder {
+ valsStr := strings.Join(vals, CommaSpace)
+ qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")")
+ return qb
+}
+
+// Subquery join the sub as alias
+func (qb *TiDBQueryBuilder) Subquery(sub string, alias string) string {
+ return fmt.Sprintf("(%s) AS %s", sub, alias)
+}
+
+// String join all Tokens
+func (qb *TiDBQueryBuilder) String() string {
+ return strings.Join(qb.Tokens, " ")
+}
diff --git a/orm/types.go b/orm/types.go
index b46be4fc..5fac5fed 100644
--- a/orm/types.go
+++ b/orm/types.go
@@ -20,13 +20,13 @@ import (
"time"
)
-// database driver
+// Driver define database driver
type Driver interface {
Name() string
Type() DriverType
}
-// field info
+// Fielder define field info
type Fielder interface {
String() string
FieldType() int
@@ -34,84 +34,315 @@ type Fielder interface {
RawValue() interface{}
}
-// orm struct
+// Ormer define the orm interface
type Ormer interface {
- Read(interface{}, ...string) error
- ReadOrCreate(interface{}, string, ...string) (bool, int64, error)
+ // read data to model
+ // for example:
+ // this will find User by Id field
+ // u = &User{Id: user.Id}
+ // err = Ormer.Read(u)
+ // this will find User by UserName field
+ // u = &User{UserName: "astaxie", Password: "pass"}
+ // err = Ormer.Read(u, "UserName")
+ Read(md interface{}, cols ...string) error
+ // Try to read a row from the database, or insert one if it doesn't exist
+ ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error)
+ // insert model data to database
+ // for example:
+ // user := new(User)
+ // id, err = Ormer.Insert(user)
+ // user must a pointer and Insert will set user's pk field
Insert(interface{}) (int64, error)
- InsertMulti(int, interface{}) (int64, error)
- Update(interface{}, ...string) (int64, error)
- Delete(interface{}) (int64, error)
- LoadRelated(interface{}, string, ...interface{}) (int64, error)
- QueryM2M(interface{}, string) QueryM2Mer
- QueryTable(interface{}) QuerySeter
- Using(string) error
+ // insert some models to database
+ InsertMulti(bulk int, mds interface{}) (int64, error)
+ // update model to database.
+ // cols set the columns those want to update.
+ // find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns
+ // for example:
+ // user := User{Id: 2}
+ // user.Langs = append(user.Langs, "zh-CN", "en-US")
+ // user.Extra.Name = "beego"
+ // user.Extra.Data = "orm"
+ // num, err = Ormer.Update(&user, "Langs", "Extra")
+ Update(md interface{}, cols ...string) (int64, error)
+ // delete model in database
+ Delete(md interface{}) (int64, error)
+ // load related models to md model.
+ // args are limit, offset int and order string.
+ //
+ // example:
+ // Ormer.LoadRelated(post,"Tags")
+ // for _,tag := range post.Tags{...}
+ //args[0] bool true useDefaultRelsDepth ; false depth 0
+ //args[0] int loadRelationDepth
+ //args[1] int limit default limit 1000
+ //args[2] int offset default offset 0
+ //args[3] string order for example : "-Id"
+ // make sure the relation is defined in model struct tags.
+ LoadRelated(md interface{}, name string, args ...interface{}) (int64, error)
+ // create a models to models queryer
+ // for example:
+ // post := Post{Id: 4}
+ // m2m := Ormer.QueryM2M(&post, "Tags")
+ QueryM2M(md interface{}, name string) QueryM2Mer
+ // return a QuerySeter for table operations.
+ // table name can be string or struct.
+ // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
+ QueryTable(ptrStructOrTableName interface{}) QuerySeter
+ // switch to another registered database driver by given name.
+ Using(name string) error
+ // begin transaction
+ // for example:
+ // o := NewOrm()
+ // err := o.Begin()
+ // ...
+ // err = o.Rollback()
Begin() error
+ // commit transaction
Commit() error
+ // rollback transaction
Rollback() error
- Raw(string, ...interface{}) RawSeter
+ // return a raw query seter for raw sql string.
+ // for example:
+ // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec()
+ // // update user testing's name to slene
+ Raw(query string, args ...interface{}) RawSeter
Driver() Driver
}
-// insert prepared statement
+// Inserter insert prepared statement
type Inserter interface {
Insert(interface{}) (int64, error)
Close() error
}
-// query seter
+// QuerySeter query seter
type QuerySeter interface {
+ // add condition expression to QuerySeter.
+ // for example:
+ // filter by UserName == 'slene'
+ // qs.Filter("UserName", "slene")
+ // sql : left outer join profile on t0.id1==t1.id2 where t1.age == 28
+ // Filter("profile__Age", 28)
+ // // time compare
+ // qs.Filter("created", time.Now())
Filter(string, ...interface{}) QuerySeter
+ // add NOT condition to querySeter.
+ // have the same usage as Filter
Exclude(string, ...interface{}) QuerySeter
+ // set condition to QuerySeter.
+ // sql's where condition
+ // cond := orm.NewCondition()
+ // cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000)
+ // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
+ // num, err := qs.SetCond(cond1).Count()
SetCond(*Condition) QuerySeter
- Limit(interface{}, ...interface{}) QuerySeter
- Offset(interface{}) QuerySeter
- OrderBy(...string) QuerySeter
- RelatedSel(...interface{}) QuerySeter
+ // add LIMIT value.
+ // args[0] means offset, e.g. LIMIT num,offset.
+ // if Limit <= 0 then Limit will be set to default limit ,eg 1000
+ // if QuerySeter doesn't call Limit, the sql's Limit will be set to default limit, eg 1000
+ // for example:
+ // qs.Limit(10, 2)
+ // // sql-> limit 10 offset 2
+ Limit(limit interface{}, args ...interface{}) QuerySeter
+ // add OFFSET value
+ // same as Limit function's args[0]
+ Offset(offset interface{}) QuerySeter
+ // add ORDER expression.
+ // "column" means ASC, "-column" means DESC.
+ // for example:
+ // qs.OrderBy("-status")
+ OrderBy(exprs ...string) QuerySeter
+ // set relation model to query together.
+ // it will query relation models and assign to parent model.
+ // for example:
+ // // will load all related fields use left join .
+ // qs.RelatedSel().One(&user)
+ // // will load related field only profile
+ // qs.RelatedSel("profile").One(&user)
+ // user.Profile.Age = 32
+ RelatedSel(params ...interface{}) QuerySeter
+ // return QuerySeter execution result number
+ // for example:
+ // num, err = qs.Filter("profile__age__gt", 28).Count()
Count() (int64, error)
+ // check result empty or not after QuerySeter executed
+ // the same as QuerySeter.Count > 0
Exist() bool
- Update(Params) (int64, error)
+ // execute update with parameters
+ // for example:
+ // num, err = qs.Filter("user_name", "slene").Update(Params{
+ // "Nums": ColValue(Col_Minus, 50),
+ // }) // user slene's Nums will minus 50
+ // num, err = qs.Filter("UserName", "slene").Update(Params{
+ // "user_name": "slene2"
+ // }) // user slene's name will change to slene2
+ Update(values Params) (int64, error)
+ // delete from table
+ //for example:
+ // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete()
+ // //delete two user who's name is testing1 or testing2
Delete() (int64, error)
+ // return a insert queryer.
+ // it can be used in times.
+ // example:
+ // i,err := sq.PrepareInsert()
+ // num, err = i.Insert(&user1) // user table will add one record user1 at once
+ // num, err = i.Insert(&user2) // user table will add one record user2 at once
+ // err = i.Close() //don't forget call Close
PrepareInsert() (Inserter, error)
- All(interface{}, ...string) (int64, error)
- One(interface{}, ...string) error
- Values(*[]Params, ...string) (int64, error)
- ValuesList(*[]ParamsList, ...string) (int64, error)
- ValuesFlat(*ParamsList, string) (int64, error)
- RowsToMap(*Params, string, string) (int64, error)
- RowsToStruct(interface{}, string, string) (int64, error)
+ // query all data and map to containers.
+ // cols means the columns when querying.
+ // for example:
+ // var users []*User
+ // qs.All(&users) // users[0],users[1],users[2] ...
+ All(container interface{}, cols ...string) (int64, error)
+ // query one row data and map to containers.
+ // cols means the columns when querying.
+ // for example:
+ // var user User
+ // qs.One(&user) //user.UserName == "slene"
+ One(container interface{}, cols ...string) error
+ // query all data and map to []map[string]interface.
+ // expres means condition expression.
+ // it converts data to []map[column]value.
+ // for example:
+ // var maps []Params
+ // qs.Values(&maps) //maps[0]["UserName"]=="slene"
+ Values(results *[]Params, exprs ...string) (int64, error)
+ // query all data and map to [][]interface
+ // it converts data to [][column_index]value
+ // for example:
+ // var list []ParamsList
+ // qs.ValuesList(&list) // list[0][1] == "slene"
+ ValuesList(results *[]ParamsList, exprs ...string) (int64, error)
+ // query all data and map to []interface.
+ // it's designed for one column record set, auto change to []value, not [][column]value.
+ // for example:
+ // var list ParamsList
+ // qs.ValuesFlat(&list, "UserName") // list[0] == "slene"
+ ValuesFlat(result *ParamsList, expr string) (int64, error)
+ // query all rows into map[string]interface with specify key and value column name.
+ // keyCol = "name", valueCol = "value"
+ // table data
+ // name | value
+ // total | 100
+ // found | 200
+ // to map[string]interface{}{
+ // "total": 100,
+ // "found": 200,
+ // }
+ RowsToMap(result *Params, keyCol, valueCol string) (int64, error)
+ // query all rows into struct with specify key and value column name.
+ // keyCol = "name", valueCol = "value"
+ // table data
+ // name | value
+ // total | 100
+ // found | 200
+ // to struct {
+ // Total int
+ // Found int
+ // }
+ RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
}
-// model to model query struct
+// QueryM2Mer model to model query struct
+// all operations are on the m2m table only, will not affect the origin model table
type QueryM2Mer interface {
+ // add models to origin models when creating queryM2M.
+ // example:
+ // m2m := orm.QueryM2M(post,"Tag")
+ // m2m.Add(&Tag1{},&Tag2{})
+ // for _,tag := range post.Tags{}{ ... }
+ // param could also be any of the follow
+ // []*Tag{{Id:3,Name: "TestTag1"}, {Id:4,Name: "TestTag2"}}
+ // &Tag{Id:5,Name: "TestTag3"}
+ // []interface{}{&Tag{Id:6,Name: "TestTag4"}}
+ // insert one or more rows to m2m table
+ // make sure the relation is defined in post model struct tag.
Add(...interface{}) (int64, error)
+ // remove models following the origin model relationship
+ // only delete rows from m2m table
+ // for example:
+ //tag3 := &Tag{Id:5,Name: "TestTag3"}
+ //num, err = m2m.Remove(tag3)
Remove(...interface{}) (int64, error)
+ // check model is existed in relationship of origin model
Exist(interface{}) bool
+ // clean all models in related of origin model
Clear() (int64, error)
+ // count all related models of origin model
Count() (int64, error)
}
-// raw query statement
+// RawPreparer raw query statement
type RawPreparer interface {
Exec(...interface{}) (sql.Result, error)
Close() error
}
-// raw query seter
+// RawSeter raw query seter
+// create From Ormer.Raw
+// for example:
+// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q)
+// rs := Ormer.Raw(sql, 1)
type RawSeter interface {
+ //execute sql and get result
Exec() (sql.Result, error)
- QueryRow(...interface{}) error
- QueryRows(...interface{}) (int64, error)
+ //query data and map to container
+ //for example:
+ // var name string
+ // var id int
+ // rs.QueryRow(&id,&name) // id==2 name=="slene"
+ QueryRow(containers ...interface{}) error
+
+ // query data rows and map to container
+ // var ids []int
+ // var names []int
+ // query = fmt.Sprintf("SELECT 'id','name' FROM %suser%s", Q, Q)
+ // num, err = dORM.Raw(query).QueryRows(&ids,&names) // ids=>{1,2},names=>{"nobody","slene"}
+ QueryRows(containers ...interface{}) (int64, error)
SetArgs(...interface{}) RawSeter
- Values(*[]Params, ...string) (int64, error)
- ValuesList(*[]ParamsList, ...string) (int64, error)
- ValuesFlat(*ParamsList, ...string) (int64, error)
- RowsToMap(*Params, string, string) (int64, error)
- RowsToStruct(interface{}, string, string) (int64, error)
+ // query data to []map[string]interface
+ // see QuerySeter's Values
+ Values(container *[]Params, cols ...string) (int64, error)
+ // query data to [][]interface
+ // see QuerySeter's ValuesList
+ ValuesList(container *[]ParamsList, cols ...string) (int64, error)
+ // query data to []interface
+ // see QuerySeter's ValuesFlat
+ ValuesFlat(container *ParamsList, cols ...string) (int64, error)
+ // query all rows into map[string]interface with specify key and value column name.
+ // keyCol = "name", valueCol = "value"
+ // table data
+ // name | value
+ // total | 100
+ // found | 200
+ // to map[string]interface{}{
+ // "total": 100,
+ // "found": 200,
+ // }
+ RowsToMap(result *Params, keyCol, valueCol string) (int64, error)
+ // query all rows into struct with specify key and value column name.
+ // keyCol = "name", valueCol = "value"
+ // table data
+ // name | value
+ // total | 100
+ // found | 200
+ // to struct {
+ // Total int
+ // Found int
+ // }
+ RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
+
+ // return prepared raw statement for used in times.
+ // for example:
+ // pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare()
+ // r, err := pre.Exec("name1") // INSERT INTO tag (name) VALUES (`name1`)
Prepare() (RawPreparer, error)
}
-// statement querier
+// stmtQuerier statement querier
type stmtQuerier interface {
Close() error
Exec(args ...interface{}) (sql.Result, error)
@@ -160,8 +391,8 @@ type dbBaser interface {
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
- OperatorSql(string) string
- GenerateOperatorSql(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
+ OperatorSQL(string) string
+ GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
GenerateOperatorLeftCol(*fieldInfo, string, *string)
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
diff --git a/orm/utils.go b/orm/utils.go
index 88168763..99437c7b 100644
--- a/orm/utils.go
+++ b/orm/utils.go
@@ -22,9 +22,10 @@ import (
"time"
)
+// StrTo is the target string
type StrTo string
-// set string
+// Set string
func (f *StrTo) Set(v string) {
if v != "" {
*f = StrTo(v)
@@ -33,93 +34,93 @@ func (f *StrTo) Set(v string) {
}
}
-// clean string
+// Clear string
func (f *StrTo) Clear() {
*f = StrTo(0x1E)
}
-// check string exist
+// Exist check string exist
func (f StrTo) Exist() bool {
return string(f) != string(0x1E)
}
-// string to bool
+// Bool string to bool
func (f StrTo) Bool() (bool, error) {
return strconv.ParseBool(f.String())
}
-// string to float32
+// Float32 string to float32
func (f StrTo) Float32() (float32, error) {
v, err := strconv.ParseFloat(f.String(), 32)
return float32(v), err
}
-// string to float64
+// Float64 string to float64
func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64)
}
-// string to int
+// Int string to int
func (f StrTo) Int() (int, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int(v), err
}
-// string to int8
+// Int8 string to int8
func (f StrTo) Int8() (int8, error) {
v, err := strconv.ParseInt(f.String(), 10, 8)
return int8(v), err
}
-// string to int16
+// Int16 string to int16
func (f StrTo) Int16() (int16, error) {
v, err := strconv.ParseInt(f.String(), 10, 16)
return int16(v), err
}
-// string to int32
+// Int32 string to int32
func (f StrTo) Int32() (int32, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int32(v), err
}
-// string to int64
+// Int64 string to int64
func (f StrTo) Int64() (int64, error) {
v, err := strconv.ParseInt(f.String(), 10, 64)
return int64(v), err
}
-// string to uint
+// Uint string to uint
func (f StrTo) Uint() (uint, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint(v), err
}
-// string to uint8
+// Uint8 string to uint8
func (f StrTo) Uint8() (uint8, error) {
v, err := strconv.ParseUint(f.String(), 10, 8)
return uint8(v), err
}
-// string to uint16
+// Uint16 string to uint16
func (f StrTo) Uint16() (uint16, error) {
v, err := strconv.ParseUint(f.String(), 10, 16)
return uint16(v), err
}
-// string to uint31
+// Uint32 string to uint31
func (f StrTo) Uint32() (uint32, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint32(v), err
}
-// string to uint64
+// Uint64 string to uint64
func (f StrTo) Uint64() (uint64, error) {
v, err := strconv.ParseUint(f.String(), 10, 64)
return uint64(v), err
}
-// string to string
+// String string to string
func (f StrTo) String() string {
if f.Exist() {
return string(f)
@@ -127,7 +128,7 @@ func (f StrTo) String() string {
return ""
}
-// interface to string
+// ToStr interface to string
func ToStr(value interface{}, args ...int) (s string) {
switch v := value.(type) {
case bool:
@@ -166,7 +167,7 @@ func ToStr(value interface{}, args ...int) (s string) {
return s
}
-// interface to int64
+// ToInt64 interface to int64
func ToInt64(value interface{}) (d int64) {
val := reflect.ValueOf(value)
switch value.(type) {
@@ -248,30 +249,12 @@ func (a argInt) Get(i int, args ...int) (r int) {
return
}
-type argAny []interface{}
-
-// get interface by index from interface slice
-func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
- if i >= 0 && i < len(a) {
- r = a[i]
- }
- if len(args) > 0 {
- r = args[0]
- }
- return
-}
-
// parse time to string with location
func timeParse(dateString, format string) (time.Time, error) {
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
return tp, err
}
-// format time string
-func timeFormat(t time.Time, format string) string {
- return t.Format(format)
-}
-
// get pointer indirect type
func indirectType(v reflect.Type) reflect.Type {
switch v.Kind() {
diff --git a/parser.go b/parser.go
index cf8dee22..b14d74b9 100644
--- a/parser.go
+++ b/parser.go
@@ -42,13 +42,13 @@ func init() {
`
var (
- lastupdateFilename string = "lastupdate.tmp"
+ lastupdateFilename = "lastupdate.tmp"
commentFilename string
pkgLastupdate map[string]int64
genInfoList map[string][]ControllerComments
)
-const COMMENTFL = "commentsRouter_"
+const coomentPrefix = "commentsRouter_"
func init() {
pkgLastupdate = make(map[string]int64)
@@ -56,7 +56,7 @@ func init() {
func parserPkg(pkgRealpath, pkgpath string) error {
rep := strings.NewReplacer("/", "_", ".", "_")
- commentFilename = COMMENTFL + rep.Replace(pkgpath) + ".go"
+ commentFilename = coomentPrefix + rep.Replace(pkgpath) + ".go"
if !compareFile(pkgRealpath) {
Info(pkgRealpath + " has not changed, not reloading")
return nil
@@ -77,7 +77,10 @@ func parserPkg(pkgRealpath, pkgpath string) error {
switch specDecl := d.(type) {
case *ast.FuncDecl:
if specDecl.Recv != nil {
- parserComments(specDecl.Doc, specDecl.Name.String(), fmt.Sprint(specDecl.Recv.List[0].Type.(*ast.StarExpr).X), pkgpath)
+ exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser
+ if ok {
+ parserComments(specDecl.Doc, specDecl.Name.String(), fmt.Sprint(exp.X), pkgpath)
+ }
}
}
}
@@ -127,11 +130,13 @@ func parserComments(comments *ast.CommentGroup, funcName, controllerName, pkgpat
}
func genRouterCode() {
- os.Mkdir(path.Join(workPath, "routers"), 0755)
+ os.Mkdir("routers", 0755)
Info("generate router from comments")
- var globalinfo string
- sortKey := make([]string, 0)
- for k, _ := range genInfoList {
+ var (
+ globalinfo string
+ sortKey []string
+ )
+ for k := range genInfoList {
sortKey = append(sortKey, k)
}
sort.Strings(sortKey)
@@ -167,7 +172,7 @@ func genRouterCode() {
}
}
if globalinfo != "" {
- f, err := os.Create(path.Join(workPath, "routers", commentFilename))
+ f, err := os.Create(path.Join("routers", commentFilename))
if err != nil {
panic(err)
}
@@ -177,11 +182,11 @@ func genRouterCode() {
}
func compareFile(pkgRealpath string) bool {
- if !utils.FileExists(path.Join(workPath, "routers", commentFilename)) {
+ if !utils.FileExists(path.Join("routers", commentFilename)) {
return true
}
- if utils.FileExists(path.Join(workPath, lastupdateFilename)) {
- content, err := ioutil.ReadFile(path.Join(workPath, lastupdateFilename))
+ if utils.FileExists(lastupdateFilename) {
+ content, err := ioutil.ReadFile(lastupdateFilename)
if err != nil {
return true
}
@@ -209,7 +214,7 @@ func savetoFile(pkgRealpath string) {
if err != nil {
return
}
- ioutil.WriteFile(path.Join(workPath, lastupdateFilename), d, os.ModePerm)
+ ioutil.WriteFile(lastupdateFilename, d, os.ModePerm)
}
func getpathTime(pkgRealpath string) (lastupdate int64, err error) {
diff --git a/plugins/apiauth/apiauth.go b/plugins/apiauth/apiauth.go
index bbae7def..3091c698 100644
--- a/plugins/apiauth/apiauth.go
+++ b/plugins/apiauth/apiauth.go
@@ -33,7 +33,7 @@
// // maybe store in configure, maybe in database
// }
//
-// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APIAuthWithFunc(getAppSecret, 360))
+// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APISecretAuth(getAppSecret, 360))
//
// Infomation:
//
@@ -68,8 +68,10 @@ import (
"github.com/astaxie/beego/context"
)
-type AppIdToAppSecret func(string) string
+// AppIDToAppSecret is used to get appsecret throw appid
+type AppIDToAppSecret func(string) string
+// APIBaiscAuth use the basic appid/appkey as the AppIdToAppSecret
func APIBaiscAuth(appid, appkey string) beego.FilterFunc {
ft := func(aid string) string {
if aid == appid {
@@ -77,52 +79,54 @@ func APIBaiscAuth(appid, appkey string) beego.FilterFunc {
}
return ""
}
- return APIAuthWithFunc(ft, 300)
+ return APISecretAuth(ft, 300)
}
-func APIAuthWithFunc(f AppIdToAppSecret, timeout int) beego.FilterFunc {
+// APISecretAuth use AppIdToAppSecret verify and
+func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc {
return func(ctx *context.Context) {
if ctx.Input.Query("appid") == "" {
- ctx.Output.SetStatus(403)
+ ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("miss query param: appid")
return
}
appsecret := f(ctx.Input.Query("appid"))
if appsecret == "" {
- ctx.Output.SetStatus(403)
+ ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("not exist this appid")
return
}
if ctx.Input.Query("signature") == "" {
- ctx.Output.SetStatus(403)
+ ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("miss query param: signature")
return
}
if ctx.Input.Query("timestamp") == "" {
- ctx.Output.SetStatus(403)
+ ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("miss query param: timestamp")
return
}
u, err := time.Parse("2006-01-02 15:04:05", ctx.Input.Query("timestamp"))
if err != nil {
- ctx.Output.SetStatus(403)
+ ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("timestamp format is error, should 2006-01-02 15:04:05")
return
}
t := time.Now()
if t.Sub(u).Seconds() > float64(timeout) {
- ctx.Output.SetStatus(403)
+ ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("timeout! the request time is long ago, please try again")
return
}
if ctx.Input.Query("signature") !=
- Signature(appsecret, ctx.Input.Method(), ctx.Request.Form, ctx.Input.Uri()) {
- ctx.Output.SetStatus(403)
+ Signature(appsecret, ctx.Input.Method(), ctx.Request.Form, ctx.Input.URI()) {
+ ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("auth failed")
}
}
}
+// Signature used to generate signature with the appsecret/method/params/RequestURI
func Signature(appsecret, method string, params url.Values, RequestURI string) (result string) {
var query string
pa := make(map[string]string)
@@ -139,11 +143,11 @@ func Signature(appsecret, method string, params url.Values, RequestURI string) (
query = fmt.Sprintf("%v%v%v", query, vs.Keys[i], vs.Vals[i])
}
}
- string_to_sign := fmt.Sprintf("%v\n%v\n%v\n", method, query, RequestURI)
+ stringToSign := fmt.Sprintf("%v\n%v\n%v\n", method, query, RequestURI)
sha256 := sha256.New
hash := hmac.New(sha256, []byte(appsecret))
- hash.Write([]byte(string_to_sign))
+ hash.Write([]byte(stringToSign))
return base64.StdEncoding.EncodeToString(hash.Sum(nil))
}
diff --git a/plugins/auth/basic.go b/plugins/auth/basic.go
index 946b8457..c478044a 100644
--- a/plugins/auth/basic.go
+++ b/plugins/auth/basic.go
@@ -46,6 +46,7 @@ import (
var defaultRealm = "Authorization Required"
+// Basic is the http basic auth
func Basic(username string, password string) beego.FilterFunc {
secrets := func(user, pass string) bool {
return user == username && pass == password
@@ -53,6 +54,7 @@ func Basic(username string, password string) beego.FilterFunc {
return NewBasicAuthenticator(secrets, defaultRealm)
}
+// NewBasicAuthenticator return the BasicAuth
func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFunc {
return func(ctx *context.Context) {
a := &BasicAuth{Secrets: secrets, Realm: Realm}
@@ -62,17 +64,19 @@ func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFun
}
}
+// SecretProvider is the SecretProvider function
type SecretProvider func(user, pass string) bool
+// BasicAuth store the SecretProvider and Realm
type BasicAuth struct {
Secrets SecretProvider
Realm string
}
-//Checks the username/password combination from the request. Returns
-//either an empty string (authentication failed) or the name of the
-//authenticated user.
-//Supports MD5 and SHA1 password entries
+// CheckAuth Checks the username/password combination from the request. Returns
+// either an empty string (authentication failed) or the name of the
+// authenticated user.
+// Supports MD5 and SHA1 password entries
func (a *BasicAuth) CheckAuth(r *http.Request) string {
s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
if len(s) != 2 || s[0] != "Basic" {
@@ -94,8 +98,8 @@ func (a *BasicAuth) CheckAuth(r *http.Request) string {
return ""
}
-//http.Handler for BasicAuth which initiates the authentication process
-//(or requires reauthentication).
+// RequireAuth http.Handler for BasicAuth which initiates the authentication process
+// (or requires reauthentication).
func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("WWW-Authenticate", `Basic realm="`+a.Realm+`"`)
w.WriteHeader(401)
diff --git a/plugins/cors/cors.go b/plugins/cors/cors.go
index 052d3bc6..1e973a40 100644
--- a/plugins/cors/cors.go
+++ b/plugins/cors/cors.go
@@ -24,7 +24,7 @@
// // - PUT and PATCH methods
// // - Origin header
// // - Credentials share
-// beego.InsertFilter("*", beego.BeforeRouter,cors.Allow(&cors.Options{
+// beego.InsertFilter("*", beego.BeforeRouter, cors.Allow(&cors.Options{
// AllowOrigins: []string{"https://*.foo.com"},
// AllowMethods: []string{"PUT", "PATCH"},
// AllowHeaders: []string{"Origin"},
@@ -36,7 +36,6 @@
package cors
import (
- "net/http"
"regexp"
"strconv"
"strings"
@@ -216,8 +215,6 @@ func Allow(opts *Options) beego.FilterFunc {
for key, value := range headers {
ctx.Output.Header(key, value)
}
- ctx.Output.SetStatus(http.StatusOK)
- ctx.WriteString("")
return
}
headers = opts.Header(origin)
diff --git a/plugins/cors/cors_test.go b/plugins/cors/cors_test.go
index 5c02ab98..34039143 100644
--- a/plugins/cors/cors_test.go
+++ b/plugins/cors/cors_test.go
@@ -25,21 +25,23 @@ import (
"github.com/astaxie/beego/context"
)
-type HttpHeaderGuardRecorder struct {
+// HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header
+type HTTPHeaderGuardRecorder struct {
*httptest.ResponseRecorder
savedHeaderMap http.Header
}
-func NewRecorder() *HttpHeaderGuardRecorder {
- return &HttpHeaderGuardRecorder{httptest.NewRecorder(), nil}
+// NewRecorder return HttpHeaderGuardRecorder
+func NewRecorder() *HTTPHeaderGuardRecorder {
+ return &HTTPHeaderGuardRecorder{httptest.NewRecorder(), nil}
}
-func (gr *HttpHeaderGuardRecorder) WriteHeader(code int) {
+func (gr *HTTPHeaderGuardRecorder) WriteHeader(code int) {
gr.ResponseRecorder.WriteHeader(code)
gr.savedHeaderMap = gr.ResponseRecorder.Header()
}
-func (gr *HttpHeaderGuardRecorder) Header() http.Header {
+func (gr *HTTPHeaderGuardRecorder) Header() http.Header {
if gr.savedHeaderMap != nil {
// headers were written. clone so we don't get updates
clone := make(http.Header)
@@ -47,9 +49,8 @@ func (gr *HttpHeaderGuardRecorder) Header() http.Header {
clone[k] = v
}
return clone
- } else {
- return gr.ResponseRecorder.Header()
}
+ return gr.ResponseRecorder.Header()
}
func Test_AllowAll(t *testing.T) {
@@ -219,13 +220,13 @@ func Test_Preflight(t *testing.T) {
func Benchmark_WithoutCORS(b *testing.B) {
recorder := httptest.NewRecorder()
handler := beego.NewControllerRegister()
- beego.RunMode = "prod"
+ beego.BConfig.RunMode = beego.PROD
handler.Any("/foo", func(ctx *context.Context) {
ctx.Output.SetStatus(500)
})
b.ResetTimer()
- for i := 0; i < 100; i++ {
- r, _ := http.NewRequest("PUT", "/foo", nil)
+ r, _ := http.NewRequest("PUT", "/foo", nil)
+ for i := 0; i < b.N; i++ {
handler.ServeHTTP(recorder, r)
}
}
@@ -233,7 +234,7 @@ func Benchmark_WithoutCORS(b *testing.B) {
func Benchmark_WithCORS(b *testing.B) {
recorder := httptest.NewRecorder()
handler := beego.NewControllerRegister()
- beego.RunMode = "prod"
+ beego.BConfig.RunMode = beego.PROD
handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
AllowAllOrigins: true,
AllowCredentials: true,
@@ -245,8 +246,8 @@ func Benchmark_WithCORS(b *testing.B) {
ctx.Output.SetStatus(500)
})
b.ResetTimer()
- for i := 0; i < 100; i++ {
- r, _ := http.NewRequest("PUT", "/foo", nil)
+ r, _ := http.NewRequest("PUT", "/foo", nil)
+ for i := 0; i < b.N; i++ {
handler.ServeHTTP(recorder, r)
}
}
diff --git a/router.go b/router.go
index 3e1ebab3..9d8cd2e6 100644
--- a/router.go
+++ b/router.go
@@ -15,10 +15,7 @@
package beego
import (
- "bufio"
- "errors"
"fmt"
- "net"
"net/http"
"os"
"path"
@@ -27,6 +24,7 @@ import (
"runtime"
"strconv"
"strings"
+ "sync"
"time"
beecontext "github.com/astaxie/beego/context"
@@ -34,8 +32,8 @@ import (
"github.com/astaxie/beego/utils"
)
+// default filter execution points
const (
- // default filter execution points
BeforeStatic = iota
BeforeRouter
BeforeExec
@@ -50,7 +48,7 @@ const (
)
var (
- // supported http methods.
+ // HTTPMETHOD list the supported http methods.
HTTPMETHOD = map[string]string{
"GET": "GET",
"POST": "POST",
@@ -71,10 +69,12 @@ var (
"SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml",
"GetControllerAndAction"}
- url_placeholder = "{{placeholder}}"
- DefaultLogFilter FilterHandler = &logFilter{}
+ urlPlaceholder = "{{placeholder}}"
+ // DefaultAccessLogFilter will skip the accesslog if return true
+ DefaultAccessLogFilter FilterHandler = &logFilter{}
)
+// FilterHandler is an interface for
type FilterHandler interface {
Filter(*beecontext.Context) bool
}
@@ -84,11 +84,11 @@ type logFilter struct {
}
func (l *logFilter) Filter(ctx *beecontext.Context) bool {
- requestPath := path.Clean(ctx.Input.Request.URL.Path)
+ requestPath := path.Clean(ctx.Request.URL.Path)
if requestPath == "/favicon.ico" || requestPath == "/robots.txt" {
return true
}
- for prefix := range StaticDir {
+ for prefix := range BConfig.WebConfig.StaticDir {
if strings.HasPrefix(requestPath, prefix) {
return true
}
@@ -96,7 +96,7 @@ func (l *logFilter) Filter(ctx *beecontext.Context) bool {
return false
}
-// To append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter
+// ExceptMethodAppend to append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter
func ExceptMethodAppend(action string) {
exceptMethod = append(exceptMethod, action)
}
@@ -106,26 +106,31 @@ type controllerInfo struct {
controllerType reflect.Type
methods map[string]string
handler http.Handler
- runfunction FilterFunc
+ runFunction FilterFunc
routerType int
}
-// ControllerRegistor containers registered router rules, controller handlers and filters.
-type ControllerRegistor struct {
+// ControllerRegister containers registered router rules, controller handlers and filters.
+type ControllerRegister struct {
routers map[string]*Tree
enableFilter bool
filters map[int][]*FilterRouter
+ pool sync.Pool
}
-// NewControllerRegister returns a new ControllerRegistor.
-func NewControllerRegister() *ControllerRegistor {
- return &ControllerRegistor{
+// NewControllerRegister returns a new ControllerRegister.
+func NewControllerRegister() *ControllerRegister {
+ cr := &ControllerRegister{
routers: make(map[string]*Tree),
filters: make(map[int][]*FilterRouter),
}
+ cr.pool.New = func() interface{} {
+ return beecontext.NewContext()
+ }
+ return cr
}
-// Add controller handler and pattern rules to ControllerRegistor.
+// Add controller handler and pattern rules to ControllerRegister.
// usage:
// default methods is the same name as method
// Add("/user",&UserController{})
@@ -133,9 +138,9 @@ func NewControllerRegister() *ControllerRegistor {
// Add("/api/create",&RestController{},"post:CreateFood")
// Add("/api/update",&RestController{},"put:UpdateFood")
// Add("/api/delete",&RestController{},"delete:DeleteFood")
-// Add("/api",&RestController{},"get,post:ApiFunc")
+// Add("/api",&RestController{},"get,post:ApiFunc"
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
-func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
+func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
reflectVal := reflect.ValueOf(c)
t := reflect.Indirect(reflectVal).Type()
methods := make(map[string]string)
@@ -183,8 +188,8 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
}
}
-func (p *ControllerRegistor) addToRouter(method, pattern string, r *controllerInfo) {
- if !RouterCaseSensitive {
+func (p *ControllerRegister) addToRouter(method, pattern string, r *controllerInfo) {
+ if !BConfig.RouterCaseSensitive {
pattern = strings.ToLower(pattern)
}
if t, ok := p.routers[method]; ok {
@@ -196,10 +201,10 @@ func (p *ControllerRegistor) addToRouter(method, pattern string, r *controllerIn
}
}
-// only when the Runmode is dev will generate router file in the router/auto.go from the controller
+// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller
// Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{})
-func (p *ControllerRegistor) Include(cList ...ControllerInterface) {
- if RunMode == "dev" {
+func (p *ControllerRegister) Include(cList ...ControllerInterface) {
+ if BConfig.RunMode == DEV {
skip := make(map[string]bool, 10)
for _, c := range cList {
reflectVal := reflect.ValueOf(c)
@@ -238,91 +243,91 @@ func (p *ControllerRegistor) Include(cList ...ControllerInterface) {
}
}
-// add get method
+// Get add get method
// usage:
// Get("/", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
-func (p *ControllerRegistor) Get(pattern string, f FilterFunc) {
+func (p *ControllerRegister) Get(pattern string, f FilterFunc) {
p.AddMethod("get", pattern, f)
}
-// add post method
+// Post add post method
// usage:
// Post("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
-func (p *ControllerRegistor) Post(pattern string, f FilterFunc) {
+func (p *ControllerRegister) Post(pattern string, f FilterFunc) {
p.AddMethod("post", pattern, f)
}
-// add put method
+// Put add put method
// usage:
// Put("/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
-func (p *ControllerRegistor) Put(pattern string, f FilterFunc) {
+func (p *ControllerRegister) Put(pattern string, f FilterFunc) {
p.AddMethod("put", pattern, f)
}
-// add delete method
+// Delete add delete method
// usage:
// Delete("/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
-func (p *ControllerRegistor) Delete(pattern string, f FilterFunc) {
+func (p *ControllerRegister) Delete(pattern string, f FilterFunc) {
p.AddMethod("delete", pattern, f)
}
-// add head method
+// Head add head method
// usage:
// Head("/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
-func (p *ControllerRegistor) Head(pattern string, f FilterFunc) {
+func (p *ControllerRegister) Head(pattern string, f FilterFunc) {
p.AddMethod("head", pattern, f)
}
-// add patch method
+// Patch add patch method
// usage:
// Patch("/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
-func (p *ControllerRegistor) Patch(pattern string, f FilterFunc) {
+func (p *ControllerRegister) Patch(pattern string, f FilterFunc) {
p.AddMethod("patch", pattern, f)
}
-// add options method
+// Options add options method
// usage:
// Options("/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
-func (p *ControllerRegistor) Options(pattern string, f FilterFunc) {
+func (p *ControllerRegister) Options(pattern string, f FilterFunc) {
p.AddMethod("options", pattern, f)
}
-// add all method
+// Any add all method
// usage:
// Any("/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
-func (p *ControllerRegistor) Any(pattern string, f FilterFunc) {
+func (p *ControllerRegister) Any(pattern string, f FilterFunc) {
p.AddMethod("*", pattern, f)
}
-// add http method router
+// AddMethod add http method router
// usage:
// AddMethod("get","/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
-func (p *ControllerRegistor) AddMethod(method, pattern string, f FilterFunc) {
+func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
if _, ok := HTTPMETHOD[strings.ToUpper(method)]; method != "*" && !ok {
panic("not support http method: " + method)
}
route := &controllerInfo{}
route.pattern = pattern
route.routerType = routerTypeRESTFul
- route.runfunction = f
+ route.runFunction = f
methods := make(map[string]string)
if method == "*" {
for _, val := range HTTPMETHOD {
@@ -343,15 +348,15 @@ func (p *ControllerRegistor) AddMethod(method, pattern string, f FilterFunc) {
}
}
-// add user defined Handler
-func (p *ControllerRegistor) Handler(pattern string, h http.Handler, options ...interface{}) {
+// Handler add user defined Handler
+func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) {
route := &controllerInfo{}
route.pattern = pattern
route.routerType = routerTypeHandler
route.handler = h
if len(options) > 0 {
if _, ok := options[0].(bool); ok {
- pattern = path.Join(pattern, "?:all")
+ pattern = path.Join(pattern, "?:all(.*)")
}
}
for _, m := range HTTPMETHOD {
@@ -359,21 +364,21 @@ func (p *ControllerRegistor) Handler(pattern string, h http.Handler, options ...
}
}
-// Add auto router to ControllerRegistor.
+// AddAuto router to ControllerRegister.
// example beego.AddAuto(&MainContorlller{}),
// MainController has method List and Page.
// visit the url /main/list to execute List function
// /main/page to execute Page function.
-func (p *ControllerRegistor) AddAuto(c ControllerInterface) {
+func (p *ControllerRegister) AddAuto(c ControllerInterface) {
p.AddAutoPrefix("/", c)
}
-// Add auto router to ControllerRegistor with prefix.
+// AddAutoPrefix Add auto router to ControllerRegister with prefix.
// example beego.AddAutoPrefix("/admin",&MainContorlller{}),
// MainController has method List and Page.
// visit the url /admin/main/list to execute List function
// /admin/main/page to execute Page function.
-func (p *ControllerRegistor) AddAutoPrefix(prefix string, c ControllerInterface) {
+func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) {
reflectVal := reflect.ValueOf(c)
rt := reflectVal.Type()
ct := reflect.Indirect(reflectVal).Type()
@@ -386,28 +391,28 @@ func (p *ControllerRegistor) AddAutoPrefix(prefix string, c ControllerInterface)
route.controllerType = ct
pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*")
patternInit := path.Join(prefix, controllerName, rt.Method(i).Name, "*")
- patternfix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name))
- patternfixInit := path.Join(prefix, controllerName, rt.Method(i).Name)
+ patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name))
+ patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name)
route.pattern = pattern
for _, m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
p.addToRouter(m, patternInit, route)
- p.addToRouter(m, patternfix, route)
- p.addToRouter(m, patternfixInit, route)
+ p.addToRouter(m, patternFix, route)
+ p.addToRouter(m, patternFixInit, route)
}
}
}
}
-// Add a FilterFunc with pattern rule and action constant.
+// InsertFilter Add a FilterFunc with pattern rule and action constant.
// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute)
-func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error {
+func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error {
mr := new(FilterRouter)
mr.tree = NewTree()
mr.pattern = pattern
mr.filterFunc = filter
- if !RouterCaseSensitive {
+ if !BConfig.RouterCaseSensitive {
pattern = strings.ToLower(pattern)
}
if len(params) == 0 {
@@ -420,15 +425,15 @@ func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter Filter
}
// add Filter into
-func (p *ControllerRegistor) insertFilterRouter(pos int, mr *FilterRouter) error {
+func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) error {
p.filters[pos] = append(p.filters[pos], mr)
p.enableFilter = true
return nil
}
-// UrlFor does another controller handler in this request function.
+// URLFor does another controller handler in this request function.
// it can access any controller method.
-func (p *ControllerRegistor) UrlFor(endpoint string, values ...interface{}) string {
+func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string {
paths := strings.Split(endpoint, ".")
if len(paths) <= 1 {
Warn("urlfor endpoint must like path.controller.method")
@@ -460,16 +465,16 @@ func (p *ControllerRegistor) UrlFor(endpoint string, values ...interface{}) stri
return ""
}
-func (p *ControllerRegistor) geturl(t *Tree, url, controllName, methodName string, params map[string]string, httpMethod string) (bool, string) {
- for k, subtree := range t.fixrouters {
- u := path.Join(url, k)
+func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName string, params map[string]string, httpMethod string) (bool, string) {
+ for _, subtree := range t.fixrouters {
+ u := path.Join(url, subtree.prefix)
ok, u := p.geturl(subtree, u, controllName, methodName, params, httpMethod)
if ok {
return ok, u
}
}
if t.wildcard != nil {
- u := path.Join(url, url_placeholder)
+ u := path.Join(url, urlPlaceholder)
ok, u := p.geturl(t.wildcard, u, controllName, methodName, params, httpMethod)
if ok {
return ok, u
@@ -499,22 +504,21 @@ func (p *ControllerRegistor) geturl(t *Tree, url, controllName, methodName strin
if find {
if l.regexps == nil {
if len(l.wildcards) == 0 {
- return true, strings.Replace(url, "/"+url_placeholder, "", 1) + tourl(params)
+ return true, strings.Replace(url, "/"+urlPlaceholder, "", 1) + toUrl(params)
}
if len(l.wildcards) == 1 {
if v, ok := params[l.wildcards[0]]; ok {
delete(params, l.wildcards[0])
- return true, strings.Replace(url, url_placeholder, v, 1) + tourl(params)
- } else {
- return false, ""
+ return true, strings.Replace(url, urlPlaceholder, v, 1) + toUrl(params)
}
+ return false, ""
}
if len(l.wildcards) == 3 && l.wildcards[0] == "." {
if p, ok := params[":path"]; ok {
if e, isok := params[":ext"]; isok {
delete(params, ":path")
delete(params, ":ext")
- return true, strings.Replace(url, url_placeholder, p+"."+e, -1) + tourl(params)
+ return true, strings.Replace(url, urlPlaceholder, p+"."+e, -1) + toUrl(params)
}
}
}
@@ -526,45 +530,43 @@ func (p *ControllerRegistor) geturl(t *Tree, url, controllName, methodName strin
}
if u, ok := params[v]; ok {
delete(params, v)
- url = strings.Replace(url, url_placeholder, u, 1)
+ url = strings.Replace(url, urlPlaceholder, u, 1)
} else {
if canskip {
canskip = false
continue
- } else {
- return false, ""
}
+ return false, ""
}
}
- return true, url + tourl(params)
- } else {
- var i int
- var startreg bool
- regurl := ""
- for _, v := range strings.Trim(l.regexps.String(), "^$") {
- if v == '(' {
- startreg = true
- continue
- } else if v == ')' {
- startreg = false
- if v, ok := params[l.wildcards[i]]; ok {
- delete(params, l.wildcards[i])
- regurl = regurl + v
- i++
- } else {
- break
- }
- } else if !startreg {
- regurl = string(append([]rune(regurl), v))
+ return true, url + toUrl(params)
+ }
+ var i int
+ var startreg bool
+ regurl := ""
+ for _, v := range strings.Trim(l.regexps.String(), "^$") {
+ if v == '(' {
+ startreg = true
+ continue
+ } else if v == ')' {
+ startreg = false
+ if v, ok := params[l.wildcards[i]]; ok {
+ delete(params, l.wildcards[i])
+ regurl = regurl + v
+ i++
+ } else {
+ break
}
+ } else if !startreg {
+ regurl = string(append([]rune(regurl), v))
}
- if l.regexps.MatchString(regurl) {
- ps := strings.Split(regurl, "/")
- for _, p := range ps {
- url = strings.Replace(url, url_placeholder, p, 1)
- }
- return true, url + tourl(params)
+ }
+ if l.regexps.MatchString(regurl) {
+ ps := strings.Split(regurl, "/")
+ for _, p := range ps {
+ url = strings.Replace(url, urlPlaceholder, p, 1)
}
+ return true, url + toUrl(params)
}
}
}
@@ -574,168 +576,137 @@ func (p *ControllerRegistor) geturl(t *Tree, url, controllName, methodName strin
return false, ""
}
+func (p *ControllerRegister) execFilter(context *beecontext.Context, pos int, urlPath string) (started bool) {
+ if p.enableFilter {
+ if l, ok := p.filters[pos]; ok {
+ for _, filterR := range l {
+ if filterR.returnOnOutput && context.ResponseWriter.Started {
+ return true
+ }
+ if ok := filterR.ValidRouter(urlPath, context); ok {
+ filterR.filterFunc(context)
+ }
+ if filterR.returnOnOutput && context.ResponseWriter.Started {
+ return true
+ }
+ }
+ }
+ }
+ return false
+}
+
// Implement http.Handler interface.
-func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
- starttime := time.Now()
- var runrouter reflect.Type
- var findrouter bool
- var runMethod string
- var routerInfo *controllerInfo
-
- w := &responseWriter{writer: rw}
-
- if RunMode == "dev" {
- w.Header().Set("Server", BeegoServerName)
- }
-
- // init context
- context := &beecontext.Context{
- ResponseWriter: w,
- Request: r,
- Input: beecontext.NewInput(r),
- Output: beecontext.NewOutput(),
- }
- context.Output.Context = context
- context.Output.EnableGzip = EnableGzip
-
+func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
+ startTime := time.Now()
+ var (
+ runRouter reflect.Type
+ findRouter bool
+ runMethod string
+ routerInfo *controllerInfo
+ )
+ context := p.pool.Get().(*beecontext.Context)
+ context.Reset(rw, r)
+ defer p.pool.Put(context)
defer p.recoverPanic(context)
+ context.Output.EnableGzip = BConfig.EnableGzip
+
+ if BConfig.RunMode == DEV {
+ context.Output.Header("Server", BConfig.ServerName)
+ }
+
var urlPath string
- if !RouterCaseSensitive {
+ if !BConfig.RouterCaseSensitive {
urlPath = strings.ToLower(r.URL.Path)
} else {
urlPath = r.URL.Path
}
- // defined filter function
- do_filter := func(pos int) (started bool) {
- if p.enableFilter {
- if l, ok := p.filters[pos]; ok {
- for _, filterR := range l {
- if filterR.returnOnOutput && w.started {
- return true
- }
- if ok, params := filterR.ValidRouter(urlPath); ok {
- for k, v := range params {
- if context.Input.Params == nil {
- context.Input.Params = make(map[string]string)
- }
- context.Input.Params[k] = v
- }
- filterR.filterFunc(context)
- }
- if filterR.returnOnOutput && w.started {
- return true
- }
- }
- }
- }
- return false
- }
- // filter wrong httpmethod
+ // filter wrong http method
if _, ok := HTTPMETHOD[r.Method]; !ok {
- http.Error(w, "Method Not Allowed", 405)
+ http.Error(rw, "Method Not Allowed", 405)
goto Admin
}
// filter for static file
- if do_filter(BeforeStatic) {
+ if p.execFilter(context, BeforeStatic, urlPath) {
goto Admin
}
serverStaticRouter(context)
- if w.started {
- findrouter = true
+ if context.ResponseWriter.Started {
+ findRouter = true
goto Admin
}
// session init
- if SessionOn {
+ if BConfig.WebConfig.Session.SessionOn {
var err error
- context.Input.CruSession, err = GlobalSessions.SessionStart(w, r)
+ context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r)
if err != nil {
Error(err)
exception("503", context)
return
}
defer func() {
- context.Input.CruSession.SessionRelease(w)
+ context.Input.CruSession.SessionRelease(rw)
}()
}
if r.Method != "GET" && r.Method != "HEAD" {
- if CopyRequestBody && !context.Input.IsUpload() {
- context.Input.CopyBody()
+ if BConfig.CopyRequestBody && !context.Input.IsUpload() {
+ context.Input.CopyBody(BConfig.MaxMemory)
}
- context.Input.ParseFormOrMulitForm(MaxMemory)
+ context.Input.ParseFormOrMulitForm(BConfig.MaxMemory)
}
- if do_filter(BeforeRouter) {
+ if p.execFilter(context, BeforeRouter, urlPath) {
goto Admin
}
- if context.Input.RunController != nil && context.Input.RunMethod != "" {
- findrouter = true
- runMethod = context.Input.RunMethod
- runrouter = context.Input.RunController
- }
-
- if !findrouter {
- http_method := r.Method
-
- if http_method == "POST" && context.Input.Query("_method") == "PUT" {
- http_method = "PUT"
- }
-
- if http_method == "POST" && context.Input.Query("_method") == "DELETE" {
- http_method = "DELETE"
- }
-
- if t, ok := p.routers[http_method]; ok {
- runObject, p := t.Match(urlPath)
+ if !findRouter {
+ httpMethod := r.Method
+ if t, ok := p.routers[httpMethod]; ok {
+ runObject := t.Match(urlPath, context)
if r, ok := runObject.(*controllerInfo); ok {
routerInfo = r
- findrouter = true
- if splat, ok := p[":splat"]; ok {
- splatlist := strings.Split(splat, "/")
- for k, v := range splatlist {
- p[strconv.Itoa(k)] = v
+ findRouter = true
+ if splat := context.Input.Param(":splat"); splat != "" {
+ for k, v := range strings.Split(splat, "/") {
+ context.Input.SetParam(strconv.Itoa(k), v)
}
}
- if p != nil {
- context.Input.Params = p
- }
}
}
}
//if no matches to url, throw a not found exception
- if !findrouter {
+ if !findRouter {
exception("404", context)
goto Admin
}
- if findrouter {
+ if findRouter {
//execute middleware filters
- if do_filter(BeforeExec) {
+ if p.execFilter(context, BeforeExec, urlPath) {
goto Admin
}
- isRunable := false
+ isRunnable := false
if routerInfo != nil {
if routerInfo.routerType == routerTypeRESTFul {
if _, ok := routerInfo.methods[r.Method]; ok {
- isRunable = true
- routerInfo.runfunction(context)
+ isRunnable = true
+ routerInfo.runFunction(context)
} else {
exception("405", context)
goto Admin
}
} else if routerInfo.routerType == routerTypeHandler {
- isRunable = true
+ isRunnable = true
routerInfo.handler.ServeHTTP(rw, r)
} else {
- runrouter = routerInfo.controllerType
+ runRouter = routerInfo.controllerType
method := r.Method
if r.Method == "POST" && context.Input.Query("_method") == "PUT" {
method = "PUT"
@@ -753,33 +724,33 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
}
}
- // also defined runrouter & runMethod from filter
- if !isRunable {
+ // also defined runRouter & runMethod from filter
+ if !isRunnable {
//Invoke the request handler
- vc := reflect.New(runrouter)
+ vc := reflect.New(runRouter)
execController, ok := vc.Interface().(ControllerInterface)
if !ok {
panic("controller is not ControllerInterface")
}
//call the controller init function
- execController.Init(context, runrouter.Name(), runMethod, vc.Interface())
+ execController.Init(context, runRouter.Name(), runMethod, vc.Interface())
//call prepare function
execController.Prepare()
//if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf
- if EnableXSRF {
- execController.XsrfToken()
+ if BConfig.WebConfig.EnableXSRF {
+ execController.XSRFToken()
if r.Method == "POST" || r.Method == "DELETE" || r.Method == "PUT" ||
(r.Method == "POST" && (context.Input.Query("_method") == "DELETE" || context.Input.Query("_method") == "PUT")) {
- execController.CheckXsrfCookie()
+ execController.CheckXSRFCookie()
}
}
execController.URLMapping()
- if !w.started {
+ if !context.ResponseWriter.Started {
//exec main logic
switch runMethod {
case "GET":
@@ -798,15 +769,15 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
execController.Options()
default:
if !execController.HandlerFunc(runMethod) {
- in := make([]reflect.Value, 0)
+ var in []reflect.Value
method := vc.MethodByName(runMethod)
method.Call(in)
}
}
//render template
- if !w.started && context.Output.Status == 0 {
- if AutoRender {
+ if !context.ResponseWriter.Started && context.Output.Status == 0 {
+ if BConfig.WebConfig.AutoRender {
if err := execController.Render(); err != nil {
panic(err)
}
@@ -814,69 +785,69 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
}
}
- // finish all runrouter. release resource
+ // finish all runRouter. release resource
execController.Finish()
}
//execute middleware filters
- if do_filter(AfterExec) {
+ if p.execFilter(context, AfterExec, urlPath) {
goto Admin
}
}
- do_filter(FinishRouter)
+ p.execFilter(context, FinishRouter, urlPath)
Admin:
- timeend := time.Since(starttime)
+ timeDur := time.Since(startTime)
//admin module record QPS
- if EnableAdmin {
- if FilterMonitorFunc(r.Method, r.URL.Path, timeend) {
- if runrouter != nil {
- go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runrouter.Name(), timeend)
+ if BConfig.Listen.EnableAdmin {
+ if FilterMonitorFunc(r.Method, r.URL.Path, timeDur) {
+ if runRouter != nil {
+ go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runRouter.Name(), timeDur)
} else {
- go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, "", timeend)
+ go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, "", timeDur)
}
}
}
- if RunMode == "dev" || AccessLogs {
- var devinfo string
- if findrouter {
+ if BConfig.RunMode == DEV || BConfig.Log.AccessLogs {
+ var devInfo string
+ if findRouter {
if routerInfo != nil {
- devinfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s | % -40s |", r.Method, r.URL.Path, timeend.String(), "match", routerInfo.pattern)
+ devInfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s | % -40s |", r.Method, r.URL.Path, timeDur.String(), "match", routerInfo.pattern)
} else {
- devinfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeend.String(), "match")
+ devInfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeDur.String(), "match")
}
} else {
- devinfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeend.String(), "notmatch")
+ devInfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeDur.String(), "notmatch")
}
- if DefaultLogFilter == nil || !DefaultLogFilter.Filter(context) {
- Debug(devinfo)
+ if DefaultAccessLogFilter == nil || !DefaultAccessLogFilter.Filter(context) {
+ Debug(devInfo)
}
}
// Call WriteHeader if status code has been set changed
if context.Output.Status != 0 {
- w.writer.WriteHeader(context.Output.Status)
+ context.ResponseWriter.WriteHeader(context.Output.Status)
}
}
-func (p *ControllerRegistor) recoverPanic(context *beecontext.Context) {
+func (p *ControllerRegister) recoverPanic(context *beecontext.Context) {
if err := recover(); err != nil {
- if err == USERSTOPRUN {
+ if err == ErrAbort {
return
}
- if !RecoverPanic {
+ if !BConfig.RecoverPanic {
panic(err)
} else {
- if ErrorsShow {
+ if BConfig.EnableErrorsShow {
if _, ok := ErrorMaps[fmt.Sprint(err)]; ok {
exception(fmt.Sprint(err), context)
return
}
}
var stack string
- Critical("the request url is ", context.Input.Url())
+ Critical("the request url is ", context.Input.URL())
Critical("Handler crashed with error", err)
for i := 1; ; i++ {
_, file, line, ok := runtime.Caller(i)
@@ -886,52 +857,14 @@ func (p *ControllerRegistor) recoverPanic(context *beecontext.Context) {
Critical(fmt.Sprintf("%s:%d", file, line))
stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line))
}
- if RunMode == "dev" {
+ if BConfig.RunMode == DEV {
showErr(err, context, stack)
}
}
}
}
-//responseWriter is a wrapper for the http.ResponseWriter
-//started set to true if response was written to then don't execute other handler
-type responseWriter struct {
- writer http.ResponseWriter
- started bool
- status int
-}
-
-// Header returns the header map that will be sent by WriteHeader.
-func (w *responseWriter) Header() http.Header {
- return w.writer.Header()
-}
-
-// Write writes the data to the connection as part of an HTTP reply,
-// and sets `started` to true.
-// started means the response has sent out.
-func (w *responseWriter) Write(p []byte) (int, error) {
- w.started = true
- return w.writer.Write(p)
-}
-
-// WriteHeader sends an HTTP response header with status code,
-// and sets `started` to true.
-func (w *responseWriter) WriteHeader(code int) {
- w.status = code
- w.started = true
- w.writer.WriteHeader(code)
-}
-
-// hijacker for http
-func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
- hj, ok := w.writer.(http.Hijacker)
- if !ok {
- return nil, nil, errors.New("webserver doesn't support hijacking")
- }
- return hj.Hijack()
-}
-
-func tourl(params map[string]string) string {
+func toUrl(params map[string]string) string {
if len(params) == 0 {
return ""
}
diff --git a/router_test.go b/router_test.go
index 005f32d6..b0ae7a18 100644
--- a/router_test.go
+++ b/router_test.go
@@ -45,24 +45,24 @@ func (tc *TestController) List() {
}
func (tc *TestController) Params() {
- tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Params["0"] + tc.Ctx.Input.Params["1"] + tc.Ctx.Input.Params["2"]))
+ tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Param("0") + tc.Ctx.Input.Param("1") + tc.Ctx.Input.Param("2")))
}
func (tc *TestController) Myext() {
tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Param(":ext")))
}
-func (tc *TestController) GetUrl() {
- tc.Ctx.Output.Body([]byte(tc.UrlFor(".Myext")))
+func (tc *TestController) GetURL() {
+ tc.Ctx.Output.Body([]byte(tc.URLFor(".Myext")))
}
-func (t *TestController) GetParams() {
- t.Ctx.WriteString(t.Ctx.Input.Query(":last") + "+" +
- t.Ctx.Input.Query(":first") + "+" + t.Ctx.Input.Query("learn"))
+func (tc *TestController) GetParams() {
+ tc.Ctx.WriteString(tc.Ctx.Input.Query(":last") + "+" +
+ tc.Ctx.Input.Query(":first") + "+" + tc.Ctx.Input.Query("learn"))
}
-func (t *TestController) GetManyRouter() {
- t.Ctx.WriteString(t.Ctx.Input.Query(":id") + t.Ctx.Input.Query(":page"))
+func (tc *TestController) GetManyRouter() {
+ tc.Ctx.WriteString(tc.Ctx.Input.Query(":id") + tc.Ctx.Input.Query(":page"))
}
type ResStatus struct {
@@ -70,29 +70,29 @@ type ResStatus struct {
Msg string
}
-type JsonController struct {
+type JSONController struct {
Controller
}
-func (this *JsonController) Prepare() {
- this.Data["json"] = "prepare"
- this.ServeJson(true)
+func (jc *JSONController) Prepare() {
+ jc.Data["json"] = "prepare"
+ jc.ServeJSON(true)
}
-func (this *JsonController) Get() {
- this.Data["Username"] = "astaxie"
- this.Ctx.Output.Body([]byte("ok"))
+func (jc *JSONController) Get() {
+ jc.Data["Username"] = "astaxie"
+ jc.Ctx.Output.Body([]byte("ok"))
}
func TestUrlFor(t *testing.T) {
handler := NewControllerRegister()
handler.Add("/api/list", &TestController{}, "*:List")
handler.Add("/person/:last/:first", &TestController{}, "*:Param")
- if a := handler.UrlFor("TestController.List"); a != "/api/list" {
+ if a := handler.URLFor("TestController.List"); a != "/api/list" {
Info(a)
t.Errorf("TestController.List must equal to /api/list")
}
- if a := handler.UrlFor("TestController.Param", ":last", "xie", ":first", "asta"); a != "/person/xie/asta" {
+ if a := handler.URLFor("TestController.Param", ":last", "xie", ":first", "asta"); a != "/person/xie/asta" {
t.Errorf("TestController.Param must equal to /person/xie/asta, but get " + a)
}
}
@@ -100,39 +100,39 @@ func TestUrlFor(t *testing.T) {
func TestUrlFor3(t *testing.T) {
handler := NewControllerRegister()
handler.AddAuto(&TestController{})
- if a := handler.UrlFor("TestController.Myext"); a != "/test/myext" && a != "/Test/Myext" {
+ if a := handler.URLFor("TestController.Myext"); a != "/test/myext" && a != "/Test/Myext" {
t.Errorf("TestController.Myext must equal to /test/myext, but get " + a)
}
- if a := handler.UrlFor("TestController.GetUrl"); a != "/test/geturl" && a != "/Test/GetUrl" {
- t.Errorf("TestController.GetUrl must equal to /test/geturl, but get " + a)
+ if a := handler.URLFor("TestController.GetURL"); a != "/test/geturl" && a != "/Test/GetURL" {
+ t.Errorf("TestController.GetURL must equal to /test/geturl, but get " + a)
}
}
func TestUrlFor2(t *testing.T) {
handler := NewControllerRegister()
handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, "*:List")
- handler.Add("/v1/:username/edit", &TestController{}, "get:GetUrl")
+ handler.Add("/v1/:username/edit", &TestController{}, "get:GetURL")
handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param")
handler.Add("/:year:int/:month:int/:title/:entid", &TestController{})
- if handler.UrlFor("TestController.GetUrl", ":username", "astaxie") != "/v1/astaxie/edit" {
- Info(handler.UrlFor("TestController.GetUrl"))
+ if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" {
+ Info(handler.URLFor("TestController.GetURL"))
t.Errorf("TestController.List must equal to /v1/astaxie/edit")
}
- if handler.UrlFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") !=
+ if handler.URLFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") !=
"/v1/za/cms_12_123.html" {
- Info(handler.UrlFor("TestController.List"))
+ Info(handler.URLFor("TestController.List"))
t.Errorf("TestController.List must equal to /v1/za/cms_12_123.html")
}
- if handler.UrlFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") !=
+ if handler.URLFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") !=
"/v1/za_cms/ttt_12_123.html" {
- Info(handler.UrlFor("TestController.Param"))
+ Info(handler.URLFor("TestController.Param"))
t.Errorf("TestController.List must equal to /v1/za_cms/ttt_12_123.html")
}
- if handler.UrlFor("TestController.Get", ":year", "1111", ":month", "11",
+ if handler.URLFor("TestController.Get", ":year", "1111", ":month", "11",
":title", "aaaa", ":entid", "aaaa") !=
"/1111/11/aaaa/aaaa" {
- Info(handler.UrlFor("TestController.Get"))
+ Info(handler.URLFor("TestController.Get"))
t.Errorf("TestController.Get must equal to /1111/11/aaaa/aaaa")
}
}
@@ -270,7 +270,7 @@ func TestPrepare(t *testing.T) {
w := httptest.NewRecorder()
handler := NewControllerRegister()
- handler.Add("/json/list", &JsonController{})
+ handler.Add("/json/list", &JSONController{})
handler.ServeHTTP(w, r)
if w.Body.String() != `"prepare"` {
t.Errorf(w.Body.String() + "user define func can't run")
@@ -333,6 +333,18 @@ func TestRouterHandler(t *testing.T) {
}
}
+func TestRouterHandlerAll(t *testing.T) {
+ r, _ := http.NewRequest("POST", "/sayhi/a/b/c", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.Handler("/sayhi", http.HandlerFunc(sayhello), true)
+ handler.ServeHTTP(w, r)
+ if w.Body.String() != "sayhello" {
+ t.Errorf("TestRouterHandler can't run")
+ }
+}
+
//
// Benchmarks NewApp:
//
diff --git a/session/couchbase/sess_couchbase.go b/session/couchbase/sess_couchbase.go
index 827d55d9..d5be11d0 100644
--- a/session/couchbase/sess_couchbase.go
+++ b/session/couchbase/sess_couchbase.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// package couchbase for session provider
+// Package couchbase for session provider
//
// depend on github.com/couchbaselabs/go-couchbasee
//
@@ -30,21 +30,22 @@
// }
//
// more docs: http://beego.me/docs/module/session.md
-package session
+package couchbase
import (
"net/http"
"strings"
"sync"
- "github.com/couchbaselabs/go-couchbase"
+ couchbase "github.com/couchbase/go-couchbase"
"github.com/astaxie/beego/session"
)
-var couchbpder = &CouchbaseProvider{}
+var couchbpder = &Provider{}
-type CouchbaseSessionStore struct {
+// SessionStore store each session
+type SessionStore struct {
b *couchbase.Bucket
sid string
lock sync.RWMutex
@@ -52,7 +53,8 @@ type CouchbaseSessionStore struct {
maxlifetime int64
}
-type CouchbaseProvider struct {
+// Provider couchabse provided
+type Provider struct {
maxlifetime int64
savePath string
pool string
@@ -60,42 +62,47 @@ type CouchbaseProvider struct {
b *couchbase.Bucket
}
-func (cs *CouchbaseSessionStore) Set(key, value interface{}) error {
+// Set value to couchabse session
+func (cs *SessionStore) Set(key, value interface{}) error {
cs.lock.Lock()
defer cs.lock.Unlock()
cs.values[key] = value
return nil
}
-func (cs *CouchbaseSessionStore) Get(key interface{}) interface{} {
+// Get value from couchabse session
+func (cs *SessionStore) Get(key interface{}) interface{} {
cs.lock.RLock()
defer cs.lock.RUnlock()
if v, ok := cs.values[key]; ok {
return v
- } else {
- return nil
}
+ return nil
}
-func (cs *CouchbaseSessionStore) Delete(key interface{}) error {
+// Delete value in couchbase session by given key
+func (cs *SessionStore) Delete(key interface{}) error {
cs.lock.Lock()
defer cs.lock.Unlock()
delete(cs.values, key)
return nil
}
-func (cs *CouchbaseSessionStore) Flush() error {
+// Flush Clean all values in couchbase session
+func (cs *SessionStore) Flush() error {
cs.lock.Lock()
defer cs.lock.Unlock()
cs.values = make(map[interface{}]interface{})
return nil
}
-func (cs *CouchbaseSessionStore) SessionID() string {
+// SessionID Get couchbase session store id
+func (cs *SessionStore) SessionID() string {
return cs.sid
}
-func (cs *CouchbaseSessionStore) SessionRelease(w http.ResponseWriter) {
+// SessionRelease Write couchbase session with Gob string
+func (cs *SessionStore) SessionRelease(w http.ResponseWriter) {
defer cs.b.Close()
bo, err := session.EncodeGob(cs.values)
@@ -106,7 +113,7 @@ func (cs *CouchbaseSessionStore) SessionRelease(w http.ResponseWriter) {
cs.b.Set(cs.sid, int(cs.maxlifetime), bo)
}
-func (cp *CouchbaseProvider) getBucket() *couchbase.Bucket {
+func (cp *Provider) getBucket() *couchbase.Bucket {
c, err := couchbase.Connect(cp.savePath)
if err != nil {
return nil
@@ -125,10 +132,10 @@ func (cp *CouchbaseProvider) getBucket() *couchbase.Bucket {
return bucket
}
-// init couchbase session
+// SessionInit init couchbase session
// savepath like couchbase server REST/JSON URL
// e.g. http://host:port/, Pool, Bucket
-func (cp *CouchbaseProvider) SessionInit(maxlifetime int64, savePath string) error {
+func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error {
cp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
@@ -144,8 +151,8 @@ func (cp *CouchbaseProvider) SessionInit(maxlifetime int64, savePath string) err
return nil
}
-// read couchbase session by sid
-func (cp *CouchbaseProvider) SessionRead(sid string) (session.SessionStore, error) {
+// SessionRead read couchbase session by sid
+func (cp *Provider) SessionRead(sid string) (session.Store, error) {
cp.b = cp.getBucket()
var doc []byte
@@ -161,11 +168,13 @@ func (cp *CouchbaseProvider) SessionRead(sid string) (session.SessionStore, erro
}
}
- cs := &CouchbaseSessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
+ cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
return cs, nil
}
-func (cp *CouchbaseProvider) SessionExist(sid string) bool {
+// SessionExist Check couchbase session exist.
+// it checkes sid exist or not.
+func (cp *Provider) SessionExist(sid string) bool {
cp.b = cp.getBucket()
defer cp.b.Close()
@@ -173,12 +182,12 @@ func (cp *CouchbaseProvider) SessionExist(sid string) bool {
if err := cp.b.Get(sid, &doc); err != nil || doc == nil {
return false
- } else {
- return true
}
+ return true
}
-func (cp *CouchbaseProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) {
+// SessionRegenerate remove oldsid and use sid to generate new session
+func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
cp.b = cp.getBucket()
var doc []byte
@@ -206,11 +215,12 @@ func (cp *CouchbaseProvider) SessionRegenerate(oldsid, sid string) (session.Sess
}
}
- cs := &CouchbaseSessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
+ cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
return cs, nil
}
-func (cp *CouchbaseProvider) SessionDestroy(sid string) error {
+// SessionDestroy Remove bucket in this couchbase
+func (cp *Provider) SessionDestroy(sid string) error {
cp.b = cp.getBucket()
defer cp.b.Close()
@@ -218,11 +228,13 @@ func (cp *CouchbaseProvider) SessionDestroy(sid string) error {
return nil
}
-func (cp *CouchbaseProvider) SessionGC() {
+// SessionGC Recycle
+func (cp *Provider) SessionGC() {
return
}
-func (cp *CouchbaseProvider) SessionAll() int {
+// SessionAll return all active session
+func (cp *Provider) SessionAll() int {
return 0
}
diff --git a/session/ledis/ledis_session.go b/session/ledis/ledis_session.go
index 643b8817..68f37b08 100644
--- a/session/ledis/ledis_session.go
+++ b/session/ledis/ledis_session.go
@@ -1,4 +1,5 @@
-package session
+// Package ledis provide session Provider
+package ledis
import (
"net/http"
@@ -11,59 +12,58 @@ import (
"github.com/siddontang/ledisdb/ledis"
)
-var ledispder = &LedisProvider{}
+var ledispder = &Provider{}
var c *ledis.DB
-// ledis session store
-type LedisSessionStore struct {
+// SessionStore ledis session store
+type SessionStore struct {
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
-// set value in ledis session
-func (ls *LedisSessionStore) Set(key, value interface{}) error {
+// Set value in ledis session
+func (ls *SessionStore) Set(key, value interface{}) error {
ls.lock.Lock()
defer ls.lock.Unlock()
ls.values[key] = value
return nil
}
-// get value in ledis session
-func (ls *LedisSessionStore) Get(key interface{}) interface{} {
+// Get value in ledis session
+func (ls *SessionStore) Get(key interface{}) interface{} {
ls.lock.RLock()
defer ls.lock.RUnlock()
if v, ok := ls.values[key]; ok {
return v
- } else {
- return nil
}
+ return nil
}
-// delete value in ledis session
-func (ls *LedisSessionStore) Delete(key interface{}) error {
+// Delete value in ledis session
+func (ls *SessionStore) Delete(key interface{}) error {
ls.lock.Lock()
defer ls.lock.Unlock()
delete(ls.values, key)
return nil
}
-// clear all values in ledis session
-func (ls *LedisSessionStore) Flush() error {
+// Flush clear all values in ledis session
+func (ls *SessionStore) Flush() error {
ls.lock.Lock()
defer ls.lock.Unlock()
ls.values = make(map[interface{}]interface{})
return nil
}
-// get ledis session id
-func (ls *LedisSessionStore) SessionID() string {
+// SessionID get ledis session id
+func (ls *SessionStore) SessionID() string {
return ls.sid
}
-// save session values to ledis
-func (ls *LedisSessionStore) SessionRelease(w http.ResponseWriter) {
+// SessionRelease save session values to ledis
+func (ls *SessionStore) SessionRelease(w http.ResponseWriter) {
b, err := session.EncodeGob(ls.values)
if err != nil {
return
@@ -72,17 +72,17 @@ func (ls *LedisSessionStore) SessionRelease(w http.ResponseWriter) {
c.Expire([]byte(ls.sid), ls.maxlifetime)
}
-// ledis session provider
-type LedisProvider struct {
+// Provider ledis session provider
+type Provider struct {
maxlifetime int64
savePath string
db int
}
-// init ledis session
+// SessionInit init ledis session
// savepath like ledis server saveDataPath,pool size
// e.g. 127.0.0.1:6379,100,astaxie
-func (lp *LedisProvider) SessionInit(maxlifetime int64, savePath string) error {
+func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error {
var err error
lp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
@@ -106,8 +106,8 @@ func (lp *LedisProvider) SessionInit(maxlifetime int64, savePath string) error {
return nil
}
-// read ledis session by sid
-func (lp *LedisProvider) SessionRead(sid string) (session.SessionStore, error) {
+// SessionRead read ledis session by sid
+func (lp *Provider) SessionRead(sid string) (session.Store, error) {
kvs, err := c.Get([]byte(sid))
var kv map[interface{}]interface{}
if len(kvs) == 0 {
@@ -118,22 +118,21 @@ func (lp *LedisProvider) SessionRead(sid string) (session.SessionStore, error) {
return nil, err
}
}
- ls := &LedisSessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime}
+ ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime}
return ls, nil
}
-// check ledis session exist by sid
-func (lp *LedisProvider) SessionExist(sid string) bool {
+// SessionExist check ledis session exist by sid
+func (lp *Provider) SessionExist(sid string) bool {
count, _ := c.Exists([]byte(sid))
if count == 0 {
return false
- } else {
- return true
}
+ return true
}
-// generate new sid for ledis session
-func (lp *LedisProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) {
+// SessionRegenerate generate new sid for ledis session
+func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
count, _ := c.Exists([]byte(sid))
if count == 0 {
// oldsid doesn't exists, set the new sid directly
@@ -156,23 +155,23 @@ func (lp *LedisProvider) SessionRegenerate(oldsid, sid string) (session.SessionS
return nil, err
}
}
- ls := &LedisSessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime}
+ ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime}
return ls, nil
}
-// delete ledis session by id
-func (lp *LedisProvider) SessionDestroy(sid string) error {
+// SessionDestroy delete ledis session by id
+func (lp *Provider) SessionDestroy(sid string) error {
c.Del([]byte(sid))
return nil
}
-// Impelment method, no used.
-func (lp *LedisProvider) SessionGC() {
+// SessionGC Impelment method, no used.
+func (lp *Provider) SessionGC() {
return
}
-// @todo
-func (lp *LedisProvider) SessionAll() int {
+// SessionAll return all active session
+func (lp *Provider) SessionAll() int {
return 0
}
func init() {
diff --git a/session/memcache/sess_memcache.go b/session/memcache/sess_memcache.go
index bb33075a..f1069bc9 100644
--- a/session/memcache/sess_memcache.go
+++ b/session/memcache/sess_memcache.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// package memcache for session provider
+// Package memcache for session provider
//
// depend on github.com/bradfitz/gomemcache/memcache
//
@@ -30,7 +30,7 @@
// }
//
// more docs: http://beego.me/docs/module/session.md
-package session
+package memcache
import (
"net/http"
@@ -45,56 +45,55 @@ import (
var mempder = &MemProvider{}
var client *memcache.Client
-// memcache session store
-type MemcacheSessionStore struct {
+// SessionStore memcache session store
+type SessionStore struct {
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
-// set value in memcache session
-func (rs *MemcacheSessionStore) Set(key, value interface{}) error {
+// Set value in memcache session
+func (rs *SessionStore) Set(key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
return nil
}
-// get value in memcache session
-func (rs *MemcacheSessionStore) Get(key interface{}) interface{} {
+// Get value in memcache session
+func (rs *SessionStore) Get(key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
return v
- } else {
- return nil
}
+ return nil
}
-// delete value in memcache session
-func (rs *MemcacheSessionStore) Delete(key interface{}) error {
+// Delete value in memcache session
+func (rs *SessionStore) Delete(key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
return nil
}
-// clear all values in memcache session
-func (rs *MemcacheSessionStore) Flush() error {
+// Flush clear all values in memcache session
+func (rs *SessionStore) Flush() error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
return nil
}
-// get memcache session id
-func (rs *MemcacheSessionStore) SessionID() string {
+// SessionID get memcache session id
+func (rs *SessionStore) SessionID() string {
return rs.sid
}
-// save session values to memcache
-func (rs *MemcacheSessionStore) SessionRelease(w http.ResponseWriter) {
+// SessionRelease save session values to memcache
+func (rs *SessionStore) SessionRelease(w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
@@ -103,7 +102,7 @@ func (rs *MemcacheSessionStore) SessionRelease(w http.ResponseWriter) {
client.Set(&item)
}
-// memcahe session provider
+// MemProvider memcache session provider
type MemProvider struct {
maxlifetime int64
conninfo []string
@@ -111,7 +110,7 @@ type MemProvider struct {
password string
}
-// init memcache session
+// SessionInit init memcache session
// savepath like
// e.g. 127.0.0.1:9090
func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error {
@@ -121,8 +120,8 @@ func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error {
return nil
}
-// read memcache session by sid
-func (rp *MemProvider) SessionRead(sid string) (session.SessionStore, error) {
+// SessionRead read memcache session by sid
+func (rp *MemProvider) SessionRead(sid string) (session.Store, error) {
if client == nil {
if err := rp.connectInit(); err != nil {
return nil, err
@@ -130,7 +129,7 @@ func (rp *MemProvider) SessionRead(sid string) (session.SessionStore, error) {
}
item, err := client.Get(sid)
if err != nil && err == memcache.ErrCacheMiss {
- rs := &MemcacheSessionStore{sid: sid, values: make(map[interface{}]interface{}), maxlifetime: rp.maxlifetime}
+ rs := &SessionStore{sid: sid, values: make(map[interface{}]interface{}), maxlifetime: rp.maxlifetime}
return rs, nil
}
var kv map[interface{}]interface{}
@@ -142,11 +141,11 @@ func (rp *MemProvider) SessionRead(sid string) (session.SessionStore, error) {
return nil, err
}
}
- rs := &MemcacheSessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime}
+ rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
-// check memcache session exist by sid
+// SessionExist check memcache session exist by sid
func (rp *MemProvider) SessionExist(sid string) bool {
if client == nil {
if err := rp.connectInit(); err != nil {
@@ -155,13 +154,12 @@ func (rp *MemProvider) SessionExist(sid string) bool {
}
if item, err := client.Get(sid); err != nil || len(item.Value) == 0 {
return false
- } else {
- return true
}
+ return true
}
-// generate new sid for memcache session
-func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) {
+// SessionRegenerate generate new sid for memcache session
+func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
if client == nil {
if err := rp.connectInit(); err != nil {
return nil, err
@@ -195,11 +193,11 @@ func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.SessionSto
}
}
- rs := &MemcacheSessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime}
+ rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
-// delete memcache session by id
+// SessionDestroy delete memcache session by id
func (rp *MemProvider) SessionDestroy(sid string) error {
if client == nil {
if err := rp.connectInit(); err != nil {
@@ -219,12 +217,12 @@ func (rp *MemProvider) connectInit() error {
return nil
}
-// Impelment method, no used.
+// SessionGC Impelment method, no used.
func (rp *MemProvider) SessionGC() {
return
}
-// @todo
+// SessionAll return all activeSession
func (rp *MemProvider) SessionAll() int {
return 0
}
diff --git a/session/mysql/sess_mysql.go b/session/mysql/sess_mysql.go
index 76a13932..969d26c9 100644
--- a/session/mysql/sess_mysql.go
+++ b/session/mysql/sess_mysql.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// package mysql for session provider
+// Package mysql for session provider
//
// depends on github.com/go-sql-driver/mysql:
//
@@ -38,7 +38,7 @@
// }
//
// more docs: http://beego.me/docs/module/session.md
-package session
+package mysql
import (
"database/sql"
@@ -47,82 +47,85 @@ import (
"time"
"github.com/astaxie/beego/session"
-
+ // import mysql driver
_ "github.com/go-sql-driver/mysql"
)
-var mysqlpder = &MysqlProvider{}
+var (
+ // TableName store the session in MySQL
+ TableName = "session"
+ mysqlpder = &Provider{}
+)
-// mysql session store
-type MysqlSessionStore struct {
+// SessionStore mysql session store
+type SessionStore struct {
c *sql.DB
sid string
lock sync.RWMutex
values map[interface{}]interface{}
}
-// set value in mysql session.
+// Set value in mysql session.
// it is temp value in map.
-func (st *MysqlSessionStore) Set(key, value interface{}) error {
+func (st *SessionStore) Set(key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
return nil
}
-// get value from mysql session
-func (st *MysqlSessionStore) Get(key interface{}) interface{} {
+// Get value from mysql session
+func (st *SessionStore) Get(key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
- } else {
- return nil
}
+ return nil
}
-// delete value in mysql session
-func (st *MysqlSessionStore) Delete(key interface{}) error {
+// Delete value in mysql session
+func (st *SessionStore) Delete(key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
return nil
}
-// clear all values in mysql session
-func (st *MysqlSessionStore) Flush() error {
+// Flush clear all values in mysql session
+func (st *SessionStore) Flush() error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
return nil
}
-// get session id of this mysql session store
-func (st *MysqlSessionStore) SessionID() string {
+// SessionID get session id of this mysql session store
+func (st *SessionStore) SessionID() string {
return st.sid
}
-// save mysql session values to database.
+// SessionRelease save mysql session values to database.
// must call this method to save values to database.
-func (st *MysqlSessionStore) SessionRelease(w http.ResponseWriter) {
+func (st *SessionStore) SessionRelease(w http.ResponseWriter) {
defer st.c.Close()
b, err := session.EncodeGob(st.values)
if err != nil {
return
}
- st.c.Exec("UPDATE session set `session_data`=?, `session_expiry`=? where session_key=?",
+ st.c.Exec("UPDATE "+TableName+" set `session_data`=?, `session_expiry`=? where session_key=?",
b, time.Now().Unix(), st.sid)
}
-// mysql session provider
-type MysqlProvider struct {
+// Provider mysql session provider
+type Provider struct {
maxlifetime int64
savePath string
}
// connect to mysql
-func (mp *MysqlProvider) connectInit() *sql.DB {
+func (mp *Provider) connectInit() *sql.DB {
db, e := sql.Open("mysql", mp.savePath)
if e != nil {
return nil
@@ -130,22 +133,22 @@ func (mp *MysqlProvider) connectInit() *sql.DB {
return db
}
-// init mysql session.
+// SessionInit init mysql session.
// savepath is the connection string of mysql.
-func (mp *MysqlProvider) SessionInit(maxlifetime int64, savePath string) error {
+func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error {
mp.maxlifetime = maxlifetime
mp.savePath = savePath
return nil
}
-// get mysql session by sid
-func (mp *MysqlProvider) SessionRead(sid string) (session.SessionStore, error) {
+// SessionRead get mysql session by sid
+func (mp *Provider) SessionRead(sid string) (session.Store, error) {
c := mp.connectInit()
- row := c.QueryRow("select session_data from session where session_key=?", sid)
+ row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
- c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)",
+ c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)",
sid, "", time.Now().Unix())
}
var kv map[interface{}]interface{}
@@ -157,34 +160,33 @@ func (mp *MysqlProvider) SessionRead(sid string) (session.SessionStore, error) {
return nil, err
}
}
- rs := &MysqlSessionStore{c: c, sid: sid, values: kv}
+ rs := &SessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
-// check mysql session exist
-func (mp *MysqlProvider) SessionExist(sid string) bool {
+// SessionExist check mysql session exist
+func (mp *Provider) SessionExist(sid string) bool {
c := mp.connectInit()
defer c.Close()
- row := c.QueryRow("select session_data from session where session_key=?", sid)
+ row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
return false
- } else {
- return true
}
+ return true
}
-// generate new sid for mysql session
-func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) {
+// SessionRegenerate generate new sid for mysql session
+func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
c := mp.connectInit()
- row := c.QueryRow("select session_data from session where session_key=?", oldsid)
+ row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
- c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix())
+ c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix())
}
- c.Exec("update session set `session_key`=? where session_key=?", sid, oldsid)
+ c.Exec("update "+TableName+" set `session_key`=? where session_key=?", sid, oldsid)
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
@@ -194,32 +196,32 @@ func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (session.SessionS
return nil, err
}
}
- rs := &MysqlSessionStore{c: c, sid: sid, values: kv}
+ rs := &SessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
-// delete mysql session by sid
-func (mp *MysqlProvider) SessionDestroy(sid string) error {
+// SessionDestroy delete mysql session by sid
+func (mp *Provider) SessionDestroy(sid string) error {
c := mp.connectInit()
- c.Exec("DELETE FROM session where session_key=?", sid)
+ c.Exec("DELETE FROM "+TableName+" where session_key=?", sid)
c.Close()
return nil
}
-// delete expired values in mysql session
-func (mp *MysqlProvider) SessionGC() {
+// SessionGC delete expired values in mysql session
+func (mp *Provider) SessionGC() {
c := mp.connectInit()
- c.Exec("DELETE from session where session_expiry < ?", time.Now().Unix()-mp.maxlifetime)
+ c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime)
c.Close()
return
}
-// count values in mysql session
-func (mp *MysqlProvider) SessionAll() int {
+// SessionAll count values in mysql session
+func (mp *Provider) SessionAll() int {
c := mp.connectInit()
defer c.Close()
var total int
- err := c.QueryRow("SELECT count(*) as num from session").Scan(&total)
+ err := c.QueryRow("SELECT count(*) as num from " + TableName).Scan(&total)
if err != nil {
return 0
}
diff --git a/session/postgres/sess_postgresql.go b/session/postgres/sess_postgresql.go
index ac9a1612..73f9c13a 100644
--- a/session/postgres/sess_postgresql.go
+++ b/session/postgres/sess_postgresql.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// package postgresql for session provider
+// Package postgres for session provider
//
// depends on github.com/lib/pq:
//
@@ -27,9 +27,9 @@
// session_expiry timestamp NOT NULL,
// CONSTRAINT session_key PRIMARY KEY(session_key)
// );
-
+//
// will be activated with these settings in app.conf:
-
+//
// SessionOn = true
// SessionProvider = postgresql
// SessionSavePath = "user=a password=b dbname=c sslmode=disable"
@@ -48,7 +48,7 @@
// }
//
// more docs: http://beego.me/docs/module/session.md
-package session
+package postgres
import (
"database/sql"
@@ -57,64 +57,63 @@ import (
"time"
"github.com/astaxie/beego/session"
-
+ // import postgresql Driver
_ "github.com/lib/pq"
)
-var postgresqlpder = &PostgresqlProvider{}
+var postgresqlpder = &Provider{}
-// postgresql session store
-type PostgresqlSessionStore struct {
+// SessionStore postgresql session store
+type SessionStore struct {
c *sql.DB
sid string
lock sync.RWMutex
values map[interface{}]interface{}
}
-// set value in postgresql session.
+// Set value in postgresql session.
// it is temp value in map.
-func (st *PostgresqlSessionStore) Set(key, value interface{}) error {
+func (st *SessionStore) Set(key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
return nil
}
-// get value from postgresql session
-func (st *PostgresqlSessionStore) Get(key interface{}) interface{} {
+// Get value from postgresql session
+func (st *SessionStore) Get(key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
- } else {
- return nil
}
+ return nil
}
-// delete value in postgresql session
-func (st *PostgresqlSessionStore) Delete(key interface{}) error {
+// Delete value in postgresql session
+func (st *SessionStore) Delete(key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
return nil
}
-// clear all values in postgresql session
-func (st *PostgresqlSessionStore) Flush() error {
+// Flush clear all values in postgresql session
+func (st *SessionStore) Flush() error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
return nil
}
-// get session id of this postgresql session store
-func (st *PostgresqlSessionStore) SessionID() string {
+// SessionID get session id of this postgresql session store
+func (st *SessionStore) SessionID() string {
return st.sid
}
-// save postgresql session values to database.
+// SessionRelease save postgresql session values to database.
// must call this method to save values to database.
-func (st *PostgresqlSessionStore) SessionRelease(w http.ResponseWriter) {
+func (st *SessionStore) SessionRelease(w http.ResponseWriter) {
defer st.c.Close()
b, err := session.EncodeGob(st.values)
if err != nil {
@@ -125,14 +124,14 @@ func (st *PostgresqlSessionStore) SessionRelease(w http.ResponseWriter) {
}
-// postgresql session provider
-type PostgresqlProvider struct {
+// Provider postgresql session provider
+type Provider struct {
maxlifetime int64
savePath string
}
// connect to postgresql
-func (mp *PostgresqlProvider) connectInit() *sql.DB {
+func (mp *Provider) connectInit() *sql.DB {
db, e := sql.Open("postgres", mp.savePath)
if e != nil {
return nil
@@ -140,16 +139,16 @@ func (mp *PostgresqlProvider) connectInit() *sql.DB {
return db
}
-// init postgresql session.
+// SessionInit init postgresql session.
// savepath is the connection string of postgresql.
-func (mp *PostgresqlProvider) SessionInit(maxlifetime int64, savePath string) error {
+func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error {
mp.maxlifetime = maxlifetime
mp.savePath = savePath
return nil
}
-// get postgresql session by sid
-func (mp *PostgresqlProvider) SessionRead(sid string) (session.SessionStore, error) {
+// SessionRead get postgresql session by sid
+func (mp *Provider) SessionRead(sid string) (session.Store, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=$1", sid)
var sessiondata []byte
@@ -174,12 +173,12 @@ func (mp *PostgresqlProvider) SessionRead(sid string) (session.SessionStore, err
return nil, err
}
}
- rs := &PostgresqlSessionStore{c: c, sid: sid, values: kv}
+ rs := &SessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
-// check postgresql session exist
-func (mp *PostgresqlProvider) SessionExist(sid string) bool {
+// SessionExist check postgresql session exist
+func (mp *Provider) SessionExist(sid string) bool {
c := mp.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from session where session_key=$1", sid)
@@ -188,13 +187,12 @@ func (mp *PostgresqlProvider) SessionExist(sid string) bool {
if err == sql.ErrNoRows {
return false
- } else {
- return true
}
+ return true
}
-// generate new sid for postgresql session
-func (mp *PostgresqlProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) {
+// SessionRegenerate generate new sid for postgresql session
+func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=$1", oldsid)
var sessiondata []byte
@@ -213,28 +211,28 @@ func (mp *PostgresqlProvider) SessionRegenerate(oldsid, sid string) (session.Ses
return nil, err
}
}
- rs := &PostgresqlSessionStore{c: c, sid: sid, values: kv}
+ rs := &SessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
-// delete postgresql session by sid
-func (mp *PostgresqlProvider) SessionDestroy(sid string) error {
+// SessionDestroy delete postgresql session by sid
+func (mp *Provider) SessionDestroy(sid string) error {
c := mp.connectInit()
c.Exec("DELETE FROM session where session_key=$1", sid)
c.Close()
return nil
}
-// delete expired values in postgresql session
-func (mp *PostgresqlProvider) SessionGC() {
+// SessionGC delete expired values in postgresql session
+func (mp *Provider) SessionGC() {
c := mp.connectInit()
c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime)
c.Close()
return
}
-// count values in postgresql session
-func (mp *PostgresqlProvider) SessionAll() int {
+// SessionAll count values in postgresql session
+func (mp *Provider) SessionAll() int {
c := mp.connectInit()
defer c.Close()
var total int
diff --git a/session/redis/sess_redis.go b/session/redis/sess_redis.go
index d31feb2c..99a672d5 100644
--- a/session/redis/sess_redis.go
+++ b/session/redis/sess_redis.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// package redis for session provider
+// Package redis for session provider
//
// depend on github.com/garyburd/redigo/redis
//
@@ -30,7 +30,7 @@
// }
//
// more docs: http://beego.me/docs/module/session.md
-package session
+package redis
import (
"net/http"
@@ -43,15 +43,13 @@ import (
"github.com/garyburd/redigo/redis"
)
-var redispder = &RedisProvider{}
+var redispder = &Provider{}
// redis max pool size
-var MAX_POOL_SIZE = 100
+var MaxPoolSize = 100
-var redisPool chan redis.Conn
-
-// redis session store
-type RedisSessionStore struct {
+// SessionStore redis session store
+type SessionStore struct {
p *redis.Pool
sid string
lock sync.RWMutex
@@ -59,61 +57,63 @@ type RedisSessionStore struct {
maxlifetime int64
}
-// set value in redis session
-func (rs *RedisSessionStore) Set(key, value interface{}) error {
+// Set value in redis session
+func (rs *SessionStore) Set(key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
return nil
}
-// get value in redis session
-func (rs *RedisSessionStore) Get(key interface{}) interface{} {
+// Get value in redis session
+func (rs *SessionStore) Get(key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
return v
- } else {
- return nil
}
+ return nil
}
-// delete value in redis session
-func (rs *RedisSessionStore) Delete(key interface{}) error {
+// Delete value in redis session
+func (rs *SessionStore) Delete(key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
return nil
}
-// clear all values in redis session
-func (rs *RedisSessionStore) Flush() error {
+// Flush clear all values in redis session
+func (rs *SessionStore) Flush() error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
return nil
}
-// get redis session id
-func (rs *RedisSessionStore) SessionID() string {
+// SessionID get redis session id
+func (rs *SessionStore) SessionID() string {
return rs.sid
}
-// save session values to redis
-func (rs *RedisSessionStore) SessionRelease(w http.ResponseWriter) {
- c := rs.p.Get()
- defer c.Close()
+// SessionRelease save session values to redis
+func (rs *SessionStore) SessionRelease(w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
}
- c.Do("SETEX", rs.sid, rs.maxlifetime, string(b))
+ c := rs.p.Get()
+ defer c.Close()
+ // Update session value if exists or error.
+ if existed, err := redis.Bool(c.Do("EXISTS", rs.sid)); existed || err != nil {
+ c.Do("SETEX", rs.sid, rs.maxlifetime, string(b))
+ }
}
-// redis session provider
-type RedisProvider struct {
+// Provider redis session provider
+type Provider struct {
maxlifetime int64
savePath string
poolsize int
@@ -122,10 +122,10 @@ type RedisProvider struct {
poollist *redis.Pool
}
-// init redis session
+// SessionInit init redis session
// savepath like redis server addr,pool size,password,dbnum
// e.g. 127.0.0.1:6379,100,astaxie,0
-func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error {
+func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
@@ -134,12 +134,12 @@ func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error {
if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize <= 0 {
- rp.poolsize = MAX_POOL_SIZE
+ rp.poolsize = MaxPoolSize
} else {
rp.poolsize = poolsize
}
} else {
- rp.poolsize = MAX_POOL_SIZE
+ rp.poolsize = MaxPoolSize
}
if len(configs) > 2 {
rp.password = configs[2]
@@ -176,8 +176,8 @@ func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error {
return rp.poollist.Get().Err()
}
-// read redis session by sid
-func (rp *RedisProvider) SessionRead(sid string) (session.SessionStore, error) {
+// SessionRead read redis session by sid
+func (rp *Provider) SessionRead(sid string) (session.Store, error) {
c := rp.poollist.Get()
defer c.Close()
@@ -192,24 +192,23 @@ func (rp *RedisProvider) SessionRead(sid string) (session.SessionStore, error) {
}
}
- rs := &RedisSessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
+ rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
-// check redis session exist by sid
-func (rp *RedisProvider) SessionExist(sid string) bool {
+// SessionExist check redis session exist by sid
+func (rp *Provider) SessionExist(sid string) bool {
c := rp.poollist.Get()
defer c.Close()
if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 {
return false
- } else {
- return true
}
+ return true
}
-// generate new sid for redis session
-func (rp *RedisProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) {
+// SessionRegenerate generate new sid for redis session
+func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
c := rp.poollist.Get()
defer c.Close()
@@ -234,12 +233,12 @@ func (rp *RedisProvider) SessionRegenerate(oldsid, sid string) (session.SessionS
}
}
- rs := &RedisSessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
+ rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
-// delete redis session by id
-func (rp *RedisProvider) SessionDestroy(sid string) error {
+// SessionDestroy delete redis session by id
+func (rp *Provider) SessionDestroy(sid string) error {
c := rp.poollist.Get()
defer c.Close()
@@ -247,13 +246,13 @@ func (rp *RedisProvider) SessionDestroy(sid string) error {
return nil
}
-// Impelment method, no used.
-func (rp *RedisProvider) SessionGC() {
+// SessionGC Impelment method, no used.
+func (rp *Provider) SessionGC() {
return
}
-// @todo
-func (rp *RedisProvider) SessionAll() int {
+// SessionAll return all activeSession
+func (rp *Provider) SessionAll() int {
return 0
}
diff --git a/session/redis/sess_redis_test.go b/session/redis/sess_redis_test.go
new file mode 100644
index 00000000..0c634428
--- /dev/null
+++ b/session/redis/sess_redis_test.go
@@ -0,0 +1,64 @@
+// Copyright 2016 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package redis
+
+import (
+ "testing"
+)
+
+func TestSessionRelease(t *testing.T) {
+
+ provider := Provider{}
+ if err := provider.SessionInit(3, "127.0.0.1:6379"); err != nil {
+ t.Fatal("init session err,", err)
+ }
+
+ sessionID := "beegosessionid_00001"
+
+ session, err := provider.SessionRegenerate("", sessionID)
+ if err != nil {
+ t.Fatal("new session error,", err)
+ }
+
+ // set item.
+ session.Set("k1", "v1")
+ // update.
+ session.SessionRelease(nil)
+
+ session, err = provider.SessionRead(sessionID)
+ if err != nil {
+ t.Fatal("read session error,", err)
+ }
+ if v1 := session.Get("k1"); v1 == nil {
+ t.Fatal("want v1 got nil")
+ } else if v, _ := v1.(string); v != "v1" {
+ t.Fatalf("want v1 got %s", v)
+ }
+
+ // delete
+ provider.SessionDestroy(sessionID)
+ session.Set("k2", "v2")
+
+ session.SessionRelease(nil)
+
+ session, err = provider.SessionRead(sessionID)
+ if err != nil {
+ t.Fatal("read session error,", err)
+ }
+ if session.Get("k1") != nil || session.Get("k2") != nil {
+ t.Fatalf("want emtpy session value,got %s,%s", session.Get("k1"), session.Get("k2"))
+ }
+
+}
diff --git a/session/sess_cookie.go b/session/sess_cookie.go
index 01dc505c..3fefa360 100644
--- a/session/sess_cookie.go
+++ b/session/sess_cookie.go
@@ -25,7 +25,7 @@ import (
var cookiepder = &CookieProvider{}
-// Cookie SessionStore
+// CookieSessionStore Cookie SessionStore
type CookieSessionStore struct {
sid string
values map[interface{}]interface{} // session data
@@ -47,9 +47,8 @@ func (st *CookieSessionStore) Get(key interface{}) interface{} {
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
- } else {
- return nil
}
+ return nil
}
// Delete value in cookie session
@@ -60,7 +59,7 @@ func (st *CookieSessionStore) Delete(key interface{}) error {
return nil
}
-// Clean all values in cookie session
+// Flush Clean all values in cookie session
func (st *CookieSessionStore) Flush() error {
st.lock.Lock()
defer st.lock.Unlock()
@@ -68,12 +67,12 @@ func (st *CookieSessionStore) Flush() error {
return nil
}
-// Return id of this cookie session
+// SessionID Return id of this cookie session
func (st *CookieSessionStore) SessionID() string {
return st.sid
}
-// Write cookie session to http response cookie
+// SessionRelease Write cookie session to http response cookie
func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) {
str, err := encodeCookie(cookiepder.block,
cookiepder.config.SecurityKey,
@@ -101,14 +100,14 @@ type cookieConfig struct {
Maxage int `json:"maxage"`
}
-// Cookie session provider
+// CookieProvider Cookie session provider
type CookieProvider struct {
maxlifetime int64
config *cookieConfig
block cipher.Block
}
-// Init cookie session provider with max lifetime and config json.
+// SessionInit Init cookie session provider with max lifetime and config json.
// maxlifetime is ignored.
// json config:
// securityKey - hash string
@@ -136,9 +135,9 @@ func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error
return nil
}
-// Get SessionStore in cooke.
+// SessionRead Get SessionStore in cooke.
// decode cooke string to map and put into SessionStore with sid.
-func (pder *CookieProvider) SessionRead(sid string) (SessionStore, error) {
+func (pder *CookieProvider) SessionRead(sid string) (Store, error) {
maps, _ := decodeCookie(pder.block,
pder.config.SecurityKey,
pder.config.SecurityName,
@@ -150,32 +149,32 @@ func (pder *CookieProvider) SessionRead(sid string) (SessionStore, error) {
return rs, nil
}
-// Cookie session is always existed
+// SessionExist Cookie session is always existed
func (pder *CookieProvider) SessionExist(sid string) bool {
return true
}
-// Implement method, no used.
-func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
+// SessionRegenerate Implement method, no used.
+func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (Store, error) {
return nil, nil
}
-// Implement method, no used.
+// SessionDestroy Implement method, no used.
func (pder *CookieProvider) SessionDestroy(sid string) error {
return nil
}
-// Implement method, no used.
+// SessionGC Implement method, no used.
func (pder *CookieProvider) SessionGC() {
return
}
-// Implement method, return 0.
+// SessionAll Implement method, return 0.
func (pder *CookieProvider) SessionAll() int {
return 0
}
-// Implement method, no used.
+// SessionUpdate Implement method, no used.
func (pder *CookieProvider) SessionUpdate(sid string) error {
return nil
}
diff --git a/session/sess_cookie_test.go b/session/sess_cookie_test.go
index fe3ac806..b5982260 100644
--- a/session/sess_cookie_test.go
+++ b/session/sess_cookie_test.go
@@ -53,3 +53,44 @@ func TestCookie(t *testing.T) {
}
}
}
+
+func TestDestorySessionCookie(t *testing.T) {
+ config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
+ globalSessions, err := NewManager("cookie", config)
+ if err != nil {
+ t.Fatal("init cookie session err", err)
+ }
+
+ r, _ := http.NewRequest("GET", "/", nil)
+ w := httptest.NewRecorder()
+ session, err := globalSessions.SessionStart(w, r)
+ if err != nil {
+ t.Fatal("session start err,", err)
+ }
+
+ // request again ,will get same sesssion id .
+ r1, _ := http.NewRequest("GET", "/", nil)
+ r1.Header.Set("Cookie", w.Header().Get("Set-Cookie"))
+ w = httptest.NewRecorder()
+ newSession, err := globalSessions.SessionStart(w, r1)
+ if err != nil {
+ t.Fatal("session start err,", err)
+ }
+ if newSession.SessionID() != session.SessionID() {
+ t.Fatal("get cookie session id is not the same again.")
+ }
+
+ // After destory session , will get a new session id .
+ globalSessions.SessionDestroy(w, r1)
+ r2, _ := http.NewRequest("GET", "/", nil)
+ r2.Header.Set("Cookie", w.Header().Get("Set-Cookie"))
+
+ w = httptest.NewRecorder()
+ newSession, err = globalSessions.SessionStart(w, r2)
+ if err != nil {
+ t.Fatal("session start error")
+ }
+ if newSession.SessionID() == session.SessionID() {
+ t.Fatal("after destory session and reqeust again ,get cookie session id is same.")
+ }
+}
diff --git a/session/sess_file.go b/session/sess_file.go
index b1084acf..9265b030 100644
--- a/session/sess_file.go
+++ b/session/sess_file.go
@@ -32,7 +32,7 @@ var (
gcmaxlifetime int64
)
-// File session store
+// FileSessionStore File session store
type FileSessionStore struct {
sid string
lock sync.RWMutex
@@ -53,9 +53,8 @@ func (fs *FileSessionStore) Get(key interface{}) interface{} {
defer fs.lock.RUnlock()
if v, ok := fs.values[key]; ok {
return v
- } else {
- return nil
}
+ return nil
}
// Delete value in file session by given key
@@ -66,7 +65,7 @@ func (fs *FileSessionStore) Delete(key interface{}) error {
return nil
}
-// Clean all values in file session
+// Flush Clean all values in file session
func (fs *FileSessionStore) Flush() error {
fs.lock.Lock()
defer fs.lock.Unlock()
@@ -74,12 +73,12 @@ func (fs *FileSessionStore) Flush() error {
return nil
}
-// Get file session store id
+// SessionID Get file session store id
func (fs *FileSessionStore) SessionID() string {
return fs.sid
}
-// Write file session to local file with Gob string
+// SessionRelease Write file session to local file with Gob string
func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
b, err := EncodeGob(fs.values)
if err != nil {
@@ -100,14 +99,14 @@ func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
f.Close()
}
-// File session provider
+// FileProvider File session provider
type FileProvider struct {
lock sync.RWMutex
maxlifetime int64
savePath string
}
-// Init file session provider.
+// SessionInit Init file session provider.
// savePath sets the session files path.
func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error {
fp.maxlifetime = maxlifetime
@@ -115,10 +114,10 @@ func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error {
return nil
}
-// Read file session by sid.
+// SessionRead Read file session by sid.
// if file is not exist, create it.
// the file path is generated from sid string.
-func (fp *FileProvider) SessionRead(sid string) (SessionStore, error) {
+func (fp *FileProvider) SessionRead(sid string) (Store, error) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
@@ -154,7 +153,7 @@ func (fp *FileProvider) SessionRead(sid string) (SessionStore, error) {
return ss, nil
}
-// Check file session exist.
+// SessionExist Check file session exist.
// it checkes the file named from sid exist or not.
func (fp *FileProvider) SessionExist(sid string) bool {
filepder.lock.Lock()
@@ -163,12 +162,11 @@ func (fp *FileProvider) SessionExist(sid string) bool {
_, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
if err == nil {
return true
- } else {
- return false
}
+ return false
}
-// Remove all files in this save path
+// SessionDestroy Remove all files in this save path
func (fp *FileProvider) SessionDestroy(sid string) error {
filepder.lock.Lock()
defer filepder.lock.Unlock()
@@ -176,7 +174,7 @@ func (fp *FileProvider) SessionDestroy(sid string) error {
return nil
}
-// Recycle files in save path
+// SessionGC Recycle files in save path
func (fp *FileProvider) SessionGC() {
filepder.lock.Lock()
defer filepder.lock.Unlock()
@@ -185,7 +183,7 @@ func (fp *FileProvider) SessionGC() {
filepath.Walk(fp.savePath, gcpath)
}
-// Get active file session number.
+// SessionAll Get active file session number.
// it walks save path to count files.
func (fp *FileProvider) SessionAll() int {
a := &activeSession{}
@@ -199,9 +197,9 @@ func (fp *FileProvider) SessionAll() int {
return a.total
}
-// Generate new sid for file session.
+// SessionRegenerate Generate new sid for file session.
// it delete old file and create new file named from new sid.
-func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
+func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
@@ -269,14 +267,14 @@ type activeSession struct {
total int
}
-func (self *activeSession) visit(paths string, f os.FileInfo, err error) error {
+func (as *activeSession) visit(paths string, f os.FileInfo, err error) error {
if err != nil {
return err
}
if f.IsDir() {
return nil
}
- self.total = self.total + 1
+ as.total = as.total + 1
return nil
}
diff --git a/session/sess_mem.go b/session/sess_mem.go
index dd066703..dd61ef57 100644
--- a/session/sess_mem.go
+++ b/session/sess_mem.go
@@ -23,7 +23,7 @@ import (
var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)}
-// memory session store.
+// MemSessionStore memory session store.
// it saved sessions in a map in memory.
type MemSessionStore struct {
sid string //session id
@@ -32,7 +32,7 @@ type MemSessionStore struct {
lock sync.RWMutex
}
-// set value to memory session
+// Set value to memory session
func (st *MemSessionStore) Set(key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
@@ -40,18 +40,17 @@ func (st *MemSessionStore) Set(key, value interface{}) error {
return nil
}
-// get value from memory session by key
+// Get value from memory session by key
func (st *MemSessionStore) Get(key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.value[key]; ok {
return v
- } else {
- return nil
}
+ return nil
}
-// delete in memory session by key
+// Delete in memory session by key
func (st *MemSessionStore) Delete(key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
@@ -59,7 +58,7 @@ func (st *MemSessionStore) Delete(key interface{}) error {
return nil
}
-// clear all values in memory session
+// Flush clear all values in memory session
func (st *MemSessionStore) Flush() error {
st.lock.Lock()
defer st.lock.Unlock()
@@ -67,15 +66,16 @@ func (st *MemSessionStore) Flush() error {
return nil
}
-// get this id of memory session store
+// SessionID get this id of memory session store
func (st *MemSessionStore) SessionID() string {
return st.sid
}
-// Implement method, no used.
+// SessionRelease Implement method, no used.
func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) {
}
+// MemProvider Implement the provider interface
type MemProvider struct {
lock sync.RWMutex // locker
sessions map[string]*list.Element // map in memory
@@ -84,44 +84,42 @@ type MemProvider struct {
savePath string
}
-// init memory session
+// SessionInit init memory session
func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error {
pder.maxlifetime = maxlifetime
pder.savePath = savePath
return nil
}
-// get memory session store by sid
-func (pder *MemProvider) SessionRead(sid string) (SessionStore, error) {
+// SessionRead get memory session store by sid
+func (pder *MemProvider) SessionRead(sid string) (Store, error) {
pder.lock.RLock()
if element, ok := pder.sessions[sid]; ok {
go pder.SessionUpdate(sid)
pder.lock.RUnlock()
return element.Value.(*MemSessionStore), nil
- } else {
- pder.lock.RUnlock()
- pder.lock.Lock()
- newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})}
- element := pder.list.PushBack(newsess)
- pder.sessions[sid] = element
- pder.lock.Unlock()
- return newsess, nil
}
+ pder.lock.RUnlock()
+ pder.lock.Lock()
+ newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})}
+ element := pder.list.PushBack(newsess)
+ pder.sessions[sid] = element
+ pder.lock.Unlock()
+ return newsess, nil
}
-// check session store exist in memory session by sid
+// SessionExist check session store exist in memory session by sid
func (pder *MemProvider) SessionExist(sid string) bool {
pder.lock.RLock()
defer pder.lock.RUnlock()
if _, ok := pder.sessions[sid]; ok {
return true
- } else {
- return false
}
+ return false
}
-// generate new sid for session store in memory session
-func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
+// SessionRegenerate generate new sid for session store in memory session
+func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) {
pder.lock.RLock()
if element, ok := pder.sessions[oldsid]; ok {
go pder.SessionUpdate(oldsid)
@@ -132,18 +130,17 @@ func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (SessionStore, er
delete(pder.sessions, oldsid)
pder.lock.Unlock()
return element.Value.(*MemSessionStore), nil
- } else {
- pder.lock.RUnlock()
- pder.lock.Lock()
- newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})}
- element := pder.list.PushBack(newsess)
- pder.sessions[sid] = element
- pder.lock.Unlock()
- return newsess, nil
}
+ pder.lock.RUnlock()
+ pder.lock.Lock()
+ newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})}
+ element := pder.list.PushBack(newsess)
+ pder.sessions[sid] = element
+ pder.lock.Unlock()
+ return newsess, nil
}
-// delete session store in memory session by id
+// SessionDestroy delete session store in memory session by id
func (pder *MemProvider) SessionDestroy(sid string) error {
pder.lock.Lock()
defer pder.lock.Unlock()
@@ -155,7 +152,7 @@ func (pder *MemProvider) SessionDestroy(sid string) error {
return nil
}
-// clean expired session stores in memory session
+// SessionGC clean expired session stores in memory session
func (pder *MemProvider) SessionGC() {
pder.lock.RLock()
for {
@@ -177,12 +174,12 @@ func (pder *MemProvider) SessionGC() {
pder.lock.RUnlock()
}
-// get count number of memory session
+// SessionAll get count number of memory session
func (pder *MemProvider) SessionAll() int {
return pder.list.Len()
}
-// expand time of session store by id in memory session
+// SessionUpdate expand time of session store by id in memory session
func (pder *MemProvider) SessionUpdate(sid string) error {
pder.lock.Lock()
defer pder.lock.Unlock()
diff --git a/session/sess_utils.go b/session/sess_utils.go
index 9ae74528..d7db5ba8 100644
--- a/session/sess_utils.go
+++ b/session/sess_utils.go
@@ -43,6 +43,7 @@ func init() {
gob.Register(map[int]int64{})
}
+// EncodeGob encode the obj to gob
func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) {
for _, v := range obj {
gob.Register(v)
@@ -56,6 +57,7 @@ func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) {
return buf.Bytes(), nil
}
+// DecodeGob decode data to map
func DecodeGob(encoded []byte) (map[interface{}]interface{}, error) {
buf := bytes.NewBuffer(encoded)
dec := gob.NewDecoder(buf)
@@ -178,11 +180,11 @@ func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime
return nil, err
}
// 5. DecodeGob.
- if dst, err := DecodeGob(b); err != nil {
+ dst, err := DecodeGob(b)
+ if err != nil {
return nil, err
- } else {
- return dst, nil
}
+ return dst, nil
}
// Encoding -------------------------------------------------------------------
diff --git a/session/session.go b/session/session.go
index f0895de1..39d475fc 100644
--- a/session/session.go
+++ b/session/session.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// package session provider
+// Package session provider
//
// Usage:
// import(
@@ -20,7 +20,7 @@
// )
//
// func init() {
-// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "sessionIDHashFunc": "sha1", "sessionIDHashKey": "", "cookieLifeTime": 3600, "providerConfig": ""}`)
+// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "cookieLifeTime": 3600, "providerConfig": ""}`)
// go globalSessions.GC()
// }
//
@@ -37,8 +37,8 @@ import (
"time"
)
-// SessionStore contains all data for one session process with specific id.
-type SessionStore interface {
+// Store contains all data for one session process with specific id.
+type Store interface {
Set(key, value interface{}) error //set session value
Get(key interface{}) interface{} //get session value
Delete(key interface{}) error //delete session value
@@ -51,9 +51,9 @@ type SessionStore interface {
// it can operate a SessionStore by its id.
type Provider interface {
SessionInit(gclifetime int64, config string) error
- SessionRead(sid string) (SessionStore, error)
+ SessionRead(sid string) (Store, error)
SessionExist(sid string) bool
- SessionRegenerate(oldsid, sid string) (SessionStore, error)
+ SessionRegenerate(oldsid, sid string) (Store, error)
SessionDestroy(sid string) error
SessionAll() int //get all active session
SessionGC()
@@ -83,7 +83,7 @@ type managerConfig struct {
CookieLifeTime int `json:"cookieLifeTime"`
ProviderConfig string `json:"providerConfig"`
Domain string `json:"domain"`
- SessionIdLength int64 `json:"sessionIdLength"`
+ SessionIDLength int64 `json:"sessionIDLength"`
}
// Manager contains Provider and its configuration.
@@ -92,7 +92,7 @@ type Manager struct {
config *managerConfig
}
-// Create new Manager with provider name and json config string.
+// NewManager Create new Manager with provider name and json config string.
// provider name:
// 1. cookie
// 2. file
@@ -123,8 +123,8 @@ func NewManager(provideName, config string) (*Manager, error) {
return nil, err
}
- if cf.SessionIdLength == 0 {
- cf.SessionIdLength = 16
+ if cf.SessionIDLength == 0 {
+ cf.SessionIDLength = 16
}
return &Manager{
@@ -133,104 +133,108 @@ func NewManager(provideName, config string) (*Manager, error) {
}, nil
}
-// Start session. generate or read the session id from http request.
-// if session id exists, return SessionStore with this id.
-func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore, err error) {
+// getSid retrieves session identifier from HTTP Request.
+// First try to retrieve id by reading from cookie, session cookie name is configurable,
+// if not exist, then retrieve id from querying parameters.
+//
+// error is not nil when there is anything wrong.
+// sid is empty when need to generate a new session id
+// otherwise return an valid session id.
+func (manager *Manager) getSid(r *http.Request) (string, error) {
cookie, errs := r.Cookie(manager.config.CookieName)
- if errs != nil || cookie.Value == "" {
- sid, errs := manager.sessionId(r)
+ if errs != nil || cookie.Value == "" || cookie.MaxAge < 0 {
+ errs := r.ParseForm()
if errs != nil {
- return nil, errs
- }
- session, err = manager.provider.SessionRead(sid)
- cookie = &http.Cookie{
- Name: manager.config.CookieName,
- Value: url.QueryEscape(sid),
- Path: "/",
- HttpOnly: true,
- Secure: manager.isSecure(r),
- Domain: manager.config.Domain,
- }
- if manager.config.CookieLifeTime > 0 {
- cookie.MaxAge = manager.config.CookieLifeTime
- cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second)
- }
- if manager.config.EnableSetCookie {
- http.SetCookie(w, cookie)
- }
- r.AddCookie(cookie)
- } else {
- sid, errs := url.QueryUnescape(cookie.Value)
- if errs != nil {
- return nil, errs
- }
- if manager.provider.SessionExist(sid) {
- session, err = manager.provider.SessionRead(sid)
- } else {
- sid, err = manager.sessionId(r)
- if err != nil {
- return nil, err
- }
- session, err = manager.provider.SessionRead(sid)
- cookie = &http.Cookie{
- Name: manager.config.CookieName,
- Value: url.QueryEscape(sid),
- Path: "/",
- HttpOnly: true,
- Secure: manager.isSecure(r),
- Domain: manager.config.Domain,
- }
- if manager.config.CookieLifeTime > 0 {
- cookie.MaxAge = manager.config.CookieLifeTime
- cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second)
- }
- if manager.config.EnableSetCookie {
- http.SetCookie(w, cookie)
- }
- r.AddCookie(cookie)
+ return "", errs
}
+
+ sid := r.FormValue(manager.config.CookieName)
+ return sid, nil
}
+
+ // HTTP Request contains cookie for sessionid info.
+ return url.QueryUnescape(cookie.Value)
+}
+
+// SessionStart generate or read the session id from http request.
+// if session id exists, return SessionStore with this id.
+func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session Store, err error) {
+ sid, errs := manager.getSid(r)
+ if errs != nil {
+ return nil, errs
+ }
+
+ if sid != "" && manager.provider.SessionExist(sid) {
+ return manager.provider.SessionRead(sid)
+ }
+
+ // Generate a new session
+ sid, errs = manager.sessionID()
+ if errs != nil {
+ return nil, errs
+ }
+
+ session, err = manager.provider.SessionRead(sid)
+ cookie := &http.Cookie{
+ Name: manager.config.CookieName,
+ Value: url.QueryEscape(sid),
+ Path: "/",
+ HttpOnly: true,
+ Secure: manager.isSecure(r),
+ Domain: manager.config.Domain,
+ }
+ if manager.config.CookieLifeTime > 0 {
+ cookie.MaxAge = manager.config.CookieLifeTime
+ cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second)
+ }
+ if manager.config.EnableSetCookie {
+ http.SetCookie(w, cookie)
+ }
+ r.AddCookie(cookie)
+
return
}
-// Destroy session by its id in http request cookie.
+// SessionDestroy Destroy session by its id in http request cookie.
func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" {
return
- } else {
- manager.provider.SessionDestroy(cookie.Value)
+ }
+ manager.provider.SessionDestroy(cookie.Value)
+ if manager.config.EnableSetCookie {
expiration := time.Now()
- cookie := http.Cookie{Name: manager.config.CookieName,
+ cookie = &http.Cookie{Name: manager.config.CookieName,
Path: "/",
HttpOnly: true,
Expires: expiration,
MaxAge: -1}
- http.SetCookie(w, &cookie)
+
+ http.SetCookie(w, cookie)
}
}
-// Get SessionStore by its id.
-func (manager *Manager) GetSessionStore(sid string) (sessions SessionStore, err error) {
+// GetSessionStore Get SessionStore by its id.
+func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) {
sessions, err = manager.provider.SessionRead(sid)
return
}
-// Start session gc process.
+// GC Start session gc process.
// it can do gc in times after gc lifetime.
func (manager *Manager) GC() {
manager.provider.SessionGC()
time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() })
}
-// Regenerate a session id for this SessionStore who's id is saving in http request.
-func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) {
- sid, err := manager.sessionId(r)
+// SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request.
+func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) (session Store) {
+ sid, err := manager.sessionID()
if err != nil {
return
}
cookie, err := r.Cookie(manager.config.CookieName)
- if err != nil && cookie.Value == "" {
+ if err != nil || cookie.Value == "" {
//delete old cookie
session, _ = manager.provider.SessionRead(sid)
cookie = &http.Cookie{Name: manager.config.CookieName,
@@ -251,23 +255,25 @@ func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Reque
cookie.MaxAge = manager.config.CookieLifeTime
cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second)
}
- http.SetCookie(w, cookie)
+ if manager.config.EnableSetCookie {
+ http.SetCookie(w, cookie)
+ }
r.AddCookie(cookie)
return
}
-// Get all active sessions count number.
+// GetActiveSession Get all active sessions count number.
func (manager *Manager) GetActiveSession() int {
return manager.provider.SessionAll()
}
-// Set cookie with https.
+// SetSecure Set cookie with https.
func (manager *Manager) SetSecure(secure bool) {
manager.config.Secure = secure
}
-func (manager *Manager) sessionId(r *http.Request) (string, error) {
- b := make([]byte, manager.config.SessionIdLength)
+func (manager *Manager) sessionID() (string, error) {
+ b := make([]byte, manager.config.SessionIDLength)
n, err := rand.Read(b)
if n != len(b) || err != nil {
return "", fmt.Errorf("Could not successfully read from the system CSPRNG.")
diff --git a/staticfile.go b/staticfile.go
index 7c1ed98c..f9f3dc3e 100644
--- a/staticfile.go
+++ b/staticfile.go
@@ -15,107 +15,183 @@
package beego
import (
+ "bytes"
"net/http"
"os"
"path"
+ "path/filepath"
"strconv"
"strings"
+ "sync"
+
+ "errors"
+
+ "time"
"github.com/astaxie/beego/context"
- "github.com/astaxie/beego/utils"
)
+var notStaticRequestErr = errors.New("request not a static file request")
+
func serverStaticRouter(ctx *context.Context) {
if ctx.Input.Method() != "GET" && ctx.Input.Method() != "HEAD" {
return
}
- requestPath := path.Clean(ctx.Input.Request.URL.Path)
- i := 0
- for prefix, staticDir := range StaticDir {
+
+ forbidden, filePath, fileInfo, err := lookupFile(ctx)
+ if err == notStaticRequestErr {
+ return
+ }
+
+ if forbidden {
+ exception("403", ctx)
+ return
+ }
+
+ if filePath == "" || fileInfo == nil {
+ if BConfig.RunMode == DEV {
+ Warn("Can't find/open the file:", filePath, err)
+ }
+ http.NotFound(ctx.ResponseWriter, ctx.Request)
+ return
+ }
+ if fileInfo.IsDir() {
+ //serveFile will list dir
+ http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath)
+ return
+ }
+
+ var enableCompress = BConfig.EnableGzip && isStaticCompress(filePath)
+ var acceptEncoding string
+ if enableCompress {
+ acceptEncoding = context.ParseEncoding(ctx.Request)
+ }
+ b, n, sch, err := openFile(filePath, fileInfo, acceptEncoding)
+ if err != nil {
+ if BConfig.RunMode == DEV {
+ Warn("Can't compress the file:", filePath, err)
+ }
+ http.NotFound(ctx.ResponseWriter, ctx.Request)
+ return
+ }
+
+ if b {
+ ctx.Output.Header("Content-Encoding", n)
+ } else {
+ ctx.Output.Header("Content-Length", strconv.FormatInt(sch.size, 10))
+ }
+
+ http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, sch.modTime, sch)
+ return
+
+}
+
+type serveContentHolder struct {
+ *bytes.Reader
+ modTime time.Time
+ size int64
+ encoding string
+}
+
+var (
+ staticFileMap = make(map[string]*serveContentHolder)
+ mapLock sync.Mutex
+)
+
+func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, error) {
+ mapKey := acceptEncoding + ":" + filePath
+ mapFile, _ := staticFileMap[mapKey]
+ if isOk(mapFile, fi) {
+ return mapFile.encoding != "", mapFile.encoding, mapFile, nil
+ }
+ mapLock.Lock()
+ defer mapLock.Unlock()
+ if mapFile, _ = staticFileMap[mapKey]; !isOk(mapFile, fi) {
+ file, err := os.Open(filePath)
+ if err != nil {
+ return false, "", nil, err
+ }
+ defer file.Close()
+ var bufferWriter bytes.Buffer
+ _, n, err := context.WriteFile(acceptEncoding, &bufferWriter, file)
+ if err != nil {
+ return false, "", nil, err
+ }
+ mapFile = &serveContentHolder{Reader: bytes.NewReader(bufferWriter.Bytes()), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), encoding: n}
+ staticFileMap[mapKey] = mapFile
+ }
+
+ return mapFile.encoding != "", mapFile.encoding, mapFile, nil
+}
+
+func isOk(s *serveContentHolder, fi os.FileInfo) bool {
+ if s == nil {
+ return false
+ }
+ return s.modTime == fi.ModTime() && s.size == fi.Size()
+}
+
+// isStaticCompress detect static files
+func isStaticCompress(filePath string) bool {
+ for _, statExtension := range BConfig.WebConfig.StaticExtensionsToGzip {
+ if strings.HasSuffix(strings.ToLower(filePath), strings.ToLower(statExtension)) {
+ return true
+ }
+ }
+ return false
+}
+
+// searchFile search the file by url path
+// if none the static file prefix matches ,return notStaticRequestErr
+func searchFile(ctx *context.Context) (string, os.FileInfo, error) {
+ requestPath := filepath.ToSlash(filepath.Clean(ctx.Request.URL.Path))
+ // special processing : favicon.ico/robots.txt can be in any static dir
+ if requestPath == "/favicon.ico" || requestPath == "/robots.txt" {
+ file := path.Join(".", requestPath)
+ if fi, _ := os.Stat(file); fi != nil {
+ return file, fi, nil
+ }
+ for _, staticDir := range BConfig.WebConfig.StaticDir {
+ filePath := path.Join(staticDir, requestPath)
+ if fi, _ := os.Stat(filePath); fi != nil {
+ return filePath, fi, nil
+ }
+ }
+ return "", nil, errors.New(requestPath + " file not find")
+ }
+
+ for prefix, staticDir := range BConfig.WebConfig.StaticDir {
if len(prefix) == 0 {
continue
}
- if requestPath == "/favicon.ico" || requestPath == "/robots.txt" {
- file := path.Join(staticDir, requestPath)
- if utils.FileExists(file) {
- http.ServeFile(ctx.ResponseWriter, ctx.Request, file)
- return
- } else {
- i++
- if i == len(StaticDir) {
- http.NotFound(ctx.ResponseWriter, ctx.Request)
- return
- } else {
- continue
- }
- }
+ if !strings.Contains(requestPath, prefix) {
+ continue
}
- if strings.HasPrefix(requestPath, prefix) {
- if len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' {
- continue
- }
- file := path.Join(staticDir, requestPath[len(prefix):])
- finfo, err := os.Stat(file)
- if err != nil {
- if RunMode == "dev" {
- Warn("Can't find the file:", file, err)
- }
- http.NotFound(ctx.ResponseWriter, ctx.Request)
- return
- }
- //if the request is dir and DirectoryIndex is false then
- if finfo.IsDir() {
- if !DirectoryIndex {
- exception("403", ctx)
- return
- } else if ctx.Input.Request.URL.Path[len(ctx.Input.Request.URL.Path)-1] != '/' {
- http.Redirect(ctx.ResponseWriter, ctx.Request, ctx.Input.Request.URL.Path+"/", 302)
- return
- }
- } else if strings.HasSuffix(requestPath, "/index.html") {
- file := path.Join(staticDir, requestPath)
- if utils.FileExists(file) {
- http.ServeFile(ctx.ResponseWriter, ctx.Request, file)
- return
- }
- }
-
- //This block obtained from (https://github.com/smithfox/beego) - it should probably get merged into astaxie/beego after a pull request
- isStaticFileToCompress := false
- if StaticExtensionsToGzip != nil && len(StaticExtensionsToGzip) > 0 {
- for _, statExtension := range StaticExtensionsToGzip {
- if strings.HasSuffix(strings.ToLower(file), strings.ToLower(statExtension)) {
- isStaticFileToCompress = true
- break
- }
- }
- }
-
- if isStaticFileToCompress {
- var contentEncoding string
- if EnableGzip {
- contentEncoding = getAcceptEncodingZip(ctx.Request)
- }
-
- memzipfile, err := openMemZipFile(file, contentEncoding)
- if err != nil {
- return
- }
-
- if contentEncoding == "gzip" {
- ctx.Output.Header("Content-Encoding", "gzip")
- } else if contentEncoding == "deflate" {
- ctx.Output.Header("Content-Encoding", "deflate")
- } else {
- ctx.Output.Header("Content-Length", strconv.FormatInt(finfo.Size(), 10))
- }
-
- http.ServeContent(ctx.ResponseWriter, ctx.Request, file, finfo.ModTime(), memzipfile)
-
- } else {
- http.ServeFile(ctx.ResponseWriter, ctx.Request, file)
- }
- return
+ if len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' {
+ continue
+ }
+ filePath := path.Join(staticDir, requestPath[len(prefix):])
+ if fi, err := os.Stat(filePath); fi != nil {
+ return filePath, fi, err
}
}
+ return "", nil, notStaticRequestErr
+}
+
+// lookupFile find the file to serve
+// if the file is dir ,search the index.html as default file( MUST NOT A DIR also)
+// if the index.html not exist or is a dir, give a forbidden response depending on DirectoryIndex
+func lookupFile(ctx *context.Context) (bool, string, os.FileInfo, error) {
+ fp, fi, err := searchFile(ctx)
+ if fp == "" || fi == nil {
+ return false, "", nil, err
+ }
+ if !fi.IsDir() {
+ return false, fp, fi, err
+ }
+ ifp := filepath.Join(fp, "index.html")
+ if ifi, _ := os.Stat(ifp); ifi != nil && ifi.Mode().IsRegular() {
+ return false, ifp, ifi, err
+ }
+ return !BConfig.WebConfig.DirectoryIndex, fp, fi, err
}
diff --git a/staticfile_test.go b/staticfile_test.go
new file mode 100644
index 00000000..d3333570
--- /dev/null
+++ b/staticfile_test.go
@@ -0,0 +1,71 @@
+package beego
+
+import (
+ "bytes"
+ "compress/gzip"
+ "compress/zlib"
+ "io"
+ "io/ioutil"
+ "os"
+ "testing"
+)
+
+const licenseFile = "./LICENSE"
+
+func testOpenFile(encoding string, content []byte, t *testing.T) {
+ fi, _ := os.Stat(licenseFile)
+ b, n, sch, err := openFile(licenseFile, fi, encoding)
+ if err != nil {
+ t.Log(err)
+ t.Fail()
+ }
+
+ t.Log("open static file encoding "+n, b)
+
+ assetOpenFileAndContent(sch, content, t)
+}
+func TestOpenStaticFile_1(t *testing.T) {
+ file, _ := os.Open(licenseFile)
+ content, _ := ioutil.ReadAll(file)
+ testOpenFile("", content, t)
+}
+
+func TestOpenStaticFileGzip_1(t *testing.T) {
+ file, _ := os.Open(licenseFile)
+ var zipBuf bytes.Buffer
+ fileWriter, _ := gzip.NewWriterLevel(&zipBuf, gzip.BestCompression)
+ io.Copy(fileWriter, file)
+ fileWriter.Close()
+ content, _ := ioutil.ReadAll(&zipBuf)
+
+ testOpenFile("gzip", content, t)
+}
+func TestOpenStaticFileDeflate_1(t *testing.T) {
+ file, _ := os.Open(licenseFile)
+ var zipBuf bytes.Buffer
+ fileWriter, _ := zlib.NewWriterLevel(&zipBuf, zlib.BestCompression)
+ io.Copy(fileWriter, file)
+ fileWriter.Close()
+ content, _ := ioutil.ReadAll(&zipBuf)
+
+ testOpenFile("deflate", content, t)
+}
+
+func assetOpenFileAndContent(sch *serveContentHolder, content []byte, t *testing.T) {
+ t.Log(sch.size, len(content))
+ if sch.size != int64(len(content)) {
+ t.Log("static content file size not same")
+ t.Fail()
+ }
+ bs, _ := ioutil.ReadAll(sch)
+ for i, v := range content {
+ if v != bs[i] {
+ t.Log("content not same")
+ t.Fail()
+ }
+ }
+ if len(staticFileMap) == 0 {
+ t.Log("men map is empty")
+ t.Fail()
+ }
+}
diff --git a/swagger/docsSpec.go b/swagger/docs_spec.go
similarity index 72%
rename from swagger/docsSpec.go
rename to swagger/docs_spec.go
index 6f8fe1db..d8402aa5 100644
--- a/swagger/docsSpec.go
+++ b/swagger/docs_spec.go
@@ -12,53 +12,59 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// swagger struct definition
+// Package swagger struct definition
package swagger
+// SwaggerVersion show the current swagger version
const SwaggerVersion = "1.2"
+// ResourceListing list the resource
type ResourceListing struct {
- ApiVersion string `json:"apiVersion"`
+ APIVersion string `json:"apiVersion"`
SwaggerVersion string `json:"swaggerVersion"` // e.g 1.2
// BasePath string `json:"basePath"` obsolete in 1.1
- Apis []ApiRef `json:"apis"`
- Infos Infomation `json:"info"`
+ APIs []APIRef `json:"apis"`
+ Info Information `json:"info"`
}
-type ApiRef struct {
+// APIRef description the api path and description
+type APIRef struct {
Path string `json:"path"` // relative or absolute, must start with /
Description string `json:"description"`
}
-type Infomation struct {
+// Information show the API Information
+type Information struct {
Title string `json:"title,omitempty"`
Description string `json:"description,omitempty"`
Contact string `json:"contact,omitempty"`
- TermsOfServiceUrl string `json:"termsOfServiceUrl,omitempty"`
+ TermsOfServiceURL string `json:"termsOfServiceUrl,omitempty"`
License string `json:"license,omitempty"`
- LicenseUrl string `json:"licenseUrl,omitempty"`
+ LicenseURL string `json:"licenseUrl,omitempty"`
}
-// https://github.com/wordnik/swagger-core/blob/scala_2.10-1.3-RC3/schemas/api-declaration-schema.json
-type ApiDeclaration struct {
- ApiVersion string `json:"apiVersion"`
+// APIDeclaration see https://github.com/wordnik/swagger-core/blob/scala_2.10-1.3-RC3/schemas/api-declaration-schema.json
+type APIDeclaration struct {
+ APIVersion string `json:"apiVersion"`
SwaggerVersion string `json:"swaggerVersion"`
BasePath string `json:"basePath"`
ResourcePath string `json:"resourcePath"` // must start with /
Consumes []string `json:"consumes,omitempty"`
Produces []string `json:"produces,omitempty"`
- Apis []Api `json:"apis,omitempty"`
+ APIs []API `json:"apis,omitempty"`
Models map[string]Model `json:"models,omitempty"`
}
-type Api struct {
+// API show tha API struct
+type API struct {
Path string `json:"path"` // relative or absolute, must start with /
Description string `json:"description"`
Operations []Operation `json:"operations,omitempty"`
}
+// Operation desc the Operation
type Operation struct {
- HttpMethod string `json:"httpMethod"`
+ HTTPMethod string `json:"httpMethod"`
Nickname string `json:"nickname"`
Type string `json:"type"` // in 1.1 = DataType
// ResponseClass string `json:"responseClass"` obsolete in 1.2
@@ -72,15 +78,18 @@ type Operation struct {
Protocols []Protocol `json:"protocols,omitempty"`
}
+// Protocol support which Protocol
type Protocol struct {
}
+// ResponseMessage Show the
type ResponseMessage struct {
Code int `json:"code"`
Message string `json:"message"`
ResponseModel string `json:"responseModel"`
}
+// Parameter desc the request parameters
type Parameter struct {
ParamType string `json:"paramType"` // path,query,body,header,form
Name string `json:"name"`
@@ -94,17 +103,20 @@ type Parameter struct {
Maximum int `json:"maximum"`
}
+// ErrorResponse desc response
type ErrorResponse struct {
Code int `json:"code"`
Reason string `json:"reason"`
}
+// Model define the data model
type Model struct {
- Id string `json:"id"`
+ ID string `json:"id"`
Required []string `json:"required,omitempty"`
Properties map[string]ModelProperty `json:"properties"`
}
+// ModelProperty define the properties
type ModelProperty struct {
Type string `json:"type"`
Description string `json:"description"`
@@ -112,20 +124,20 @@ type ModelProperty struct {
Format string `json:"format"`
}
-// https://github.com/wordnik/swagger-core/wiki/authorizations
+// Authorization see https://github.com/wordnik/swagger-core/wiki/authorizations
type Authorization struct {
LocalOAuth OAuth `json:"local-oauth"`
- ApiKey ApiKey `json:"apiKey"`
+ APIKey APIKey `json:"apiKey"`
}
-// https://github.com/wordnik/swagger-core/wiki/authorizations
+// OAuth see https://github.com/wordnik/swagger-core/wiki/authorizations
type OAuth struct {
Type string `json:"type"` // e.g. oauth2
Scopes []string `json:"scopes"` // e.g. PUBLIC
GrantTypes map[string]GrantType `json:"grantTypes"`
}
-// https://github.com/wordnik/swagger-core/wiki/authorizations
+// GrantType see https://github.com/wordnik/swagger-core/wiki/authorizations
type GrantType struct {
LoginEndpoint Endpoint `json:"loginEndpoint"`
TokenName string `json:"tokenName"` // e.g. access_code
@@ -133,16 +145,16 @@ type GrantType struct {
TokenEndpoint Endpoint `json:"tokenEndpoint"`
}
-// https://github.com/wordnik/swagger-core/wiki/authorizations
+// Endpoint see https://github.com/wordnik/swagger-core/wiki/authorizations
type Endpoint struct {
- Url string `json:"url"`
- ClientIdName string `json:"clientIdName"`
+ URL string `json:"url"`
+ ClientIDName string `json:"clientIdName"`
ClientSecretName string `json:"clientSecretName"`
TokenName string `json:"tokenName"`
}
-// https://github.com/wordnik/swagger-core/wiki/authorizations
-type ApiKey struct {
+// APIKey see https://github.com/wordnik/swagger-core/wiki/authorizations
+type APIKey struct {
Type string `json:"type"` // e.g. apiKey
PassAs string `json:"passAs"` // e.g. header
}
diff --git a/template.go b/template.go
index 64b1939e..9aac3ea2 100644
--- a/template.go
+++ b/template.go
@@ -28,17 +28,14 @@ import (
)
var (
- beegoTplFuncMap template.FuncMap
- // beego template caching map and supported template file extensions.
- BeeTemplates map[string]*template.Template
- BeeTemplateExt []string
+ beegoTplFuncMap = make(template.FuncMap)
+ // BeeTemplates caching map and supported template file extensions.
+ BeeTemplates = make(map[string]*template.Template)
+ // BeeTemplateExt stores the template extention which will build
+ BeeTemplateExt = []string{"tpl", "html"}
)
func init() {
- BeeTemplates = make(map[string]*template.Template)
- beegoTplFuncMap = make(template.FuncMap)
- BeeTemplateExt = make([]string, 0)
- BeeTemplateExt = append(BeeTemplateExt, "tpl", "html")
beegoTplFuncMap["dateformat"] = DateFormat
beegoTplFuncMap["date"] = Date
beegoTplFuncMap["compare"] = Compare
@@ -46,14 +43,15 @@ func init() {
beegoTplFuncMap["not_nil"] = NotNil
beegoTplFuncMap["not_null"] = NotNil
beegoTplFuncMap["substr"] = Substr
- beegoTplFuncMap["html2str"] = Html2str
+ beegoTplFuncMap["html2str"] = HTML2str
beegoTplFuncMap["str2html"] = Str2html
beegoTplFuncMap["htmlquote"] = Htmlquote
beegoTplFuncMap["htmlunquote"] = Htmlunquote
beegoTplFuncMap["renderform"] = RenderForm
beegoTplFuncMap["assets_js"] = AssetsJs
- beegoTplFuncMap["assets_css"] = AssetsCss
+ beegoTplFuncMap["assets_css"] = AssetsCSS
beegoTplFuncMap["config"] = Config
+ beegoTplFuncMap["map_get"] = MapGet
// go1.2 added template funcs
// Comparisons
@@ -64,7 +62,7 @@ func init() {
beegoTplFuncMap["lt"] = lt // <
beegoTplFuncMap["ne"] = ne // !=
- beegoTplFuncMap["urlfor"] = UrlFor // !=
+ beegoTplFuncMap["urlfor"] = URLFor // !=
}
// AddFuncMap let user to register a func in the template.
@@ -78,7 +76,7 @@ type templatefile struct {
files map[string][]string
}
-func (self *templatefile) visit(paths string, f os.FileInfo, err error) error {
+func (tf *templatefile) visit(paths string, f os.FileInfo, err error) error {
if f == nil {
return err
}
@@ -91,21 +89,21 @@ func (self *templatefile) visit(paths string, f os.FileInfo, err error) error {
replace := strings.NewReplacer("\\", "/")
a := []byte(paths)
- a = a[len([]byte(self.root)):]
+ a = a[len([]byte(tf.root)):]
file := strings.TrimLeft(replace.Replace(string(a)), "/")
subdir := filepath.Dir(file)
- if _, ok := self.files[subdir]; ok {
- self.files[subdir] = append(self.files[subdir], file)
+ if _, ok := tf.files[subdir]; ok {
+ tf.files[subdir] = append(tf.files[subdir], file)
} else {
m := make([]string, 1)
m[0] = file
- self.files[subdir] = m
+ tf.files[subdir] = m
}
return nil
}
-// return this path contains supported template extension of beego or not.
+// HasTemplateExt return this path contains supported template extension of beego or not.
func HasTemplateExt(paths string) bool {
for _, v := range BeeTemplateExt {
if strings.HasSuffix(paths, "."+v) {
@@ -115,7 +113,7 @@ func HasTemplateExt(paths string) bool {
return false
}
-// add new extension for template.
+// AddTemplateExt add new extension for template.
func AddTemplateExt(ext string) {
for _, v := range BeeTemplateExt {
if v == ext {
@@ -125,15 +123,14 @@ func AddTemplateExt(ext string) {
BeeTemplateExt = append(BeeTemplateExt, ext)
}
-// build all template files in a directory.
+// BuildTemplate will build all template files in a directory.
// it makes beego can render any template file in view directory.
-func BuildTemplate(dir string) error {
+func BuildTemplate(dir string, files ...string) error {
if _, err := os.Stat(dir); err != nil {
if os.IsNotExist(err) {
return nil
- } else {
- return errors.New("dir open err")
}
+ return errors.New("dir open err")
}
self := &templatefile{
root: dir,
@@ -148,11 +145,13 @@ func BuildTemplate(dir string) error {
}
for _, v := range self.files {
for _, file := range v {
- t, err := getTemplate(self.root, file, v...)
- if err != nil {
- Trace("parse template err:", file, err)
- } else {
- BeeTemplates[file] = t
+ if len(files) == 0 || utils.InSlice(file, files) {
+ t, err := getTemplate(self.root, file, v...)
+ if err != nil {
+ Trace("parse template err:", file, err)
+ } else {
+ BeeTemplates[file] = t
+ }
}
}
}
@@ -177,7 +176,7 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp
if err != nil {
return nil, [][]string{}, err
}
- reg := regexp.MustCompile(TemplateLeft + "[ ]*template[ ]+\"([^\"]+)\"")
+ reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*template[ ]+\"([^\"]+)\"")
allsub := reg.FindAllStringSubmatch(string(data), -1)
for _, m := range allsub {
if len(m) == 2 {
@@ -198,7 +197,7 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp
}
func getTemplate(root, file string, others ...string) (t *template.Template, err error) {
- t = template.New(file).Delims(TemplateLeft, TemplateRight).Funcs(beegoTplFuncMap)
+ t = template.New(file).Delims(BConfig.WebConfig.TemplateLeft, BConfig.WebConfig.TemplateRight).Funcs(beegoTplFuncMap)
var submods [][]string
t, submods, err = getTplDeep(root, file, "", t)
if err != nil {
@@ -240,7 +239,7 @@ func _getTemplate(t0 *template.Template, root string, submods [][]string, others
if err != nil {
continue
}
- reg := regexp.MustCompile(TemplateLeft + "[ ]*define[ ]+\"([^\"]+)\"")
+ reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*define[ ]+\"([^\"]+)\"")
allsub := reg.FindAllStringSubmatch(string(data), -1)
for _, sub := range allsub {
if len(sub) == 2 && sub[1] == m[1] {
@@ -260,3 +259,30 @@ func _getTemplate(t0 *template.Template, root string, submods [][]string, others
}
return
}
+
+// SetViewsPath sets view directory path in beego application.
+func SetViewsPath(path string) *App {
+ BConfig.WebConfig.ViewsPath = path
+ return BeeApp
+}
+
+// SetStaticPath sets static directory path and proper url pattern in beego application.
+// if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public".
+func SetStaticPath(url string, path string) *App {
+ if !strings.HasPrefix(url, "/") {
+ url = "/" + url
+ }
+ url = strings.TrimRight(url, "/")
+ BConfig.WebConfig.StaticDir[url] = path
+ return BeeApp
+}
+
+// DelStaticPath removes the static folder setting in this url pattern in beego application.
+func DelStaticPath(url string) *App {
+ if !strings.HasPrefix(url, "/") {
+ url = "/" + url
+ }
+ url = strings.TrimRight(url, "/")
+ delete(BConfig.WebConfig.StaticDir, url)
+ return BeeApp
+}
diff --git a/template_test.go b/template_test.go
index b35da5ce..2e222efc 100644
--- a/template_test.go
+++ b/template_test.go
@@ -20,11 +20,11 @@ import (
"testing"
)
-var header string = `{{define "header"}}
+var header = `{{define "header"}}
Hello, astaxie!
{{end}}`
-var index string = `
+var index = `
beego welcome template
@@ -37,7 +37,7 @@ var index string = `
`
-var block string = `{{define "block"}}
+var block = `{{define "block"}}