diff --git a/.gitignore b/.gitignore
index 39ae5706..9806457b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,3 +2,4 @@
.DS_Store
*.swp
*.swo
+beego.iml
diff --git a/.travis.yml b/.travis.yml
new file mode 100644
index 00000000..c59cef61
--- /dev/null
+++ b/.travis.yml
@@ -0,0 +1,30 @@
+language: go
+
+go:
+ - 1.5.1
+
+services:
+ - redis-server
+ - mysql
+ - postgresql
+ - memcached
+env:
+ - ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db
+ - ORM_DRIVER=mysql ORM_SOURCE="root:@/orm_test?charset=utf8"
+ - ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
+install:
+ - go get github.com/lib/pq
+ - go get github.com/go-sql-driver/mysql
+ - go get github.com/mattn/go-sqlite3
+ - go get github.com/bradfitz/gomemcache/memcache
+ - go get github.com/garyburd/redigo/redis
+ - go get github.com/beego/x2j
+ - go get github.com/beego/goyaml2
+ - go get github.com/belogik/goes
+ - go get github.com/couchbase/go-couchbase
+ - go get github.com/siddontang/ledisdb/config
+ - go get github.com/siddontang/ledisdb/ledis
+before_script:
+ - sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi"
+ - sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi"
+ - sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi"
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 00000000..9d511616
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,52 @@
+# Contributing to beego
+
+beego is an open source project.
+
+It is the work of hundreds of contributors. We appreciate your help!
+
+Here are instructions to get you started. They are probably not perfect,
+please let us know if anything feels wrong or incomplete.
+
+## Contribution guidelines
+
+### Pull requests
+
+First of all. beego follow the gitflow. So please send you pull request
+to **develop** branch. We will close the pull request to master branch.
+
+We are always happy to receive pull requests, and do our best to
+review them as fast as possible. Not sure if that typo is worth a pull
+request? Do it! We will appreciate it.
+
+If your pull request is not accepted on the first try, don't be
+discouraged! Sometimes we can make a mistake, please do more explaining
+for us. We will appreciate it.
+
+We're trying very hard to keep beego simple and fast. We don't want it
+to do everything for everybody. This means that we might decide against
+incorporating a new feature. But we will give you some advice on how to
+do it in other way.
+
+### Create issues
+
+Any significant improvement should be documented as [a GitHub
+issue](https://github.com/astaxie/beego/issues) before anybody
+starts working on it.
+
+Also when filing an issue, make sure to answer these five questions:
+
+- What version of beego are you using (bee version)?
+- What operating system and processor architecture are you using?
+- What did you do?
+- What did you expect to see?
+- What did you see instead?
+
+### but check existing issues and docs first!
+
+Please take a moment to check that an issue doesn't already exist
+documenting your bug report or improvement proposal. If it does, it
+never hurts to add a quick "+1" or "I have this problem too". This will
+help prioritize the most common problems and requests.
+
+Also if you don't know how to use it. please make sure you have read though
+the docs in http://beego.me/docs
\ No newline at end of file
diff --git a/README.md b/README.md
index 7b650887..fec6113f 100644
--- a/README.md
+++ b/README.md
@@ -1,16 +1,38 @@
## Beego
-[![Build Status](https://drone.io/github.com/astaxie/beego/status.png)](https://drone.io/github.com/astaxie/beego/latest)
+[![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego)
[![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego)
-beego is an open-source, high-performance, modular, full-stack web framework.
+beego is used for rapid development of RESTful APIs, web apps and backend services in Go.
+It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct embedding.
More info [beego.me](http://beego.me)
-## Installation
+##Quick Start
+######Download and install
go get github.com/astaxie/beego
+######Create file `hello.go`
+```go
+package main
+
+import "github.com/astaxie/beego"
+
+func main(){
+ beego.Run()
+}
+```
+######Build and run
+```bash
+ go build hello.go
+ ./hello
+```
+######Congratulations!
+You just built your first beego app.
+Open your browser and visit `http://localhost:8000`.
+Please see [Documentation](http://beego.me/docs) for more.
+
## Features
* RESTful support
@@ -26,6 +48,7 @@ More info [beego.me](http://beego.me)
* [English](http://beego.me/docs/intro/)
* [中文文档](http://beego.me/docs/intro/)
+* [Русский](http://beego.me/docs/intro/)
## Community
@@ -33,5 +56,5 @@ More info [beego.me](http://beego.me)
## LICENSE
-beego is licensed under the Apache Licence, Version 2.0
+beego source code is licensed under the Apache Licence, Version 2.0
(http://www.apache.org/licenses/LICENSE-2.0.html).
diff --git a/admin.go b/admin.go
index 90142c55..3effc582 100644
--- a/admin.go
+++ b/admin.go
@@ -19,9 +19,11 @@ import (
"encoding/json"
"fmt"
"net/http"
+ "os"
"text/template"
"time"
+ "github.com/astaxie/beego/grace"
"github.com/astaxie/beego/toolbox"
"github.com/astaxie/beego/utils"
)
@@ -63,24 +65,15 @@ func init() {
// AdminIndex is the default http.Handler for admin module.
// it matches url pattern "/".
func adminIndex(rw http.ResponseWriter, r *http.Request) {
- tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
- tmpl = template.Must(tmpl.Parse(indexTpl))
- tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
- data := make(map[interface{}]interface{})
- tmpl.Execute(rw, data)
+ execTpl(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl)
}
// QpsIndex is the http.Handler for writing qbs statistics map result info in http.ResponseWriter.
// it's registered with url pattern "/qbs" in admin module.
func qpsIndex(rw http.ResponseWriter, r *http.Request) {
- tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
- tmpl = template.Must(tmpl.Parse(qpsTpl))
- tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
data := make(map[interface{}]interface{})
data["Content"] = toolbox.StatisticsMap.GetMap()
-
- tmpl.Execute(rw, data)
-
+ execTpl(rw, data, qpsTpl, defaultScriptsTpl)
}
// ListConf is the http.Handler of displaying all beego configuration values as key/value pair.
@@ -88,178 +81,145 @@ func qpsIndex(rw http.ResponseWriter, r *http.Request) {
func listConf(rw http.ResponseWriter, r *http.Request) {
r.ParseForm()
command := r.Form.Get("command")
- if command != "" {
- data := make(map[interface{}]interface{})
- switch command {
- case "conf":
- m := make(map[string]interface{})
+ if command == "" {
+ rw.Write([]byte("command not support"))
+ return
+ }
- m["AppName"] = AppName
- m["AppPath"] = AppPath
- m["AppConfigPath"] = AppConfigPath
- m["StaticDir"] = StaticDir
- m["StaticExtensionsToGzip"] = StaticExtensionsToGzip
- m["HttpAddr"] = HttpAddr
- m["HttpPort"] = HttpPort
- m["HttpTLS"] = EnableHttpTLS
- m["HttpCertFile"] = HttpCertFile
- m["HttpKeyFile"] = HttpKeyFile
- m["RecoverPanic"] = RecoverPanic
- m["AutoRender"] = AutoRender
- m["ViewsPath"] = ViewsPath
- m["RunMode"] = RunMode
- m["SessionOn"] = SessionOn
- m["SessionProvider"] = SessionProvider
- m["SessionName"] = SessionName
- m["SessionGCMaxLifetime"] = SessionGCMaxLifetime
- m["SessionSavePath"] = SessionSavePath
- m["SessionCookieLifeTime"] = SessionCookieLifeTime
- m["UseFcgi"] = UseFcgi
- m["MaxMemory"] = MaxMemory
- m["EnableGzip"] = EnableGzip
- m["DirectoryIndex"] = DirectoryIndex
- m["HttpServerTimeOut"] = HttpServerTimeOut
- m["ErrorsShow"] = ErrorsShow
- m["XSRFKEY"] = XSRFKEY
- m["EnableXSRF"] = EnableXSRF
- m["XSRFExpire"] = XSRFExpire
- m["CopyRequestBody"] = CopyRequestBody
- m["TemplateLeft"] = TemplateLeft
- m["TemplateRight"] = TemplateRight
- m["BeegoServerName"] = BeegoServerName
- m["EnableAdmin"] = EnableAdmin
- m["AdminHttpAddr"] = AdminHttpAddr
- m["AdminHttpPort"] = AdminHttpPort
+ data := make(map[interface{}]interface{})
+ switch command {
+ case "conf":
+ m := make(map[string]interface{})
+ m["AppConfigPath"] = AppConfigPath
+ m["AppConfigProvider"] = AppConfigProvider
+ m["BConfig.AppName"] = BConfig.AppName
+ m["BConfig.RunMode"] = BConfig.RunMode
+ m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive
+ m["BConfig.ServerName"] = BConfig.ServerName
+ m["BConfig.RecoverPanic"] = BConfig.RecoverPanic
+ m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody
+ m["BConfig.EnableGzip"] = BConfig.EnableGzip
+ m["BConfig.MaxMemory"] = BConfig.MaxMemory
+ m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow
+ m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful
+ m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut
+ m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4
+ m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP
+ m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr
+ m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort
+ m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS
+ m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr
+ m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort
+ m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile
+ m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile
+ m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin
+ m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr
+ m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort
+ m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi
+ m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo
+ m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender
+ m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs
+ m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName
+ m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator
+ m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex
+ m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir
+ m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip
+ m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft
+ m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight
+ m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath
+ m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF
+ m["BConfig.WebConfig.XSRFKEY"] = BConfig.WebConfig.XSRFKey
+ m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire
+ m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn
+ m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider
+ m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName
+ m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime
+ m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig
+ m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime
+ m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie
+ m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain
+ m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs
+ m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum
+ m["BConfig.Log.Outputs"] = BConfig.Log.Outputs
+ tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
+ tmpl = template.Must(tmpl.Parse(configTpl))
+ tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
- tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
- tmpl = template.Must(tmpl.Parse(configTpl))
- tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
+ data["Content"] = m
- data["Content"] = m
+ tmpl.Execute(rw, data)
- tmpl.Execute(rw, data)
-
- case "router":
- content := make(map[string]interface{})
-
- var fields = []string{
- fmt.Sprintf("Router Pattern"),
- fmt.Sprintf("Methods"),
- fmt.Sprintf("Controller"),
+ case "router":
+ var (
+ content = map[string]interface{}{
+ "Fields": []string{
+ "Router Pattern",
+ "Methods",
+ "Controller",
+ },
}
- content["Fields"] = fields
+ methods = []string{}
+ methodsData = make(map[string]interface{})
+ )
+ for method, t := range BeeApp.Handlers.routers {
- methods := []string{}
- methodsData := make(map[string]interface{})
- for method, t := range BeeApp.Handlers.routers {
+ resultList := new([][]string)
- resultList := new([][]string)
+ printTree(resultList, t)
- printTree(resultList, t)
-
- methods = append(methods, method)
- methodsData[method] = resultList
- }
-
- content["Data"] = methodsData
- content["Methods"] = methods
- data["Content"] = content
- data["Title"] = "Routers"
-
- tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
- tmpl = template.Must(tmpl.Parse(routerAndFilterTpl))
- tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
- tmpl.Execute(rw, data)
- case "filter":
- content := make(map[string]interface{})
-
- var fields = []string{
- fmt.Sprintf("Router Pattern"),
- fmt.Sprintf("Filter Function"),
- }
- content["Fields"] = fields
-
- filterTypes := []string{}
- filterTypeData := make(map[string]interface{})
-
- if BeeApp.Handlers.enableFilter {
- var filterType string
-
- if bf, ok := BeeApp.Handlers.filters[BeforeRouter]; ok {
- filterType = "Before Router"
- filterTypes = append(filterTypes, filterType)
- resultList := new([][]string)
- for _, f := range bf {
-
- var result = []string{
- fmt.Sprintf("%s", f.pattern),
- fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)),
- }
- *resultList = append(*resultList, result)
- }
- filterTypeData[filterType] = resultList
- }
-
- if bf, ok := BeeApp.Handlers.filters[BeforeExec]; ok {
- filterType = "Before Exec"
- filterTypes = append(filterTypes, filterType)
- resultList := new([][]string)
- for _, f := range bf {
-
- var result = []string{
- fmt.Sprintf("%s", f.pattern),
- fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)),
- }
- *resultList = append(*resultList, result)
- }
- filterTypeData[filterType] = resultList
- }
-
- if bf, ok := BeeApp.Handlers.filters[AfterExec]; ok {
- filterType = "After Exec"
- filterTypes = append(filterTypes, filterType)
- resultList := new([][]string)
- for _, f := range bf {
-
- var result = []string{
- fmt.Sprintf("%s", f.pattern),
- fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)),
- }
- *resultList = append(*resultList, result)
- }
- filterTypeData[filterType] = resultList
- }
-
- if bf, ok := BeeApp.Handlers.filters[FinishRouter]; ok {
- filterType = "Finish Router"
- filterTypes = append(filterTypes, filterType)
- resultList := new([][]string)
- for _, f := range bf {
-
- var result = []string{
- fmt.Sprintf("%s", f.pattern),
- fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)),
- }
- *resultList = append(*resultList, result)
- }
- filterTypeData[filterType] = resultList
- }
- }
-
- content["Data"] = filterTypeData
- content["Methods"] = filterTypes
-
- data["Content"] = content
- data["Title"] = "Filters"
- tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
- tmpl = template.Must(tmpl.Parse(routerAndFilterTpl))
- tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
- tmpl.Execute(rw, data)
-
- default:
- rw.Write([]byte("command not support"))
+ methods = append(methods, method)
+ methodsData[method] = resultList
}
- } else {
+
+ content["Data"] = methodsData
+ content["Methods"] = methods
+ data["Content"] = content
+ data["Title"] = "Routers"
+ execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl)
+ case "filter":
+ var (
+ content = map[string]interface{}{
+ "Fields": []string{
+ "Router Pattern",
+ "Filter Function",
+ },
+ }
+ filterTypes = []string{}
+ filterTypeData = make(map[string]interface{})
+ )
+
+ if BeeApp.Handlers.enableFilter {
+ var filterType string
+ for k, fr := range map[int]string{
+ BeforeStatic: "Before Static",
+ BeforeRouter: "Before Router",
+ BeforeExec: "Before Exec",
+ AfterExec: "After Exec",
+ FinishRouter: "Finish Router"} {
+ if bf, ok := BeeApp.Handlers.filters[k]; ok {
+ filterType = fr
+ filterTypes = append(filterTypes, filterType)
+ resultList := new([][]string)
+ for _, f := range bf {
+ var result = []string{
+ fmt.Sprintf("%s", f.pattern),
+ fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)),
+ }
+ *resultList = append(*resultList, result)
+ }
+ filterTypeData[filterType] = resultList
+ }
+ }
+ }
+
+ content["Data"] = filterTypeData
+ content["Methods"] = filterTypes
+
+ data["Content"] = content
+ data["Title"] = "Filters"
+ execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl)
+ default:
+ rw.Write([]byte("command not support"))
}
}
@@ -274,23 +234,23 @@ func printTree(resultList *[][]string, t *Tree) {
if v, ok := l.runObject.(*controllerInfo); ok {
if v.routerType == routerTypeBeego {
var result = []string{
- fmt.Sprintf("%s", v.pattern),
+ v.pattern,
fmt.Sprintf("%s", v.methods),
fmt.Sprintf("%s", v.controllerType),
}
*resultList = append(*resultList, result)
} else if v.routerType == routerTypeRESTFul {
var result = []string{
- fmt.Sprintf("%s", v.pattern),
+ v.pattern,
fmt.Sprintf("%s", v.methods),
- fmt.Sprintf(""),
+ "",
}
*resultList = append(*resultList, result)
} else if v.routerType == routerTypeHandler {
var result = []string{
- fmt.Sprintf("%s", v.pattern),
- fmt.Sprintf(""),
- fmt.Sprintf(""),
+ v.pattern,
+ "",
+ "",
}
*resultList = append(*resultList, result)
}
@@ -303,54 +263,49 @@ func printTree(resultList *[][]string, t *Tree) {
func profIndex(rw http.ResponseWriter, r *http.Request) {
r.ParseForm()
command := r.Form.Get("command")
- format := r.Form.Get("format")
- data := make(map[string]interface{})
+ if command == "" {
+ return
+ }
- var result bytes.Buffer
- if command != "" {
- toolbox.ProcessInput(command, &result)
- data["Content"] = result.String()
+ var (
+ format = r.Form.Get("format")
+ data = make(map[interface{}]interface{})
+ result bytes.Buffer
+ )
+ toolbox.ProcessInput(command, &result)
+ data["Content"] = result.String()
- if format == "json" && command == "gc summary" {
- dataJson, err := json.Marshal(data)
- if err != nil {
- http.Error(rw, err.Error(), http.StatusInternalServerError)
- return
- }
-
- rw.Header().Set("Content-Type", "application/json")
- rw.Write(dataJson)
+ if format == "json" && command == "gc summary" {
+ dataJSON, err := json.Marshal(data)
+ if err != nil {
+ http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
- data["Title"] = command
- tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
- tmpl = template.Must(tmpl.Parse(profillingTpl))
- if command == "gc summary" {
- tmpl = template.Must(tmpl.Parse(gcAjaxTpl))
- } else {
-
- tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
- }
- tmpl.Execute(rw, data)
- } else {
+ rw.Header().Set("Content-Type", "application/json")
+ rw.Write(dataJSON)
+ return
}
+
+ data["Title"] = command
+ defaultTpl := defaultScriptsTpl
+ if command == "gc summary" {
+ defaultTpl = gcAjaxTpl
+ }
+ execTpl(rw, data, profillingTpl, defaultTpl)
}
// Healthcheck is a http.Handler calling health checking and showing the result.
// it's in "/healthcheck" pattern in admin module.
func healthcheck(rw http.ResponseWriter, req *http.Request) {
- data := make(map[interface{}]interface{})
-
- var result = []string{}
- fields := []string{
- fmt.Sprintf("Name"),
- fmt.Sprintf("Message"),
- fmt.Sprintf("Status"),
- }
- resultList := new([][]string)
-
- content := make(map[string]interface{})
+ var (
+ data = make(map[interface{}]interface{})
+ result = []string{}
+ resultList = new([][]string)
+ content = map[string]interface{}{
+ "Fields": []string{"Name", "Message", "Status"},
+ }
+ )
for name, h := range toolbox.AdminCheckList {
if err := h.Check(); err != nil {
@@ -370,16 +325,10 @@ func healthcheck(rw http.ResponseWriter, req *http.Request) {
}
*resultList = append(*resultList, result)
}
-
- content["Fields"] = fields
content["Data"] = resultList
data["Content"] = content
data["Title"] = "Health Check"
- tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
- tmpl = template.Must(tmpl.Parse(healthCheckTpl))
- tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
- tmpl.Execute(rw, data)
-
+ execTpl(rw, data, healthCheckTpl, defaultScriptsTpl)
}
// TaskStatus is a http.Handler with running task status (task name, status and the last execution).
@@ -391,10 +340,8 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) {
req.ParseForm()
taskname := req.Form.Get("taskname")
if taskname != "" {
-
if t, ok := toolbox.AdminTaskList[taskname]; ok {
- err := t.Run()
- if err != nil {
+ if err := t.Run(); err != nil {
data["Message"] = []string{"error", fmt.Sprintf("%s", err)}
}
data["Message"] = []string{"success", fmt.Sprintf("%s run success,Now the Status is %s", taskname, t.GetStatus())}
@@ -408,18 +355,18 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) {
resultList := new([][]string)
var result = []string{}
var fields = []string{
- fmt.Sprintf("Task Name"),
- fmt.Sprintf("Task Spec"),
- fmt.Sprintf("Task Status"),
- fmt.Sprintf("Last Time"),
- fmt.Sprintf(""),
+ "Task Name",
+ "Task Spec",
+ "Task Status",
+ "Last Time",
+ "",
}
for tname, tk := range toolbox.AdminTaskList {
result = []string{
- fmt.Sprintf("%s", tname),
+ tname,
fmt.Sprintf("%s", tk.GetSpec()),
fmt.Sprintf("%s", tk.GetStatus()),
- fmt.Sprintf("%s", tk.GetPrev().String()),
+ tk.GetPrev().String(),
}
*resultList = append(*resultList, result)
}
@@ -428,9 +375,14 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) {
content["Data"] = resultList
data["Content"] = content
data["Title"] = "Tasks"
+ execTpl(rw, data, tasksTpl, defaultScriptsTpl)
+}
+
+func execTpl(rw http.ResponseWriter, data map[interface{}]interface{}, tpls ...string) {
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
- tmpl = template.Must(tmpl.Parse(tasksTpl))
- tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
+ for _, tpl := range tpls {
+ tmpl = template.Must(tmpl.Parse(tpl))
+ }
tmpl.Execute(rw, data)
}
@@ -450,17 +402,23 @@ func (admin *adminApp) Run() {
if len(toolbox.AdminTaskList) > 0 {
toolbox.StartTask()
}
- addr := AdminHttpAddr
+ addr := BConfig.Listen.AdminAddr
- if AdminHttpPort != 0 {
- addr = fmt.Sprintf("%s:%d", AdminHttpAddr, AdminHttpPort)
+ if BConfig.Listen.AdminPort != 0 {
+ addr = fmt.Sprintf("%s:%d", BConfig.Listen.AdminAddr, BConfig.Listen.AdminPort)
}
for p, f := range admin.routers {
http.Handle(p, f)
}
BeeLogger.Info("Admin server Running on %s", addr)
- err := http.ListenAndServe(addr, nil)
+
+ var err error
+ if BConfig.Listen.Graceful {
+ err = grace.ListenAndServe(addr, nil)
+ } else {
+ err = http.ListenAndServe(addr, nil)
+ }
if err != nil {
- BeeLogger.Critical("Admin ListenAndServe: ", err)
+ BeeLogger.Critical("Admin ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid()))
}
}
diff --git a/app.go b/app.go
index 35040f33..af54ea4b 100644
--- a/app.go
+++ b/app.go
@@ -20,14 +20,26 @@ import (
"net/http"
"net/http/fcgi"
"os"
+ "path"
"time"
+ "github.com/astaxie/beego/grace"
"github.com/astaxie/beego/utils"
)
+var (
+ // BeeApp is an application instance
+ BeeApp *App
+)
+
+func init() {
+ // create beego application
+ BeeApp = NewApp()
+}
+
// App defines beego application with a new PatternServeMux.
type App struct {
- Handlers *ControllerRegistor
+ Handlers *ControllerRegister
Server *http.Server
}
@@ -40,93 +52,311 @@ func NewApp() *App {
// Run beego application.
func (app *App) Run() {
- addr := HttpAddr
+ addr := BConfig.Listen.HTTPAddr
- if HttpPort != 0 {
- addr = fmt.Sprintf("%s:%d", HttpAddr, HttpPort)
+ if BConfig.Listen.HTTPPort != 0 {
+ addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPAddr, BConfig.Listen.HTTPPort)
}
var (
- err error
- l net.Listener
+ err error
+ l net.Listener
+ endRunning = make(chan bool, 1)
)
- endRunning := make(chan bool, 1)
- if UseFcgi {
- if UseStdIo {
- err = fcgi.Serve(nil, app.Handlers) // standard I/O
- if err == nil {
+ // run cgi server
+ if BConfig.Listen.EnableFcgi {
+ if BConfig.Listen.EnableStdIo {
+ if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O
BeeLogger.Info("Use FCGI via standard I/O")
} else {
- BeeLogger.Info("Cannot use FCGI via standard I/O", err)
+ BeeLogger.Critical("Cannot use FCGI via standard I/O", err)
}
- } else {
- if HttpPort == 0 {
- // remove the Socket file before start
- if utils.FileExists(addr) {
- os.Remove(addr)
- }
- l, err = net.Listen("unix", addr)
- } else {
- l, err = net.Listen("tcp", addr)
- }
- if err != nil {
- BeeLogger.Critical("Listen: ", err)
- }
- err = fcgi.Serve(l, app.Handlers)
+ return
}
- } else {
- app.Server.Addr = addr
- app.Server.Handler = app.Handlers
- app.Server.ReadTimeout = time.Duration(HttpServerTimeOut) * time.Second
- app.Server.WriteTimeout = time.Duration(HttpServerTimeOut) * time.Second
+ if BConfig.Listen.HTTPPort == 0 {
+ // remove the Socket file before start
+ if utils.FileExists(addr) {
+ os.Remove(addr)
+ }
+ l, err = net.Listen("unix", addr)
+ } else {
+ l, err = net.Listen("tcp", addr)
+ }
+ if err != nil {
+ BeeLogger.Critical("Listen: ", err)
+ }
+ if err = fcgi.Serve(l, app.Handlers); err != nil {
+ BeeLogger.Critical("fcgi.Serve: ", err)
+ }
+ return
+ }
- if EnableHttpTLS {
+ app.Server.Handler = app.Handlers
+ app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second
+ app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second
+
+ // run graceful mode
+ if BConfig.Listen.Graceful {
+ httpsAddr := BConfig.Listen.HTTPSAddr
+ app.Server.Addr = httpsAddr
+ if BConfig.Listen.EnableHTTPS {
go func() {
time.Sleep(20 * time.Microsecond)
- if HttpsPort != 0 {
- app.Server.Addr = fmt.Sprintf("%s:%d", HttpAddr, HttpsPort)
+ if BConfig.Listen.HTTPSPort != 0 {
+ httpsAddr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort)
+ app.Server.Addr = httpsAddr
}
- BeeLogger.Info("https server Running on %s", app.Server.Addr)
- err := app.Server.ListenAndServeTLS(HttpCertFile, HttpKeyFile)
- if err != nil {
- BeeLogger.Critical("ListenAndServeTLS: ", err)
+ server := grace.NewServer(httpsAddr, app.Handlers)
+ server.Server.ReadTimeout = app.Server.ReadTimeout
+ server.Server.WriteTimeout = app.Server.WriteTimeout
+ if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
+ BeeLogger.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
}()
}
-
- if EnableHttpListen {
+ if BConfig.Listen.EnableHTTP {
go func() {
- app.Server.Addr = addr
- BeeLogger.Info("http server Running on %s", app.Server.Addr)
- if ListenTCP4 && HttpAddr == "" {
- ln, err := net.Listen("tcp4", app.Server.Addr)
- if err != nil {
- BeeLogger.Critical("ListenAndServe: ", err)
- time.Sleep(100 * time.Microsecond)
- endRunning <- true
- return
- }
- err = app.Server.Serve(ln)
- if err != nil {
- BeeLogger.Critical("ListenAndServe: ", err)
- time.Sleep(100 * time.Microsecond)
- endRunning <- true
- return
- }
- } else {
- err := app.Server.ListenAndServe()
- if err != nil {
- BeeLogger.Critical("ListenAndServe: ", err)
- time.Sleep(100 * time.Microsecond)
- endRunning <- true
- }
+ server := grace.NewServer(addr, app.Handlers)
+ server.Server.ReadTimeout = app.Server.ReadTimeout
+ server.Server.WriteTimeout = app.Server.WriteTimeout
+ if BConfig.Listen.ListenTCP4 {
+ server.Network = "tcp4"
+ }
+ if err := server.ListenAndServe(); err != nil {
+ BeeLogger.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid()))
+ time.Sleep(100 * time.Microsecond)
+ endRunning <- true
}
}()
}
+ <-endRunning
+ return
}
+ // run normal mode
+ app.Server.Addr = addr
+ if BConfig.Listen.EnableHTTPS {
+ go func() {
+ time.Sleep(20 * time.Microsecond)
+ if BConfig.Listen.HTTPSPort != 0 {
+ app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort)
+ }
+ BeeLogger.Info("https server Running on %s", app.Server.Addr)
+ if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
+ BeeLogger.Critical("ListenAndServeTLS: ", err)
+ time.Sleep(100 * time.Microsecond)
+ endRunning <- true
+ }
+ }()
+ }
+ if BConfig.Listen.EnableHTTP {
+ go func() {
+ app.Server.Addr = addr
+ BeeLogger.Info("http server Running on %s", app.Server.Addr)
+ if BConfig.Listen.ListenTCP4 {
+ ln, err := net.Listen("tcp4", app.Server.Addr)
+ if err != nil {
+ BeeLogger.Critical("ListenAndServe: ", err)
+ time.Sleep(100 * time.Microsecond)
+ endRunning <- true
+ return
+ }
+ if err = app.Server.Serve(ln); err != nil {
+ BeeLogger.Critical("ListenAndServe: ", err)
+ time.Sleep(100 * time.Microsecond)
+ endRunning <- true
+ return
+ }
+ } else {
+ if err := app.Server.ListenAndServe(); err != nil {
+ BeeLogger.Critical("ListenAndServe: ", err)
+ time.Sleep(100 * time.Microsecond)
+ endRunning <- true
+ }
+ }
+ }()
+ }
<-endRunning
}
+
+// Router adds a patterned controller handler to BeeApp.
+// it's an alias method of App.Router.
+// usage:
+// simple router
+// beego.Router("/admin", &admin.UserController{})
+// beego.Router("/admin/index", &admin.ArticleController{})
+//
+// regex router
+//
+// beego.Router("/api/:id([0-9]+)", &controllers.RController{})
+//
+// custom rules
+// beego.Router("/api/list",&RestController{},"*:ListFood")
+// beego.Router("/api/create",&RestController{},"post:CreateFood")
+// beego.Router("/api/update",&RestController{},"put:UpdateFood")
+// beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
+func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App {
+ BeeApp.Handlers.Add(rootpath, c, mappingMethods...)
+ return BeeApp
+}
+
+// Include will generate router file in the router/xxx.go from the controller's comments
+// usage:
+// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{})
+// type BankAccount struct{
+// beego.Controller
+// }
+//
+// register the function
+// func (b *BankAccount)Mapping(){
+// b.Mapping("ShowAccount" , b.ShowAccount)
+// b.Mapping("ModifyAccount", b.ModifyAccount)
+//}
+//
+// //@router /account/:id [get]
+// func (b *BankAccount) ShowAccount(){
+// //logic
+// }
+//
+//
+// //@router /account/:id [post]
+// func (b *BankAccount) ModifyAccount(){
+// //logic
+// }
+//
+// the comments @router url methodlist
+// url support all the function Router's pattern
+// methodlist [get post head put delete options *]
+func Include(cList ...ControllerInterface) *App {
+ BeeApp.Handlers.Include(cList...)
+ return BeeApp
+}
+
+// RESTRouter adds a restful controller handler to BeeApp.
+// its' controller implements beego.ControllerInterface and
+// defines a param "pattern/:objectId" to visit each resource.
+func RESTRouter(rootpath string, c ControllerInterface) *App {
+ Router(rootpath, c)
+ Router(path.Join(rootpath, ":objectId"), c)
+ return BeeApp
+}
+
+// AutoRouter adds defined controller handler to BeeApp.
+// it's same to App.AutoRouter.
+// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page,
+// visit the url /main/list to exec List function or /main/page to exec Page function.
+func AutoRouter(c ControllerInterface) *App {
+ BeeApp.Handlers.AddAuto(c)
+ return BeeApp
+}
+
+// AutoPrefix adds controller handler to BeeApp with prefix.
+// it's same to App.AutoRouterWithPrefix.
+// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page,
+// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function.
+func AutoPrefix(prefix string, c ControllerInterface) *App {
+ BeeApp.Handlers.AddAutoPrefix(prefix, c)
+ return BeeApp
+}
+
+// Get used to register router for Get method
+// usage:
+// beego.Get("/", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func Get(rootpath string, f FilterFunc) *App {
+ BeeApp.Handlers.Get(rootpath, f)
+ return BeeApp
+}
+
+// Post used to register router for Post method
+// usage:
+// beego.Post("/api", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func Post(rootpath string, f FilterFunc) *App {
+ BeeApp.Handlers.Post(rootpath, f)
+ return BeeApp
+}
+
+// Delete used to register router for Delete method
+// usage:
+// beego.Delete("/api", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func Delete(rootpath string, f FilterFunc) *App {
+ BeeApp.Handlers.Delete(rootpath, f)
+ return BeeApp
+}
+
+// Put used to register router for Put method
+// usage:
+// beego.Put("/api", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func Put(rootpath string, f FilterFunc) *App {
+ BeeApp.Handlers.Put(rootpath, f)
+ return BeeApp
+}
+
+// Head used to register router for Head method
+// usage:
+// beego.Head("/api", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func Head(rootpath string, f FilterFunc) *App {
+ BeeApp.Handlers.Head(rootpath, f)
+ return BeeApp
+}
+
+// Options used to register router for Options method
+// usage:
+// beego.Options("/api", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func Options(rootpath string, f FilterFunc) *App {
+ BeeApp.Handlers.Options(rootpath, f)
+ return BeeApp
+}
+
+// Patch used to register router for Patch method
+// usage:
+// beego.Patch("/api", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func Patch(rootpath string, f FilterFunc) *App {
+ BeeApp.Handlers.Patch(rootpath, f)
+ return BeeApp
+}
+
+// Any used to register router for all methods
+// usage:
+// beego.Any("/api", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func Any(rootpath string, f FilterFunc) *App {
+ BeeApp.Handlers.Any(rootpath, f)
+ return BeeApp
+}
+
+// Handler used to register a Handler router
+// usage:
+// beego.Handler("/api", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func Handler(rootpath string, h http.Handler, options ...interface{}) *App {
+ BeeApp.Handlers.Handler(rootpath, h, options...)
+ return BeeApp
+}
+
+// InsertFilter adds a FilterFunc with pattern condition and action constant.
+// The pos means action constant including
+// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter.
+// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute)
+func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App {
+ BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...)
+ return BeeApp
+}
diff --git a/beego.go b/beego.go
index 9c34c9f9..04f02071 100644
--- a/beego.go
+++ b/beego.go
@@ -12,243 +12,35 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// beego is an open-source, high-performance, modularity, full-stack web framework
-//
-// package main
-//
-// import "github.com/astaxie/beego"
-//
-// func main() {
-// beego.Run()
-// }
-//
-// more infomation: http://beego.me
package beego
import (
- "net/http"
+ "fmt"
"os"
- "path"
"path/filepath"
"strconv"
"strings"
-
- "github.com/astaxie/beego/session"
)
-// beego web framework version.
-const VERSION = "1.4.3"
+const (
+ // VERSION represent beego web framework version.
+ VERSION = "1.6.0"
-type hookfunc func() error //hook function to run
-var hooks []hookfunc //hook function slice to store the hookfunc
+ // DEV is for develop
+ DEV = "dev"
+ // PROD is for production
+ PROD = "prod"
+)
-// Router adds a patterned controller handler to BeeApp.
-// it's an alias method of App.Router.
-// usage:
-// simple router
-// beego.Router("/admin", &admin.UserController{})
-// beego.Router("/admin/index", &admin.ArticleController{})
-//
-// regex router
-//
-// beego.Router("/api/:id([0-9]+)", &controllers.RController{})
-//
-// custom rules
-// beego.Router("/api/list",&RestController{},"*:ListFood")
-// beego.Router("/api/create",&RestController{},"post:CreateFood")
-// beego.Router("/api/update",&RestController{},"put:UpdateFood")
-// beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
-func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App {
- BeeApp.Handlers.Add(rootpath, c, mappingMethods...)
- return BeeApp
-}
+//hook function to run
+type hookfunc func() error
-// Router add list from
-// usage:
-// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{})
-// type BankAccount struct{
-// beego.Controller
-// }
-//
-// register the function
-// func (b *BankAccount)Mapping(){
-// b.Mapping("ShowAccount" , b.ShowAccount)
-// b.Mapping("ModifyAccount", b.ModifyAccount)
-//}
-//
-// //@router /account/:id [get]
-// func (b *BankAccount) ShowAccount(){
-// //logic
-// }
-//
-//
-// //@router /account/:id [post]
-// func (b *BankAccount) ModifyAccount(){
-// //logic
-// }
-//
-// the comments @router url methodlist
-// url support all the function Router's pattern
-// methodlist [get post head put delete options *]
-func Include(cList ...ControllerInterface) *App {
- BeeApp.Handlers.Include(cList...)
- return BeeApp
-}
+var (
+ hooks = make([]hookfunc, 0) //hook function slice to store the hookfunc
+)
-// RESTRouter adds a restful controller handler to BeeApp.
-// its' controller implements beego.ControllerInterface and
-// defines a param "pattern/:objectId" to visit each resource.
-func RESTRouter(rootpath string, c ControllerInterface) *App {
- Router(rootpath, c)
- Router(path.Join(rootpath, ":objectId"), c)
- return BeeApp
-}
-
-// AutoRouter adds defined controller handler to BeeApp.
-// it's same to App.AutoRouter.
-// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page,
-// visit the url /main/list to exec List function or /main/page to exec Page function.
-func AutoRouter(c ControllerInterface) *App {
- BeeApp.Handlers.AddAuto(c)
- return BeeApp
-}
-
-// AutoPrefix adds controller handler to BeeApp with prefix.
-// it's same to App.AutoRouterWithPrefix.
-// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page,
-// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function.
-func AutoPrefix(prefix string, c ControllerInterface) *App {
- BeeApp.Handlers.AddAutoPrefix(prefix, c)
- return BeeApp
-}
-
-// register router for Get method
-// usage:
-// beego.Get("/", func(ctx *context.Context){
-// ctx.Output.Body("hello world")
-// })
-func Get(rootpath string, f FilterFunc) *App {
- BeeApp.Handlers.Get(rootpath, f)
- return BeeApp
-}
-
-// register router for Post method
-// usage:
-// beego.Post("/api", func(ctx *context.Context){
-// ctx.Output.Body("hello world")
-// })
-func Post(rootpath string, f FilterFunc) *App {
- BeeApp.Handlers.Post(rootpath, f)
- return BeeApp
-}
-
-// register router for Delete method
-// usage:
-// beego.Delete("/api", func(ctx *context.Context){
-// ctx.Output.Body("hello world")
-// })
-func Delete(rootpath string, f FilterFunc) *App {
- BeeApp.Handlers.Delete(rootpath, f)
- return BeeApp
-}
-
-// register router for Put method
-// usage:
-// beego.Put("/api", func(ctx *context.Context){
-// ctx.Output.Body("hello world")
-// })
-func Put(rootpath string, f FilterFunc) *App {
- BeeApp.Handlers.Put(rootpath, f)
- return BeeApp
-}
-
-// register router for Head method
-// usage:
-// beego.Head("/api", func(ctx *context.Context){
-// ctx.Output.Body("hello world")
-// })
-func Head(rootpath string, f FilterFunc) *App {
- BeeApp.Handlers.Head(rootpath, f)
- return BeeApp
-}
-
-// register router for Options method
-// usage:
-// beego.Options("/api", func(ctx *context.Context){
-// ctx.Output.Body("hello world")
-// })
-func Options(rootpath string, f FilterFunc) *App {
- BeeApp.Handlers.Options(rootpath, f)
- return BeeApp
-}
-
-// register router for Patch method
-// usage:
-// beego.Patch("/api", func(ctx *context.Context){
-// ctx.Output.Body("hello world")
-// })
-func Patch(rootpath string, f FilterFunc) *App {
- BeeApp.Handlers.Patch(rootpath, f)
- return BeeApp
-}
-
-// register router for all method
-// usage:
-// beego.Any("/api", func(ctx *context.Context){
-// ctx.Output.Body("hello world")
-// })
-func Any(rootpath string, f FilterFunc) *App {
- BeeApp.Handlers.Any(rootpath, f)
- return BeeApp
-}
-
-// register router for own Handler
-// usage:
-// beego.Handler("/api", func(ctx *context.Context){
-// ctx.Output.Body("hello world")
-// })
-func Handler(rootpath string, h http.Handler, options ...interface{}) *App {
- BeeApp.Handlers.Handler(rootpath, h, options...)
- return BeeApp
-}
-
-// SetViewsPath sets view directory path in beego application.
-func SetViewsPath(path string) *App {
- ViewsPath = path
- return BeeApp
-}
-
-// SetStaticPath sets static directory path and proper url pattern in beego application.
-// if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public".
-func SetStaticPath(url string, path string) *App {
- if !strings.HasPrefix(url, "/") {
- url = "/" + url
- }
- url = strings.TrimRight(url, "/")
- StaticDir[url] = path
- return BeeApp
-}
-
-// DelStaticPath removes the static folder setting in this url pattern in beego application.
-func DelStaticPath(url string) *App {
- if !strings.HasPrefix(url, "/") {
- url = "/" + url
- }
- url = strings.TrimRight(url, "/")
- delete(StaticDir, url)
- return BeeApp
-}
-
-// InsertFilter adds a FilterFunc with pattern condition and action constant.
-// The pos means action constant including
-// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter.
-// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute)
-func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App {
- BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...)
- return BeeApp
-}
-
-// The hookfunc will run in beego.Run()
+// AddAPPStartHook is used to register the hookfunc
+// The hookfuncs will run in beego.Run()
// such as sessionInit, middlerware start, buildtemplate, admin start
func AddAPPStartHook(hf hookfunc) {
hooks = append(hooks, hf)
@@ -256,97 +48,60 @@ func AddAPPStartHook(hf hookfunc) {
// Run beego application.
// beego.Run() default run on HttpPort
+// beego.Run("localhost")
// beego.Run(":8089")
// beego.Run("127.0.0.1:8089")
func Run(params ...string) {
+ initBeforeHTTPRun()
+
if len(params) > 0 && params[0] != "" {
strs := strings.Split(params[0], ":")
if len(strs) > 0 && strs[0] != "" {
- HttpAddr = strs[0]
+ BConfig.Listen.HTTPAddr = strs[0]
}
if len(strs) > 1 && strs[1] != "" {
- HttpPort, _ = strconv.Atoi(strs[1])
+ BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1])
}
}
- initBeforeHttpRun()
-
- if EnableAdmin {
- go beeAdminApp.Run()
- }
BeeApp.Run()
}
-func initBeforeHttpRun() {
- // if AppConfigPath not In the conf/app.conf reParse config
- if AppConfigPath != filepath.Join(AppPath, "conf", "app.conf") {
- err := ParseConfig()
- if err != nil && AppConfigPath != filepath.Join(workPath, "conf", "app.conf") {
- // configuration is critical to app, panic here if parse failed
- panic(err)
- }
- }
-
- //init mime
- AddAPPStartHook(initMime)
-
- // do hooks function
- for _, hk := range hooks {
- err := hk()
- if err != nil {
- panic(err)
- }
- }
-
- if SessionOn {
- var err error
- sessionConfig := AppConfig.String("sessionConfig")
- if sessionConfig == "" {
- sessionConfig = `{"cookieName":"` + SessionName + `",` +
- `"gclifetime":` + strconv.FormatInt(SessionGCMaxLifetime, 10) + `,` +
- `"providerConfig":"` + filepath.ToSlash(SessionSavePath) + `",` +
- `"secure":` + strconv.FormatBool(EnableHttpTLS) + `,` +
- `"enableSetCookie":` + strconv.FormatBool(SessionAutoSetCookie) + `,` +
- `"domain":"` + SessionDomain + `",` +
- `"cookieLifeTime":` + strconv.Itoa(SessionCookieLifeTime) + `}`
- }
- GlobalSessions, err = session.NewManager(SessionProvider,
- sessionConfig)
- if err != nil {
- panic(err)
- }
- go GlobalSessions.GC()
- }
-
- err := BuildTemplate(ViewsPath)
- if err != nil {
- if RunMode == "dev" {
- Warn(err)
- }
- }
-
- registerDefaultErrorHandler()
-
- if EnableDocs {
- Get("/docs", serverDocs)
- Get("/docs/*", serverDocs)
- }
-}
-
-// this function is for test package init
-func TestBeegoInit(apppath string) {
- AppPath = apppath
- RunMode = "test"
- AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
+func initBeforeHTTPRun() {
+ // if AppConfigPath is setted or conf/app.conf exist
err := ParseConfig()
- if err != nil && !os.IsNotExist(err) {
- // for init if doesn't have app.conf will not panic
- Info(err)
+ if err != nil {
+ panic(err)
+ }
+ //init log
+ for adaptor, config := range BConfig.Log.Outputs {
+ err = BeeLogger.SetLogger(adaptor, config)
+ if err != nil {
+ fmt.Printf("%s with the config `%s` got err:%s\n", adaptor, config, err)
+ }
+ }
+
+ SetLogFuncCall(BConfig.Log.FileLineNum)
+
+ //init hooks
+ AddAPPStartHook(registerMime)
+ AddAPPStartHook(registerDefaultErrorHandler)
+ AddAPPStartHook(registerSession)
+ AddAPPStartHook(registerDocs)
+ AddAPPStartHook(registerTemplate)
+ AddAPPStartHook(registerAdmin)
+
+ for _, hk := range hooks {
+ if err := hk(); err != nil {
+ panic(err)
+ }
}
- os.Chdir(AppPath)
- initBeforeHttpRun()
}
-func init() {
- hooks = make([]hookfunc, 0)
+// TestBeegoInit is for test package init
+func TestBeegoInit(ap string) {
+ os.Setenv("BEEGO_RUNMODE", "test")
+ AppConfigPath = filepath.Join(ap, "conf", "app.conf")
+ os.Chdir(ap)
+ initBeforeHTTPRun()
}
diff --git a/cache/README.md b/cache/README.md
index 72d0d1c5..957790e7 100644
--- a/cache/README.md
+++ b/cache/README.md
@@ -26,7 +26,7 @@ Then init a Cache (example with memory adapter)
Use it like this:
- bm.Put("astaxie", 1, 10)
+ bm.Put("astaxie", 1, 10 * time.Second)
bm.Get("astaxie")
bm.IsExist("astaxie")
bm.Delete("astaxie")
@@ -43,7 +43,7 @@ interval means the gc time. The cache will check at each time interval, whether
## Memcache adapter
-Memcache adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client.
+Memcache adapter use the [gomemcache](http://github.com/bradfitz/gomemcache) client.
Configure like this:
@@ -52,7 +52,7 @@ Configure like this:
## Redis adapter
-Redis adapter use the [redigo](http://github.com/garyburd/redigo/redis) client.
+Redis adapter use the [redigo](http://github.com/garyburd/redigo) client.
Configure like this:
diff --git a/cache/cache.go b/cache/cache.go
index ddb2f857..f7158741 100644
--- a/cache/cache.go
+++ b/cache/cache.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package cache provide a Cache interface and some implemetn engine
// Usage:
//
// import(
@@ -22,7 +23,7 @@
//
// Use it like this:
//
-// bm.Put("astaxie", 1, 10)
+// bm.Put("astaxie", 1, 10 * time.Second)
// bm.Get("astaxie")
// bm.IsExist("astaxie")
// bm.Delete("astaxie")
@@ -32,13 +33,14 @@ package cache
import (
"fmt"
+ "time"
)
// Cache interface contains all behaviors for cache adapter.
// usage:
-// cache.Register("file",cache.NewFileCache()) // this operation is run in init method of file.go.
+// cache.Register("file",cache.NewFileCache) // this operation is run in init method of file.go.
// c,err := cache.NewCache("file","{....}")
-// c.Put("key",value,3600)
+// c.Put("key",value, 3600 * time.Second)
// v := c.Get("key")
//
// c.Incr("counter") // now is 1
@@ -47,8 +49,10 @@ import (
type Cache interface {
// get cached value by key.
Get(key string) interface{}
+ // GetMulti is a batch version of Get.
+ GetMulti(keys []string) []interface{}
// set cached value with key and expire time.
- Put(key string, val interface{}, timeout int64) error
+ Put(key string, val interface{}, timeout time.Duration) error
// delete cached value by key.
Delete(key string) error
// increase cached int value by key, as a counter.
@@ -63,12 +67,15 @@ type Cache interface {
StartAndGC(config string) error
}
-var adapters = make(map[string]Cache)
+// Instance is a function create a new Cache Instance
+type Instance func() Cache
+
+var adapters = make(map[string]Instance)
// Register makes a cache adapter available by the adapter name.
// If Register is called twice with the same name or if driver is nil,
// it panics.
-func Register(name string, adapter Cache) {
+func Register(name string, adapter Instance) {
if adapter == nil {
panic("cache: Register adapter is nil")
}
@@ -78,15 +85,16 @@ func Register(name string, adapter Cache) {
adapters[name] = adapter
}
-// Create a new cache driver by adapter name and config string.
+// NewCache Create a new cache driver by adapter name and config string.
// config need to be correct JSON as string: {"interval":360}.
// it will start gc automatically.
func NewCache(adapterName, config string) (adapter Cache, err error) {
- adapter, ok := adapters[adapterName]
+ instanceFunc, ok := adapters[adapterName]
if !ok {
err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName)
return
}
+ adapter = instanceFunc()
err = adapter.StartAndGC(config)
if err != nil {
adapter = nil
diff --git a/cache/cache_test.go b/cache/cache_test.go
index 7c43e539..9ceb606a 100644
--- a/cache/cache_test.go
+++ b/cache/cache_test.go
@@ -25,7 +25,8 @@ func TestCache(t *testing.T) {
if err != nil {
t.Error("init err")
}
- if err = bm.Put("astaxie", 1, 10); err != nil {
+ timeoutDuration := 10 * time.Second
+ if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
@@ -42,7 +43,7 @@ func TestCache(t *testing.T) {
t.Error("check err")
}
- if err = bm.Put("astaxie", 1, 10); err != nil {
+ if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err)
}
@@ -65,6 +66,35 @@ func TestCache(t *testing.T) {
if bm.IsExist("astaxie") {
t.Error("delete err")
}
+
+ //test GetMulti
+ if err = bm.Put("astaxie", "author", timeoutDuration); err != nil {
+ t.Error("set Error", err)
+ }
+ if !bm.IsExist("astaxie") {
+ t.Error("check err")
+ }
+ if v := bm.Get("astaxie"); v.(string) != "author" {
+ t.Error("get err")
+ }
+
+ if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil {
+ t.Error("set Error", err)
+ }
+ if !bm.IsExist("astaxie1") {
+ t.Error("check err")
+ }
+
+ vv := bm.GetMulti([]string{"astaxie", "astaxie1"})
+ if len(vv) != 2 {
+ t.Error("GetMulti ERROR")
+ }
+ if vv[0].(string) != "author" {
+ t.Error("GetMulti ERROR")
+ }
+ if vv[1].(string) != "author1" {
+ t.Error("GetMulti ERROR")
+ }
}
func TestFileCache(t *testing.T) {
@@ -72,7 +102,8 @@ func TestFileCache(t *testing.T) {
if err != nil {
t.Error("init err")
}
- if err = bm.Put("astaxie", 1, 10); err != nil {
+ timeoutDuration := 10 * time.Second
+ if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
@@ -102,16 +133,36 @@ func TestFileCache(t *testing.T) {
if bm.IsExist("astaxie") {
t.Error("delete err")
}
+
//test string
- if err = bm.Put("astaxie", "author", 10); err != nil {
+ if err = bm.Put("astaxie", "author", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
t.Error("check err")
}
-
if v := bm.Get("astaxie"); v.(string) != "author" {
t.Error("get err")
}
+
+ //test GetMulti
+ if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil {
+ t.Error("set Error", err)
+ }
+ if !bm.IsExist("astaxie1") {
+ t.Error("check err")
+ }
+
+ vv := bm.GetMulti([]string{"astaxie", "astaxie1"})
+ if len(vv) != 2 {
+ t.Error("GetMulti ERROR")
+ }
+ if vv[0].(string) != "author" {
+ t.Error("GetMulti ERROR")
+ }
+ if vv[1].(string) != "author1" {
+ t.Error("GetMulti ERROR")
+ }
+
os.RemoveAll("cache")
}
diff --git a/cache/conv.go b/cache/conv.go
index 724abfd2..dbdff1c7 100644
--- a/cache/conv.go
+++ b/cache/conv.go
@@ -19,7 +19,7 @@ import (
"strconv"
)
-// convert interface to string.
+// GetString convert interface to string.
func GetString(v interface{}) string {
switch result := v.(type) {
case string:
@@ -34,7 +34,7 @@ func GetString(v interface{}) string {
return ""
}
-// convert interface to int.
+// GetInt convert interface to int.
func GetInt(v interface{}) int {
switch result := v.(type) {
case int:
@@ -52,7 +52,7 @@ func GetInt(v interface{}) int {
return 0
}
-// convert interface to int64.
+// GetInt64 convert interface to int64.
func GetInt64(v interface{}) int64 {
switch result := v.(type) {
case int:
@@ -71,7 +71,7 @@ func GetInt64(v interface{}) int64 {
return 0
}
-// convert interface to float64.
+// GetFloat64 convert interface to float64.
func GetFloat64(v interface{}) float64 {
switch result := v.(type) {
case float64:
@@ -85,7 +85,7 @@ func GetFloat64(v interface{}) float64 {
return 0
}
-// convert interface to bool.
+// GetBool convert interface to bool.
func GetBool(v interface{}) bool {
switch result := v.(type) {
case bool:
@@ -98,15 +98,3 @@ func GetBool(v interface{}) bool {
}
return false
}
-
-// convert interface to byte slice.
-func getByteArray(v interface{}) []byte {
- switch result := v.(type) {
- case []byte:
- return result
- case string:
- return []byte(result)
- default:
- return nil
- }
-}
diff --git a/cache/conv_test.go b/cache/conv_test.go
index 267bb0c9..cf792fa6 100644
--- a/cache/conv_test.go
+++ b/cache/conv_test.go
@@ -27,7 +27,7 @@ func TestGetString(t *testing.T) {
if "test2" != GetString(t2) {
t.Error("get string from byte array error")
}
- var t3 int = 1
+ var t3 = 1
if "1" != GetString(t3) {
t.Error("get string from int error")
}
@@ -35,7 +35,7 @@ func TestGetString(t *testing.T) {
if "1" != GetString(t4) {
t.Error("get string from int64 error")
}
- var t5 float64 = 1.1
+ var t5 = 1.1
if "1.1" != GetString(t5) {
t.Error("get string from float64 error")
}
@@ -46,7 +46,7 @@ func TestGetString(t *testing.T) {
}
func TestGetInt(t *testing.T) {
- var t1 int = 1
+ var t1 = 1
if 1 != GetInt(t1) {
t.Error("get int from int error")
}
@@ -69,7 +69,7 @@ func TestGetInt(t *testing.T) {
func TestGetInt64(t *testing.T) {
var i int64 = 1
- var t1 int = 1
+ var t1 = 1
if i != GetInt64(t1) {
t.Error("get int64 from int error")
}
@@ -91,12 +91,12 @@ func TestGetInt64(t *testing.T) {
}
func TestGetFloat64(t *testing.T) {
- var f float64 = 1.11
+ var f = 1.11
var t1 float32 = 1.11
if f != GetFloat64(t1) {
t.Error("get float64 from float32 error")
}
- var t2 float64 = 1.11
+ var t2 = 1.11
if f != GetFloat64(t2) {
t.Error("get float64 from float64 error")
}
@@ -106,7 +106,7 @@ func TestGetFloat64(t *testing.T) {
}
var f2 float64 = 1
- var t4 int = 1
+ var t4 = 1
if f2 != GetFloat64(t4) {
t.Error("get float64 from int error")
}
@@ -130,21 +130,6 @@ func TestGetBool(t *testing.T) {
}
}
-func TestGetByteArray(t *testing.T) {
- var b = []byte("test")
- var t1 = []byte("test")
- if !byteArrayEquals(b, getByteArray(t1)) {
- t.Error("get byte array from byte array error")
- }
- var t2 = "test"
- if !byteArrayEquals(b, getByteArray(t2)) {
- t.Error("get byte array from string error")
- }
- if nil != getByteArray(nil) {
- t.Error("get byte array from nil error")
- }
-}
-
func byteArrayEquals(a []byte, b []byte) bool {
if len(a) != len(b) {
return false
diff --git a/cache/file.go b/cache/file.go
index bbbbbad2..4b030980 100644
--- a/cache/file.go
+++ b/cache/file.go
@@ -29,23 +29,20 @@ import (
"time"
)
-func init() {
- Register("file", NewFileCache())
-}
-
// FileCacheItem is basic unit of file cache adapter.
// it contains data and expire time.
type FileCacheItem struct {
Data interface{}
- Lastaccess int64
- Expired int64
+ Lastaccess time.Time
+ Expired time.Time
}
+// FileCache Config
var (
- FileCachePath string = "cache" // cache directory
- FileCacheFileSuffix string = ".bin" // cache file suffix
- FileCacheDirectoryLevel int = 2 // cache file deep level if auto generated cache files.
- FileCacheEmbedExpiry int64 = 0 // cache expire time, default is no expire forever.
+ FileCachePath = "cache" // cache directory
+ FileCacheFileSuffix = ".bin" // cache file suffix
+ FileCacheDirectoryLevel = 2 // cache file deep level if auto generated cache files.
+ FileCacheEmbedExpiry time.Duration // cache expire time, default is no expire forever.
)
// FileCache is cache adapter for file storage.
@@ -56,14 +53,14 @@ type FileCache struct {
EmbedExpiry int
}
-// Create new file cache with no config.
+// NewFileCache Create new file cache with no config.
// the level and expiry need set in method StartAndGC as config string.
-func NewFileCache() *FileCache {
+func NewFileCache() Cache {
// return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix}
return &FileCache{}
}
-// Start and begin gc for file cache.
+// StartAndGC will start and begin gc for file cache.
// the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0}
func (fc *FileCache) StartAndGC(config string) error {
@@ -79,7 +76,7 @@ func (fc *FileCache) StartAndGC(config string) error {
cfg["DirectoryLevel"] = strconv.Itoa(FileCacheDirectoryLevel)
}
if _, ok := cfg["EmbedExpiry"]; !ok {
- cfg["EmbedExpiry"] = strconv.FormatInt(FileCacheEmbedExpiry, 10)
+ cfg["EmbedExpiry"] = strconv.FormatInt(int64(FileCacheEmbedExpiry.Seconds()), 10)
}
fc.CachePath = cfg["CachePath"]
fc.FileSuffix = cfg["FileSuffix"]
@@ -120,36 +117,46 @@ func (fc *FileCache) getCacheFileName(key string) string {
// Get value from file cache.
// if non-exist or expired, return empty string.
func (fc *FileCache) Get(key string) interface{} {
- fileData, err := File_get_contents(fc.getCacheFileName(key))
+ fileData, err := FileGetContents(fc.getCacheFileName(key))
if err != nil {
return ""
}
var to FileCacheItem
- Gob_decode(fileData, &to)
- if to.Expired < time.Now().Unix() {
+ GobDecode(fileData, &to)
+ if to.Expired.Before(time.Now()) {
return ""
}
return to.Data
}
+// GetMulti gets values from file cache.
+// if non-exist or expired, return empty string.
+func (fc *FileCache) GetMulti(keys []string) []interface{} {
+ var rc []interface{}
+ for _, key := range keys {
+ rc = append(rc, fc.Get(key))
+ }
+ return rc
+}
+
// Put value into file cache.
// timeout means how long to keep this file, unit of ms.
// if timeout equals FileCacheEmbedExpiry(default is 0), cache this item forever.
-func (fc *FileCache) Put(key string, val interface{}, timeout int64) error {
+func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error {
gob.Register(val)
item := FileCacheItem{Data: val}
if timeout == FileCacheEmbedExpiry {
- item.Expired = time.Now().Unix() + (86400 * 365 * 10) // ten years
+ item.Expired = time.Now().Add((86400 * 365 * 10) * time.Second) // ten years
} else {
- item.Expired = time.Now().Unix() + timeout
+ item.Expired = time.Now().Add(timeout)
}
- item.Lastaccess = time.Now().Unix()
- data, err := Gob_encode(item)
+ item.Lastaccess = time.Now()
+ data, err := GobEncode(item)
if err != nil {
return err
}
- return File_put_contents(fc.getCacheFileName(key), data)
+ return FilePutContents(fc.getCacheFileName(key), data)
}
// Delete file cache value.
@@ -161,7 +168,7 @@ func (fc *FileCache) Delete(key string) error {
return nil
}
-// Increase cached int value.
+// Incr will increase cached int value.
// fc value is saving forever unless Delete.
func (fc *FileCache) Incr(key string) error {
data := fc.Get(key)
@@ -175,7 +182,7 @@ func (fc *FileCache) Incr(key string) error {
return nil
}
-// Decrease cached int value.
+// Decr will decrease cached int value.
func (fc *FileCache) Decr(key string) error {
data := fc.Get(key)
var decr int
@@ -188,13 +195,13 @@ func (fc *FileCache) Decr(key string) error {
return nil
}
-// Check value is exist.
+// IsExist check value is exist.
func (fc *FileCache) IsExist(key string) bool {
ret, _ := exists(fc.getCacheFileName(key))
return ret
}
-// Clean cached files.
+// ClearAll will clean cached files.
// not implemented.
func (fc *FileCache) ClearAll() error {
return nil
@@ -212,9 +219,9 @@ func exists(path string) (bool, error) {
return false, err
}
-// Get bytes to file.
+// FileGetContents Get bytes to file.
// if non-exist, create this file.
-func File_get_contents(filename string) (data []byte, e error) {
+func FileGetContents(filename string) (data []byte, e error) {
f, e := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, os.ModePerm)
if e != nil {
return
@@ -232,9 +239,9 @@ func File_get_contents(filename string) (data []byte, e error) {
return
}
-// Put bytes to file.
+// FilePutContents Put bytes to file.
// if non-exist, create this file.
-func File_put_contents(filename string, content []byte) error {
+func FilePutContents(filename string, content []byte) error {
fp, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, os.ModePerm)
if err != nil {
return err
@@ -244,8 +251,8 @@ func File_put_contents(filename string, content []byte) error {
return err
}
-// Gob encodes file cache item.
-func Gob_encode(data interface{}) ([]byte, error) {
+// GobEncode Gob encodes file cache item.
+func GobEncode(data interface{}) ([]byte, error) {
buf := bytes.NewBuffer(nil)
enc := gob.NewEncoder(buf)
err := enc.Encode(data)
@@ -255,9 +262,13 @@ func Gob_encode(data interface{}) ([]byte, error) {
return buf.Bytes(), err
}
-// Gob decodes file cache item.
-func Gob_decode(data []byte, to *FileCacheItem) error {
+// GobDecode Gob decodes file cache item.
+func GobDecode(data []byte, to *FileCacheItem) error {
buf := bytes.NewBuffer(data)
dec := gob.NewDecoder(buf)
return dec.Decode(&to)
}
+
+func init() {
+ Register("file", NewFileCache)
+}
diff --git a/cache/memcache/memcache.go b/cache/memcache/memcache.go
index f5a5c6ef..3f0fe411 100644
--- a/cache/memcache/memcache.go
+++ b/cache/memcache/memcache.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// package memcahe for cache provider
+// Package memcache for cache provider
//
// depend on github.com/bradfitz/gomemcache/memcache
//
@@ -36,22 +36,24 @@ import (
"github.com/bradfitz/gomemcache/memcache"
+ "time"
+
"github.com/astaxie/beego/cache"
)
-// Memcache adapter.
-type MemcacheCache struct {
+// Cache Memcache adapter.
+type Cache struct {
conn *memcache.Client
conninfo []string
}
-// create new memcache adapter.
-func NewMemCache() *MemcacheCache {
- return &MemcacheCache{}
+// NewMemCache create new memcache adapter.
+func NewMemCache() cache.Cache {
+ return &Cache{}
}
-// get value from memcache.
-func (rc *MemcacheCache) Get(key string) interface{} {
+// Get get value from memcache.
+func (rc *Cache) Get(key string) interface{} {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return err
@@ -63,8 +65,33 @@ func (rc *MemcacheCache) Get(key string) interface{} {
return nil
}
-// put value to memcache. only support string.
-func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
+// GetMulti get value from memcache.
+func (rc *Cache) GetMulti(keys []string) []interface{} {
+ size := len(keys)
+ var rv []interface{}
+ if rc.conn == nil {
+ if err := rc.connectInit(); err != nil {
+ for i := 0; i < size; i++ {
+ rv = append(rv, err)
+ }
+ return rv
+ }
+ }
+ mv, err := rc.conn.GetMulti(keys)
+ if err == nil {
+ for _, v := range mv {
+ rv = append(rv, string(v.Value))
+ }
+ return rv
+ }
+ for i := 0; i < size; i++ {
+ rv = append(rv, err)
+ }
+ return rv
+}
+
+// Put put value to memcache. only support string.
+func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return err
@@ -74,12 +101,12 @@ func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
if !ok {
return errors.New("val must string")
}
- item := memcache.Item{Key: key, Value: []byte(v), Expiration: int32(timeout)}
+ item := memcache.Item{Key: key, Value: []byte(v), Expiration: int32(timeout / time.Second)}
return rc.conn.Set(&item)
}
-// delete value in memcache.
-func (rc *MemcacheCache) Delete(key string) error {
+// Delete delete value in memcache.
+func (rc *Cache) Delete(key string) error {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return err
@@ -88,8 +115,8 @@ func (rc *MemcacheCache) Delete(key string) error {
return rc.conn.Delete(key)
}
-// increase counter.
-func (rc *MemcacheCache) Incr(key string) error {
+// Incr increase counter.
+func (rc *Cache) Incr(key string) error {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return err
@@ -99,8 +126,8 @@ func (rc *MemcacheCache) Incr(key string) error {
return err
}
-// decrease counter.
-func (rc *MemcacheCache) Decr(key string) error {
+// Decr decrease counter.
+func (rc *Cache) Decr(key string) error {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return err
@@ -110,8 +137,8 @@ func (rc *MemcacheCache) Decr(key string) error {
return err
}
-// check value exists in memcache.
-func (rc *MemcacheCache) IsExist(key string) bool {
+// IsExist check value exists in memcache.
+func (rc *Cache) IsExist(key string) bool {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return false
@@ -124,8 +151,8 @@ func (rc *MemcacheCache) IsExist(key string) bool {
return true
}
-// clear all cached in memcache.
-func (rc *MemcacheCache) ClearAll() error {
+// ClearAll clear all cached in memcache.
+func (rc *Cache) ClearAll() error {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return err
@@ -134,10 +161,10 @@ func (rc *MemcacheCache) ClearAll() error {
return rc.conn.FlushAll()
}
-// start memcache adapter.
+// StartAndGC start memcache adapter.
// config string is like {"conn":"connection info"}.
// if connecting error, return.
-func (rc *MemcacheCache) StartAndGC(config string) error {
+func (rc *Cache) StartAndGC(config string) error {
var cf map[string]string
json.Unmarshal([]byte(config), &cf)
if _, ok := cf["conn"]; !ok {
@@ -153,11 +180,11 @@ func (rc *MemcacheCache) StartAndGC(config string) error {
}
// connect to memcache and keep the connection.
-func (rc *MemcacheCache) connectInit() error {
+func (rc *Cache) connectInit() error {
rc.conn = memcache.New(rc.conninfo...)
return nil
}
func init() {
- cache.Register("memcache", NewMemCache())
+ cache.Register("memcache", NewMemCache)
}
diff --git a/cache/memcache/memcache_test.go b/cache/memcache/memcache_test.go
new file mode 100644
index 00000000..8d98c177
--- /dev/null
+++ b/cache/memcache/memcache_test.go
@@ -0,0 +1,108 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memcache
+
+import (
+ _ "github.com/bradfitz/gomemcache/memcache"
+
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/astaxie/beego/cache"
+)
+
+func TestMemcacheCache(t *testing.T) {
+ bm, err := cache.NewCache("memcache", `{"conn": "127.0.0.1:11211"}`)
+ if err != nil {
+ t.Error("init err")
+ }
+ timeoutDuration := 10 * time.Second
+ if err = bm.Put("astaxie", "1", timeoutDuration); err != nil {
+ t.Error("set Error", err)
+ }
+ if !bm.IsExist("astaxie") {
+ t.Error("check err")
+ }
+
+ time.Sleep(10 * time.Second)
+
+ if bm.IsExist("astaxie") {
+ t.Error("check err")
+ }
+ if err = bm.Put("astaxie", "1", timeoutDuration); err != nil {
+ t.Error("set Error", err)
+ }
+
+ if v, err := strconv.Atoi(bm.Get("astaxie").(string)); err != nil || v != 1 {
+ t.Error("get err")
+ }
+
+ if err = bm.Incr("astaxie"); err != nil {
+ t.Error("Incr Error", err)
+ }
+
+ if v, err := strconv.Atoi(bm.Get("astaxie").(string)); err != nil || v != 2 {
+ t.Error("get err")
+ }
+
+ if err = bm.Decr("astaxie"); err != nil {
+ t.Error("Decr Error", err)
+ }
+
+ if v, err := strconv.Atoi(bm.Get("astaxie").(string)); err != nil || v != 1 {
+ t.Error("get err")
+ }
+ bm.Delete("astaxie")
+ if bm.IsExist("astaxie") {
+ t.Error("delete err")
+ }
+
+ //test string
+ if err = bm.Put("astaxie", "author", timeoutDuration); err != nil {
+ t.Error("set Error", err)
+ }
+ if !bm.IsExist("astaxie") {
+ t.Error("check err")
+ }
+
+ if v := bm.Get("astaxie").(string); v != "author" {
+ t.Error("get err")
+ }
+
+ //test GetMulti
+ if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil {
+ t.Error("set Error", err)
+ }
+ if !bm.IsExist("astaxie1") {
+ t.Error("check err")
+ }
+
+ vv := bm.GetMulti([]string{"astaxie", "astaxie1"})
+ if len(vv) != 2 {
+ t.Error("GetMulti ERROR")
+ }
+ if vv[0].(string) != "author" && vv[0].(string) != "author1" {
+ t.Error("GetMulti ERROR")
+ }
+ if vv[1].(string) != "author1" && vv[1].(string) != "author" {
+ t.Error("GetMulti ERROR")
+ }
+
+ // test clear all
+ if err = bm.ClearAll(); err != nil {
+ t.Error("clear all err")
+ }
+}
diff --git a/cache/memory.go b/cache/memory.go
index b90d227c..fff2ebbb 100644
--- a/cache/memory.go
+++ b/cache/memory.go
@@ -17,34 +17,41 @@ package cache
import (
"encoding/json"
"errors"
- "fmt"
"sync"
"time"
)
var (
- // clock time of recycling the expired cache items in memory.
- DefaultEvery int = 60 // 1 minute
+ // DefaultEvery means the clock time of recycling the expired cache items in memory.
+ DefaultEvery = 60 // 1 minute
)
-// Memory cache item.
+// MemoryItem store memory cache item.
type MemoryItem struct {
- val interface{}
- Lastaccess time.Time
- expired int64
+ val interface{}
+ createdTime time.Time
+ lifespan time.Duration
}
-// Memory cache adapter.
+func (mi *MemoryItem) isExpire() bool {
+ // 0 means forever
+ if mi.lifespan == 0 {
+ return false
+ }
+ return time.Now().Sub(mi.createdTime) > mi.lifespan
+}
+
+// MemoryCache is Memory cache adapter.
// it contains a RW locker for safe map storage.
type MemoryCache struct {
- lock sync.RWMutex
+ sync.RWMutex
dur time.Duration
items map[string]*MemoryItem
Every int // run an expiration check Every clock time
}
// NewMemoryCache returns a new MemoryCache.
-func NewMemoryCache() *MemoryCache {
+func NewMemoryCache() Cache {
cache := MemoryCache{items: make(map[string]*MemoryItem)}
return &cache
}
@@ -52,11 +59,10 @@ func NewMemoryCache() *MemoryCache {
// Get cache from memory.
// if non-existed or expired, return nil.
func (bc *MemoryCache) Get(name string) interface{} {
- bc.lock.RLock()
- defer bc.lock.RUnlock()
+ bc.RLock()
+ defer bc.RUnlock()
if itm, ok := bc.items[name]; ok {
- if (time.Now().Unix() - itm.Lastaccess.Unix()) > itm.expired {
- go bc.Delete(name)
+ if itm.isExpire() {
return nil
}
return itm.val
@@ -64,23 +70,33 @@ func (bc *MemoryCache) Get(name string) interface{} {
return nil
}
+// GetMulti gets caches from memory.
+// if non-existed or expired, return nil.
+func (bc *MemoryCache) GetMulti(names []string) []interface{} {
+ var rc []interface{}
+ for _, name := range names {
+ rc = append(rc, bc.Get(name))
+ }
+ return rc
+}
+
// Put cache to memory.
-// if expired is 0, it will be cleaned by next gc operation ( default gc clock is 1 minute).
-func (bc *MemoryCache) Put(name string, value interface{}, expired int64) error {
- bc.lock.Lock()
- defer bc.lock.Unlock()
+// if lifespan is 0, it will be forever till restart.
+func (bc *MemoryCache) Put(name string, value interface{}, lifespan time.Duration) error {
+ bc.Lock()
+ defer bc.Unlock()
bc.items[name] = &MemoryItem{
- val: value,
- Lastaccess: time.Now(),
- expired: expired,
+ val: value,
+ createdTime: time.Now(),
+ lifespan: lifespan,
}
return nil
}
-/// Delete cache in memory.
+// Delete cache in memory.
func (bc *MemoryCache) Delete(name string) error {
- bc.lock.Lock()
- defer bc.lock.Unlock()
+ bc.Lock()
+ defer bc.Unlock()
if _, ok := bc.items[name]; !ok {
return errors.New("key not exist")
}
@@ -91,11 +107,11 @@ func (bc *MemoryCache) Delete(name string) error {
return nil
}
-// Increase cache counter in memory.
-// it supports int,int64,int32,uint,uint64,uint32.
+// Incr increase cache counter in memory.
+// it supports int,int32,int64,uint,uint32,uint64.
func (bc *MemoryCache) Incr(key string) error {
- bc.lock.RLock()
- defer bc.lock.RUnlock()
+ bc.RLock()
+ defer bc.RUnlock()
itm, ok := bc.items[key]
if !ok {
return errors.New("key not exist")
@@ -103,10 +119,10 @@ func (bc *MemoryCache) Incr(key string) error {
switch itm.val.(type) {
case int:
itm.val = itm.val.(int) + 1
- case int64:
- itm.val = itm.val.(int64) + 1
case int32:
itm.val = itm.val.(int32) + 1
+ case int64:
+ itm.val = itm.val.(int64) + 1
case uint:
itm.val = itm.val.(uint) + 1
case uint32:
@@ -114,15 +130,15 @@ func (bc *MemoryCache) Incr(key string) error {
case uint64:
itm.val = itm.val.(uint64) + 1
default:
- return errors.New("item val is not int int64 int32")
+ return errors.New("item val is not (u)int (u)int32 (u)int64")
}
return nil
}
-// Decrease counter in memory.
+// Decr decrease counter in memory.
func (bc *MemoryCache) Decr(key string) error {
- bc.lock.RLock()
- defer bc.lock.RUnlock()
+ bc.RLock()
+ defer bc.RUnlock()
itm, ok := bc.items[key]
if !ok {
return errors.New("key not exist")
@@ -158,23 +174,25 @@ func (bc *MemoryCache) Decr(key string) error {
return nil
}
-// check cache exist in memory.
+// IsExist check cache exist in memory.
func (bc *MemoryCache) IsExist(name string) bool {
- bc.lock.RLock()
- defer bc.lock.RUnlock()
- _, ok := bc.items[name]
- return ok
+ bc.RLock()
+ defer bc.RUnlock()
+ if v, ok := bc.items[name]; ok {
+ return !v.isExpire()
+ }
+ return false
}
-// delete all cache in memory.
+// ClearAll will delete all cache in memory.
func (bc *MemoryCache) ClearAll() error {
- bc.lock.Lock()
- defer bc.lock.Unlock()
+ bc.Lock()
+ defer bc.Unlock()
bc.items = make(map[string]*MemoryItem)
return nil
}
-// start memory cache. it will check expiration in every clock time.
+// StartAndGC start memory cache. it will check expiration in every clock time.
func (bc *MemoryCache) StartAndGC(config string) error {
var cf map[string]int
json.Unmarshal([]byte(config), &cf)
@@ -182,10 +200,7 @@ func (bc *MemoryCache) StartAndGC(config string) error {
cf = make(map[string]int)
cf["interval"] = DefaultEvery
}
- dur, err := time.ParseDuration(fmt.Sprintf("%ds", cf["interval"]))
- if err != nil {
- return err
- }
+ dur := time.Duration(cf["interval"]) * time.Second
bc.Every = cf["interval"]
bc.dur = dur
go bc.vaccuum()
@@ -202,21 +217,22 @@ func (bc *MemoryCache) vaccuum() {
if bc.items == nil {
return
}
- for name, _ := range bc.items {
- bc.item_expired(name)
+ for name := range bc.items {
+ bc.itemExpired(name)
}
}
}
-// item_expired returns true if an item is expired.
-func (bc *MemoryCache) item_expired(name string) bool {
- bc.lock.Lock()
- defer bc.lock.Unlock()
+// itemExpired returns true if an item is expired.
+func (bc *MemoryCache) itemExpired(name string) bool {
+ bc.Lock()
+ defer bc.Unlock()
+
itm, ok := bc.items[name]
if !ok {
return true
}
- if time.Now().Unix()-itm.Lastaccess.Unix() >= itm.expired {
+ if itm.isExpire() {
delete(bc.items, name)
return true
}
@@ -224,5 +240,5 @@ func (bc *MemoryCache) item_expired(name string) bool {
}
func init() {
- Register("memory", NewMemoryCache())
+ Register("memory", NewMemoryCache)
}
diff --git a/cache/redis/redis.go b/cache/redis/redis.go
index 0e07eaed..781e3836 100644
--- a/cache/redis/redis.go
+++ b/cache/redis/redis.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// package redis for cache provider
+// Package redis for cache provider
//
// depend on github.com/garyburd/redigo/redis
//
@@ -41,25 +41,26 @@ import (
)
var (
- // the collection name of redis for cache adapter.
- DefaultKey string = "beecacheRedis"
+ // DefaultKey the collection name of redis for cache adapter.
+ DefaultKey = "beecacheRedis"
)
-// Redis cache adapter.
-type RedisCache struct {
+// Cache is Redis cache adapter.
+type Cache struct {
p *redis.Pool // redis connection pool
conninfo string
dbNum int
key string
+ password string
}
-// create new redis cache with default collection name.
-func NewRedisCache() *RedisCache {
- return &RedisCache{key: DefaultKey}
+// NewRedisCache create new redis cache with default collection name.
+func NewRedisCache() cache.Cache {
+ return &Cache{key: DefaultKey}
}
// actually do the redis cmds
-func (rc *RedisCache) do(commandName string, args ...interface{}) (reply interface{}, err error) {
+func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) {
c := rc.p.Get()
defer c.Close()
@@ -67,17 +68,50 @@ func (rc *RedisCache) do(commandName string, args ...interface{}) (reply interfa
}
// Get cache from redis.
-func (rc *RedisCache) Get(key string) interface{} {
+func (rc *Cache) Get(key string) interface{} {
if v, err := rc.do("GET", key); err == nil {
return v
}
return nil
}
-// put cache to redis.
-func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error {
+// GetMulti get cache from redis.
+func (rc *Cache) GetMulti(keys []string) []interface{} {
+ size := len(keys)
+ var rv []interface{}
+ c := rc.p.Get()
+ defer c.Close()
var err error
- if _, err = rc.do("SETEX", key, timeout, val); err != nil {
+ for _, key := range keys {
+ err = c.Send("GET", key)
+ if err != nil {
+ goto ERROR
+ }
+ }
+ if err = c.Flush(); err != nil {
+ goto ERROR
+ }
+ for i := 0; i < size; i++ {
+ if v, err := c.Receive(); err == nil {
+ rv = append(rv, v.([]byte))
+ } else {
+ rv = append(rv, err)
+ }
+ }
+ return rv
+ERROR:
+ rv = rv[0:0]
+ for i := 0; i < size; i++ {
+ rv = append(rv, nil)
+ }
+
+ return rv
+}
+
+// Put put cache to redis.
+func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error {
+ var err error
+ if _, err = rc.do("SETEX", key, int64(timeout/time.Second), val); err != nil {
return err
}
@@ -87,8 +121,8 @@ func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error {
return err
}
-// delete cache in redis.
-func (rc *RedisCache) Delete(key string) error {
+// Delete delete cache in redis.
+func (rc *Cache) Delete(key string) error {
var err error
if _, err = rc.do("DEL", key); err != nil {
return err
@@ -97,8 +131,8 @@ func (rc *RedisCache) Delete(key string) error {
return err
}
-// check cache's existence in redis.
-func (rc *RedisCache) IsExist(key string) bool {
+// IsExist check cache's existence in redis.
+func (rc *Cache) IsExist(key string) bool {
v, err := redis.Bool(rc.do("EXISTS", key))
if err != nil {
return false
@@ -111,20 +145,20 @@ func (rc *RedisCache) IsExist(key string) bool {
return v
}
-// increase counter in redis.
-func (rc *RedisCache) Incr(key string) error {
+// Incr increase counter in redis.
+func (rc *Cache) Incr(key string) error {
_, err := redis.Bool(rc.do("INCRBY", key, 1))
return err
}
-// decrease counter in redis.
-func (rc *RedisCache) Decr(key string) error {
+// Decr decrease counter in redis.
+func (rc *Cache) Decr(key string) error {
_, err := redis.Bool(rc.do("INCRBY", key, -1))
return err
}
-// clean all cache in redis. delete this redis collection.
-func (rc *RedisCache) ClearAll() error {
+// ClearAll clean all cache in redis. delete this redis collection.
+func (rc *Cache) ClearAll() error {
cachedKeys, err := redis.Strings(rc.do("HKEYS", rc.key))
if err != nil {
return err
@@ -138,27 +172,31 @@ func (rc *RedisCache) ClearAll() error {
return err
}
-// start redis cache adapter.
+// StartAndGC start redis cache adapter.
// config is like {"key":"collection key","conn":"connection info","dbNum":"0"}
// the cache item in redis are stored forever,
// so no gc operation.
-func (rc *RedisCache) StartAndGC(config string) error {
+func (rc *Cache) StartAndGC(config string) error {
var cf map[string]string
json.Unmarshal([]byte(config), &cf)
if _, ok := cf["key"]; !ok {
cf["key"] = DefaultKey
}
-
if _, ok := cf["conn"]; !ok {
return errors.New("config has no conn key")
}
if _, ok := cf["dbNum"]; !ok {
cf["dbNum"] = "0"
}
+ if _, ok := cf["password"]; !ok {
+ cf["password"] = ""
+ }
rc.key = cf["key"]
rc.conninfo = cf["conn"]
rc.dbNum, _ = strconv.Atoi(cf["dbNum"])
+ rc.password = cf["password"]
+
rc.connectInit()
c := rc.p.Get()
@@ -168,9 +206,20 @@ func (rc *RedisCache) StartAndGC(config string) error {
}
// connect to redis.
-func (rc *RedisCache) connectInit() {
+func (rc *Cache) connectInit() {
dialFunc := func() (c redis.Conn, err error) {
c, err = redis.Dial("tcp", rc.conninfo)
+ if err != nil {
+ return nil, err
+ }
+
+ if rc.password != "" {
+ if _, err := c.Do("AUTH", rc.password); err != nil {
+ c.Close()
+ return nil, err
+ }
+ }
+
_, selecterr := c.Do("SELECT", rc.dbNum)
if selecterr != nil {
c.Close()
@@ -187,5 +236,5 @@ func (rc *RedisCache) connectInit() {
}
func init() {
- cache.Register("redis", NewRedisCache())
+ cache.Register("redis", NewRedisCache)
}
diff --git a/cache/redis/redis_test.go b/cache/redis/redis_test.go
index fbe82ac5..47c5acc6 100644
--- a/cache/redis/redis_test.go
+++ b/cache/redis/redis_test.go
@@ -28,19 +28,20 @@ func TestRedisCache(t *testing.T) {
if err != nil {
t.Error("init err")
}
- if err = bm.Put("astaxie", 1, 10); err != nil {
+ timeoutDuration := 10 * time.Second
+ if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
t.Error("check err")
}
- time.Sleep(10 * time.Second)
+ time.Sleep(11 * time.Second)
if bm.IsExist("astaxie") {
t.Error("check err")
}
- if err = bm.Put("astaxie", 1, 10); err != nil {
+ if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err)
}
@@ -67,8 +68,9 @@ func TestRedisCache(t *testing.T) {
if bm.IsExist("astaxie") {
t.Error("delete err")
}
+
//test string
- if err = bm.Put("astaxie", "author", 10); err != nil {
+ if err = bm.Put("astaxie", "author", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
@@ -78,6 +80,26 @@ func TestRedisCache(t *testing.T) {
if v, _ := redis.String(bm.Get("astaxie"), err); v != "author" {
t.Error("get err")
}
+
+ //test GetMulti
+ if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil {
+ t.Error("set Error", err)
+ }
+ if !bm.IsExist("astaxie1") {
+ t.Error("check err")
+ }
+
+ vv := bm.GetMulti([]string{"astaxie", "astaxie1"})
+ if len(vv) != 2 {
+ t.Error("GetMulti ERROR")
+ }
+ if v, _ := redis.String(vv[0], nil); v != "author" {
+ t.Error("GetMulti ERROR")
+ }
+ if v, _ := redis.String(vv[1], nil); v != "author1" {
+ t.Error("GetMulti ERROR")
+ }
+
// test clear all
if err = bm.ClearAll(); err != nil {
t.Error("clear all err")
diff --git a/config.go b/config.go
index f326ad22..8cf21530 100644
--- a/config.go
+++ b/config.go
@@ -15,77 +15,269 @@
package beego
import (
- "fmt"
"html/template"
"os"
"path/filepath"
- "runtime"
"strings"
"github.com/astaxie/beego/config"
- "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/session"
"github.com/astaxie/beego/utils"
)
-var (
- BeeApp *App // beego application
- AppName string
- AppPath string
- workPath string
- AppConfigPath string
+// Config is the main struct for BConfig
+type Config struct {
+ AppName string //Application name
+ RunMode string //Running Mode: dev | prod
+ RouterCaseSensitive bool
+ ServerName string
+ RecoverPanic bool
+ CopyRequestBody bool
+ EnableGzip bool
+ MaxMemory int64
+ EnableErrorsShow bool
+ Listen Listen
+ WebConfig WebConfig
+ Log LogConfig
+}
+
+// Listen holds for http and https related config
+type Listen struct {
+ Graceful bool // Graceful means use graceful module to start the server
+ ServerTimeOut int64
+ ListenTCP4 bool
+ EnableHTTP bool
+ HTTPAddr string
+ HTTPPort int
+ EnableHTTPS bool
+ HTTPSAddr string
+ HTTPSPort int
+ HTTPSCertFile string
+ HTTPSKeyFile string
+ EnableAdmin bool
+ AdminAddr string
+ AdminPort int
+ EnableFcgi bool
+ EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O
+}
+
+// WebConfig holds web related config
+type WebConfig struct {
+ AutoRender bool
+ EnableDocs bool
+ FlashName string
+ FlashSeparator string
+ DirectoryIndex bool
StaticDir map[string]string
- TemplateCache map[string]*template.Template // template caching map
- StaticExtensionsToGzip []string // files with should be compressed with gzip (.js,.css,etc)
- EnableHttpListen bool
- HttpAddr string
- HttpPort int
- ListenTCP4 bool
- EnableHttpTLS bool
- HttpsPort int
- HttpCertFile string
- HttpKeyFile string
- RecoverPanic bool // flag of auto recover panic
- AutoRender bool // flag of render template automatically
- ViewsPath string
- AppConfig *beegoAppConfig
- RunMode string // run mode, "dev" or "prod"
- GlobalSessions *session.Manager // global session mananger
- SessionOn bool // flag of starting session auto. default is false.
- SessionProvider string // default session provider, memory, mysql , redis ,etc.
- SessionName string // the cookie name when saving session id into cookie.
- SessionGCMaxLifetime int64 // session gc time for auto cleaning expired session.
- SessionSavePath string // if use mysql/redis/file provider, define save path to connection info.
- SessionCookieLifeTime int // the life time of session id in cookie.
- SessionAutoSetCookie bool // auto setcookie
- SessionDomain string // the cookie domain default is empty
- UseFcgi bool
- UseStdIo bool
- MaxMemory int64
- EnableGzip bool // flag of enable gzip
- DirectoryIndex bool // flag of display directory index. default is false.
- HttpServerTimeOut int64
- ErrorsShow bool // flag of show errors in page. if true, show error and trace info in page rendered with error template.
- XSRFKEY string // xsrf hash salt string.
- EnableXSRF bool // flag of enable xsrf.
- XSRFExpire int // the expiry of xsrf value.
- CopyRequestBody bool // flag of copy raw request body in context.
+ StaticExtensionsToGzip []string
TemplateLeft string
TemplateRight string
- BeegoServerName string // beego server name exported in response header.
- EnableAdmin bool // flag of enable admin module to log every request info.
- AdminHttpAddr string // http server configurations for admin module.
- AdminHttpPort int
- FlashName string // name of the flash variable found in response header and cookie
- FlashSeperator string // used to seperate flash key:value
- AppConfigProvider string // config provider
- EnableDocs bool // enable generate docs & server docs API Swagger
- RouterCaseSensitive bool // router case sensitive default is true
- AccessLogs bool // print access logs, default is false
+ ViewsPath string
+ EnableXSRF bool
+ XSRFKey string
+ XSRFExpire int
+ Session SessionConfig
+}
+
+// SessionConfig holds session related config
+type SessionConfig struct {
+ SessionOn bool
+ SessionProvider string
+ SessionName string
+ SessionGCMaxLifetime int64
+ SessionProviderConfig string
+ SessionCookieLifeTime int
+ SessionAutoSetCookie bool
+ SessionDomain string
+}
+
+// LogConfig holds Log related config
+type LogConfig struct {
+ AccessLogs bool
+ FileLineNum bool
+ Outputs map[string]string // Store Adaptor : config
+}
+
+var (
+ // BConfig is the default config for Application
+ BConfig *Config
+ // AppConfig is the instance of Config, store the config information from file
+ AppConfig *beegoAppConfig
+ // AppConfigPath is the path to the config files
+ AppConfigPath string
+ // AppConfigProvider is the provider for the config, default is ini
+ AppConfigProvider = "ini"
+ // TemplateCache stores template caching
+ TemplateCache map[string]*template.Template
+ // GlobalSessions is the instance for the session manager
+ GlobalSessions *session.Manager
)
+func init() {
+ BConfig = &Config{
+ AppName: "beego",
+ RunMode: DEV,
+ RouterCaseSensitive: true,
+ ServerName: "beegoServer:" + VERSION,
+ RecoverPanic: true,
+ CopyRequestBody: false,
+ EnableGzip: false,
+ MaxMemory: 1 << 26, //64MB
+ EnableErrorsShow: true,
+ Listen: Listen{
+ Graceful: false,
+ ServerTimeOut: 0,
+ ListenTCP4: false,
+ EnableHTTP: true,
+ HTTPAddr: "",
+ HTTPPort: 8080,
+ EnableHTTPS: false,
+ HTTPSAddr: "",
+ HTTPSPort: 10443,
+ HTTPSCertFile: "",
+ HTTPSKeyFile: "",
+ EnableAdmin: false,
+ AdminAddr: "",
+ AdminPort: 8088,
+ EnableFcgi: false,
+ EnableStdIo: false,
+ },
+ WebConfig: WebConfig{
+ AutoRender: true,
+ EnableDocs: false,
+ FlashName: "BEEGO_FLASH",
+ FlashSeparator: "BEEGOFLASH",
+ DirectoryIndex: false,
+ StaticDir: map[string]string{"/static": "static"},
+ StaticExtensionsToGzip: []string{".css", ".js"},
+ TemplateLeft: "{{",
+ TemplateRight: "}}",
+ ViewsPath: "views",
+ EnableXSRF: false,
+ XSRFKey: "beegoxsrf",
+ XSRFExpire: 0,
+ Session: SessionConfig{
+ SessionOn: false,
+ SessionProvider: "memory",
+ SessionName: "beegosessionID",
+ SessionGCMaxLifetime: 3600,
+ SessionProviderConfig: "",
+ SessionCookieLifeTime: 0, //set cookie default is the brower life
+ SessionAutoSetCookie: true,
+ SessionDomain: "",
+ },
+ },
+ Log: LogConfig{
+ AccessLogs: false,
+ FileLineNum: true,
+ Outputs: map[string]string{"console": ""},
+ },
+ }
+ ParseConfig()
+}
+
+// ParseConfig parsed default config file.
+// now only support ini, next will support json.
+func ParseConfig() (err error) {
+ if AppConfigPath == "" {
+ if utils.FileExists(filepath.Join("conf", "app.conf")) {
+ AppConfigPath = filepath.Join("conf", "app.conf")
+ } else {
+ AppConfig = &beegoAppConfig{config.NewFakeConfig()}
+ return
+ }
+ }
+ AppConfig, err = newAppConfig(AppConfigProvider, AppConfigPath)
+ if err != nil {
+ return err
+ }
+ // set the runmode first
+ if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" {
+ BConfig.RunMode = envRunMode
+ } else if runmode := AppConfig.String("RunMode"); runmode != "" {
+ BConfig.RunMode = runmode
+ }
+
+ BConfig.AppName = AppConfig.DefaultString("AppName", BConfig.AppName)
+ BConfig.RecoverPanic = AppConfig.DefaultBool("RecoverPanic", BConfig.RecoverPanic)
+ BConfig.RouterCaseSensitive = AppConfig.DefaultBool("RouterCaseSensitive", BConfig.RouterCaseSensitive)
+ BConfig.ServerName = AppConfig.DefaultString("ServerName", BConfig.ServerName)
+ BConfig.EnableGzip = AppConfig.DefaultBool("EnableGzip", BConfig.EnableGzip)
+ BConfig.EnableErrorsShow = AppConfig.DefaultBool("EnableErrorsShow", BConfig.EnableErrorsShow)
+ BConfig.CopyRequestBody = AppConfig.DefaultBool("CopyRequestBody", BConfig.CopyRequestBody)
+ BConfig.MaxMemory = AppConfig.DefaultInt64("MaxMemory", BConfig.MaxMemory)
+ BConfig.Listen.Graceful = AppConfig.DefaultBool("Graceful", BConfig.Listen.Graceful)
+ BConfig.Listen.HTTPAddr = AppConfig.String("HTTPAddr")
+ BConfig.Listen.HTTPPort = AppConfig.DefaultInt("HTTPPort", BConfig.Listen.HTTPPort)
+ BConfig.Listen.ListenTCP4 = AppConfig.DefaultBool("ListenTCP4", BConfig.Listen.ListenTCP4)
+ BConfig.Listen.EnableHTTP = AppConfig.DefaultBool("EnableHTTP", BConfig.Listen.EnableHTTP)
+ BConfig.Listen.EnableHTTPS = AppConfig.DefaultBool("EnableHTTPS", BConfig.Listen.EnableHTTPS)
+ BConfig.Listen.HTTPSAddr = AppConfig.DefaultString("HTTPSAddr", BConfig.Listen.HTTPSAddr)
+ BConfig.Listen.HTTPSPort = AppConfig.DefaultInt("HTTPSPort", BConfig.Listen.HTTPSPort)
+ BConfig.Listen.HTTPSCertFile = AppConfig.DefaultString("HTTPSCertFile", BConfig.Listen.HTTPSCertFile)
+ BConfig.Listen.HTTPSKeyFile = AppConfig.DefaultString("HTTPSKeyFile", BConfig.Listen.HTTPSKeyFile)
+ BConfig.Listen.EnableAdmin = AppConfig.DefaultBool("EnableAdmin", BConfig.Listen.EnableAdmin)
+ BConfig.Listen.AdminAddr = AppConfig.DefaultString("AdminAddr", BConfig.Listen.AdminAddr)
+ BConfig.Listen.AdminPort = AppConfig.DefaultInt("AdminPort", BConfig.Listen.AdminPort)
+ BConfig.Listen.EnableFcgi = AppConfig.DefaultBool("EnableFcgi", BConfig.Listen.EnableFcgi)
+ BConfig.Listen.EnableStdIo = AppConfig.DefaultBool("EnableStdIo", BConfig.Listen.EnableStdIo)
+ BConfig.Listen.ServerTimeOut = AppConfig.DefaultInt64("ServerTimeOut", BConfig.Listen.ServerTimeOut)
+ BConfig.WebConfig.AutoRender = AppConfig.DefaultBool("AutoRender", BConfig.WebConfig.AutoRender)
+ BConfig.WebConfig.ViewsPath = AppConfig.DefaultString("ViewsPath", BConfig.WebConfig.ViewsPath)
+ BConfig.WebConfig.DirectoryIndex = AppConfig.DefaultBool("DirectoryIndex", BConfig.WebConfig.DirectoryIndex)
+ BConfig.WebConfig.FlashName = AppConfig.DefaultString("FlashName", BConfig.WebConfig.FlashName)
+ BConfig.WebConfig.FlashSeparator = AppConfig.DefaultString("FlashSeparator", BConfig.WebConfig.FlashSeparator)
+ BConfig.WebConfig.EnableDocs = AppConfig.DefaultBool("EnableDocs", BConfig.WebConfig.EnableDocs)
+ BConfig.WebConfig.XSRFKey = AppConfig.DefaultString("XSRFKEY", BConfig.WebConfig.XSRFKey)
+ BConfig.WebConfig.EnableXSRF = AppConfig.DefaultBool("EnableXSRF", BConfig.WebConfig.EnableXSRF)
+ BConfig.WebConfig.XSRFExpire = AppConfig.DefaultInt("XSRFExpire", BConfig.WebConfig.XSRFExpire)
+ BConfig.WebConfig.TemplateLeft = AppConfig.DefaultString("TemplateLeft", BConfig.WebConfig.TemplateLeft)
+ BConfig.WebConfig.TemplateRight = AppConfig.DefaultString("TemplateRight", BConfig.WebConfig.TemplateRight)
+ BConfig.WebConfig.Session.SessionOn = AppConfig.DefaultBool("SessionOn", BConfig.WebConfig.Session.SessionOn)
+ BConfig.WebConfig.Session.SessionProvider = AppConfig.DefaultString("SessionProvider", BConfig.WebConfig.Session.SessionProvider)
+ BConfig.WebConfig.Session.SessionName = AppConfig.DefaultString("SessionName", BConfig.WebConfig.Session.SessionName)
+ BConfig.WebConfig.Session.SessionProviderConfig = AppConfig.DefaultString("SessionProviderConfig", BConfig.WebConfig.Session.SessionProviderConfig)
+ BConfig.WebConfig.Session.SessionGCMaxLifetime = AppConfig.DefaultInt64("SessionGCMaxLifetime", BConfig.WebConfig.Session.SessionGCMaxLifetime)
+ BConfig.WebConfig.Session.SessionCookieLifeTime = AppConfig.DefaultInt("SessionCookieLifeTime", BConfig.WebConfig.Session.SessionCookieLifeTime)
+ BConfig.WebConfig.Session.SessionAutoSetCookie = AppConfig.DefaultBool("SessionAutoSetCookie", BConfig.WebConfig.Session.SessionAutoSetCookie)
+ BConfig.WebConfig.Session.SessionDomain = AppConfig.DefaultString("SessionDomain", BConfig.WebConfig.Session.SessionDomain)
+
+ if sd := AppConfig.String("StaticDir"); sd != "" {
+ for k := range BConfig.WebConfig.StaticDir {
+ delete(BConfig.WebConfig.StaticDir, k)
+ }
+ sds := strings.Fields(sd)
+ for _, v := range sds {
+ if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 {
+ BConfig.WebConfig.StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[1]
+ } else {
+ BConfig.WebConfig.StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[0]
+ }
+ }
+ }
+
+ if sgz := AppConfig.String("StaticExtensionsToGzip"); sgz != "" {
+ extensions := strings.Split(sgz, ",")
+ fileExts := []string{}
+ for _, ext := range extensions {
+ ext = strings.TrimSpace(ext)
+ if ext == "" {
+ continue
+ }
+ if !strings.HasPrefix(ext, ".") {
+ ext = "." + ext
+ }
+ fileExts = append(fileExts, ext)
+ }
+ if len(fileExts) > 0 {
+ BConfig.WebConfig.StaticExtensionsToGzip = fileExts
+ }
+ }
+ return nil
+}
+
type beegoAppConfig struct {
- innerConfig config.ConfigContainer
+ innerConfig config.Configer
}
func newAppConfig(AppConfigProvider, AppConfigPath string) (*beegoAppConfig, error) {
@@ -93,84 +285,98 @@ func newAppConfig(AppConfigProvider, AppConfigPath string) (*beegoAppConfig, err
if err != nil {
return nil, err
}
- rac := &beegoAppConfig{ac}
- return rac, nil
+ return &beegoAppConfig{ac}, nil
}
func (b *beegoAppConfig) Set(key, val string) error {
+ if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil {
+ return err
+ }
return b.innerConfig.Set(key, val)
}
func (b *beegoAppConfig) String(key string) string {
- v := b.innerConfig.String(RunMode + "::" + key)
- if v == "" {
- return b.innerConfig.String(key)
+ if v := b.innerConfig.String(BConfig.RunMode + "::" + key); v != "" {
+ return v
}
- return v
+ return b.innerConfig.String(key)
}
func (b *beegoAppConfig) Strings(key string) []string {
- v := b.innerConfig.Strings(RunMode + "::" + key)
- if v[0] == "" {
- return b.innerConfig.Strings(key)
+ if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); v[0] != "" {
+ return v
}
- return v
+ return b.innerConfig.Strings(key)
}
func (b *beegoAppConfig) Int(key string) (int, error) {
- v, err := b.innerConfig.Int(RunMode + "::" + key)
- if err != nil {
- return b.innerConfig.Int(key)
+ if v, err := b.innerConfig.Int(BConfig.RunMode + "::" + key); err == nil {
+ return v, nil
}
- return v, nil
+ return b.innerConfig.Int(key)
}
func (b *beegoAppConfig) Int64(key string) (int64, error) {
- v, err := b.innerConfig.Int64(RunMode + "::" + key)
- if err != nil {
- return b.innerConfig.Int64(key)
+ if v, err := b.innerConfig.Int64(BConfig.RunMode + "::" + key); err == nil {
+ return v, nil
}
- return v, nil
+ return b.innerConfig.Int64(key)
}
func (b *beegoAppConfig) Bool(key string) (bool, error) {
- v, err := b.innerConfig.Bool(RunMode + "::" + key)
- if err != nil {
- return b.innerConfig.Bool(key)
+ if v, err := b.innerConfig.Bool(BConfig.RunMode + "::" + key); err == nil {
+ return v, nil
}
- return v, nil
+ return b.innerConfig.Bool(key)
}
func (b *beegoAppConfig) Float(key string) (float64, error) {
- v, err := b.innerConfig.Float(RunMode + "::" + key)
- if err != nil {
- return b.innerConfig.Float(key)
+ if v, err := b.innerConfig.Float(BConfig.RunMode + "::" + key); err == nil {
+ return v, nil
}
- return v, nil
+ return b.innerConfig.Float(key)
}
func (b *beegoAppConfig) DefaultString(key string, defaultval string) string {
- return b.innerConfig.DefaultString(key, defaultval)
+ if v := b.String(key); v != "" {
+ return v
+ }
+ return defaultval
}
func (b *beegoAppConfig) DefaultStrings(key string, defaultval []string) []string {
- return b.innerConfig.DefaultStrings(key, defaultval)
+ if v := b.Strings(key); len(v) != 0 {
+ return v
+ }
+ return defaultval
}
func (b *beegoAppConfig) DefaultInt(key string, defaultval int) int {
- return b.innerConfig.DefaultInt(key, defaultval)
+ if v, err := b.Int(key); err == nil {
+ return v
+ }
+ return defaultval
}
func (b *beegoAppConfig) DefaultInt64(key string, defaultval int64) int64 {
- return b.innerConfig.DefaultInt64(key, defaultval)
+ if v, err := b.Int64(key); err == nil {
+ return v
+ }
+ return defaultval
}
func (b *beegoAppConfig) DefaultBool(key string, defaultval bool) bool {
- return b.innerConfig.DefaultBool(key, defaultval)
+ if v, err := b.Bool(key); err == nil {
+ return v
+ }
+ return defaultval
}
func (b *beegoAppConfig) DefaultFloat(key string, defaultval float64) float64 {
- return b.innerConfig.DefaultFloat(key, defaultval)
+ if v, err := b.Float(key); err == nil {
+ return v
+ }
+ return defaultval
}
func (b *beegoAppConfig) DIY(key string) (interface{}, error) {
@@ -184,302 +390,3 @@ func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) {
func (b *beegoAppConfig) SaveConfigFile(filename string) error {
return b.innerConfig.SaveConfigFile(filename)
}
-
-func init() {
- // create beego application
- BeeApp = NewApp()
-
- workPath, _ = os.Getwd()
- workPath, _ = filepath.Abs(workPath)
- // initialize default configurations
- AppPath, _ = filepath.Abs(filepath.Dir(os.Args[0]))
-
- AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
-
- if workPath != AppPath {
- if utils.FileExists(AppConfigPath) {
- os.Chdir(AppPath)
- } else {
- AppConfigPath = filepath.Join(workPath, "conf", "app.conf")
- }
- }
-
- AppConfigProvider = "ini"
-
- StaticDir = make(map[string]string)
- StaticDir["/static"] = "static"
-
- StaticExtensionsToGzip = []string{".css", ".js"}
-
- TemplateCache = make(map[string]*template.Template)
-
- // set this to 0.0.0.0 to make this app available to externally
- EnableHttpListen = true //default enable http Listen
-
- HttpAddr = ""
- HttpPort = 8080
-
- HttpsPort = 10443
-
- AppName = "beego"
-
- RunMode = "dev" //default runmod
-
- AutoRender = true
-
- RecoverPanic = true
-
- ViewsPath = "views"
-
- SessionOn = false
- SessionProvider = "memory"
- SessionName = "beegosessionID"
- SessionGCMaxLifetime = 3600
- SessionSavePath = ""
- SessionCookieLifeTime = 0 //set cookie default is the brower life
- SessionAutoSetCookie = true
-
- UseFcgi = false
- UseStdIo = false
-
- MaxMemory = 1 << 26 //64MB
-
- EnableGzip = false
-
- HttpServerTimeOut = 0
-
- ErrorsShow = true
-
- XSRFKEY = "beegoxsrf"
- XSRFExpire = 0
-
- TemplateLeft = "{{"
- TemplateRight = "}}"
-
- BeegoServerName = "beegoServer:" + VERSION
-
- EnableAdmin = false
- AdminHttpAddr = "127.0.0.1"
- AdminHttpPort = 8088
-
- FlashName = "BEEGO_FLASH"
- FlashSeperator = "BEEGOFLASH"
-
- RouterCaseSensitive = true
-
- runtime.GOMAXPROCS(runtime.NumCPU())
-
- // init BeeLogger
- BeeLogger = logs.NewLogger(10000)
- err := BeeLogger.SetLogger("console", "")
- if err != nil {
- fmt.Println("init console log error:", err)
- }
- SetLogFuncCall(true)
-
- err = ParseConfig()
- if err != nil && os.IsNotExist(err) {
- // for init if doesn't have app.conf will not panic
- ac := config.NewFakeConfig()
- AppConfig = &beegoAppConfig{ac}
- Warning(err)
- }
-}
-
-// ParseConfig parsed default config file.
-// now only support ini, next will support json.
-func ParseConfig() (err error) {
- AppConfig, err = newAppConfig(AppConfigProvider, AppConfigPath)
- if err != nil {
- return err
- }
- envRunMode := os.Getenv("BEEGO_RUNMODE")
- // set the runmode first
- if envRunMode != "" {
- RunMode = envRunMode
- } else if runmode := AppConfig.String("RunMode"); runmode != "" {
- RunMode = runmode
- }
-
- HttpAddr = AppConfig.String("HttpAddr")
-
- if v, err := AppConfig.Int("HttpPort"); err == nil {
- HttpPort = v
- }
-
- if v, err := AppConfig.Bool("ListenTCP4"); err == nil {
- ListenTCP4 = v
- }
-
- if v, err := AppConfig.Bool("EnableHttpListen"); err == nil {
- EnableHttpListen = v
- }
-
- if maxmemory, err := AppConfig.Int64("MaxMemory"); err == nil {
- MaxMemory = maxmemory
- }
-
- if appname := AppConfig.String("AppName"); appname != "" {
- AppName = appname
- }
-
- if autorender, err := AppConfig.Bool("AutoRender"); err == nil {
- AutoRender = autorender
- }
-
- if autorecover, err := AppConfig.Bool("RecoverPanic"); err == nil {
- RecoverPanic = autorecover
- }
-
- if views := AppConfig.String("ViewsPath"); views != "" {
- ViewsPath = views
- }
-
- if sessionon, err := AppConfig.Bool("SessionOn"); err == nil {
- SessionOn = sessionon
- }
-
- if sessProvider := AppConfig.String("SessionProvider"); sessProvider != "" {
- SessionProvider = sessProvider
- }
-
- if sessName := AppConfig.String("SessionName"); sessName != "" {
- SessionName = sessName
- }
-
- if sesssavepath := AppConfig.String("SessionSavePath"); sesssavepath != "" {
- SessionSavePath = sesssavepath
- }
-
- if sessMaxLifeTime, err := AppConfig.Int64("SessionGCMaxLifetime"); err == nil && sessMaxLifeTime != 0 {
- SessionGCMaxLifetime = sessMaxLifeTime
- }
-
- if sesscookielifetime, err := AppConfig.Int("SessionCookieLifeTime"); err == nil && sesscookielifetime != 0 {
- SessionCookieLifeTime = sesscookielifetime
- }
-
- if usefcgi, err := AppConfig.Bool("UseFcgi"); err == nil {
- UseFcgi = usefcgi
- }
-
- if enablegzip, err := AppConfig.Bool("EnableGzip"); err == nil {
- EnableGzip = enablegzip
- }
-
- if directoryindex, err := AppConfig.Bool("DirectoryIndex"); err == nil {
- DirectoryIndex = directoryindex
- }
-
- if timeout, err := AppConfig.Int64("HttpServerTimeOut"); err == nil {
- HttpServerTimeOut = timeout
- }
-
- if errorsshow, err := AppConfig.Bool("ErrorsShow"); err == nil {
- ErrorsShow = errorsshow
- }
-
- if copyrequestbody, err := AppConfig.Bool("CopyRequestBody"); err == nil {
- CopyRequestBody = copyrequestbody
- }
-
- if xsrfkey := AppConfig.String("XSRFKEY"); xsrfkey != "" {
- XSRFKEY = xsrfkey
- }
-
- if enablexsrf, err := AppConfig.Bool("EnableXSRF"); err == nil {
- EnableXSRF = enablexsrf
- }
-
- if expire, err := AppConfig.Int("XSRFExpire"); err == nil {
- XSRFExpire = expire
- }
-
- if tplleft := AppConfig.String("TemplateLeft"); tplleft != "" {
- TemplateLeft = tplleft
- }
-
- if tplright := AppConfig.String("TemplateRight"); tplright != "" {
- TemplateRight = tplright
- }
-
- if httptls, err := AppConfig.Bool("EnableHttpTLS"); err == nil {
- EnableHttpTLS = httptls
- }
-
- if httpsport, err := AppConfig.Int("HttpsPort"); err == nil {
- HttpsPort = httpsport
- }
-
- if certfile := AppConfig.String("HttpCertFile"); certfile != "" {
- HttpCertFile = certfile
- }
-
- if keyfile := AppConfig.String("HttpKeyFile"); keyfile != "" {
- HttpKeyFile = keyfile
- }
-
- if serverName := AppConfig.String("BeegoServerName"); serverName != "" {
- BeegoServerName = serverName
- }
-
- if flashname := AppConfig.String("FlashName"); flashname != "" {
- FlashName = flashname
- }
-
- if flashseperator := AppConfig.String("FlashSeperator"); flashseperator != "" {
- FlashSeperator = flashseperator
- }
-
- if sd := AppConfig.String("StaticDir"); sd != "" {
- for k := range StaticDir {
- delete(StaticDir, k)
- }
- sds := strings.Fields(sd)
- for _, v := range sds {
- if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 {
- StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[1]
- } else {
- StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[0]
- }
- }
- }
-
- if sgz := AppConfig.String("StaticExtensionsToGzip"); sgz != "" {
- extensions := strings.Split(sgz, ",")
- if len(extensions) > 0 {
- StaticExtensionsToGzip = []string{}
- for _, ext := range extensions {
- if len(ext) == 0 {
- continue
- }
- extWithDot := ext
- if extWithDot[:1] != "." {
- extWithDot = "." + extWithDot
- }
- StaticExtensionsToGzip = append(StaticExtensionsToGzip, extWithDot)
- }
- }
- }
-
- if enableadmin, err := AppConfig.Bool("EnableAdmin"); err == nil {
- EnableAdmin = enableadmin
- }
-
- if adminhttpaddr := AppConfig.String("AdminHttpAddr"); adminhttpaddr != "" {
- AdminHttpAddr = adminhttpaddr
- }
-
- if adminhttpport, err := AppConfig.Int("AdminHttpPort"); err == nil {
- AdminHttpPort = adminhttpport
- }
-
- if enabledocs, err := AppConfig.Bool("EnableDocs"); err == nil {
- EnableDocs = enabledocs
- }
-
- if casesensitive, err := AppConfig.Bool("RouterCaseSensitive"); err == nil {
- RouterCaseSensitive = casesensitive
- }
- return nil
-}
diff --git a/config/config.go b/config/config.go
index 8d9261b8..da5d358b 100644
--- a/config/config.go
+++ b/config/config.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package config is used to parse config
// Usage:
// import(
// "github.com/astaxie/beego/config"
@@ -28,12 +29,12 @@
// cnf.Int64(key string) (int64, error)
// cnf.Bool(key string) (bool, error)
// cnf.Float(key string) (float64, error)
-// cnf.DefaultString(key string, defaultval string) string
-// cnf.DefaultStrings(key string, defaultval []string) []string
-// cnf.DefaultInt(key string, defaultval int) int
-// cnf.DefaultInt64(key string, defaultval int64) int64
-// cnf.DefaultBool(key string, defaultval bool) bool
-// cnf.DefaultFloat(key string, defaultval float64) float64
+// cnf.DefaultString(key string, defaultVal string) string
+// cnf.DefaultStrings(key string, defaultVal []string) []string
+// cnf.DefaultInt(key string, defaultVal int) int
+// cnf.DefaultInt64(key string, defaultVal int64) int64
+// cnf.DefaultBool(key string, defaultVal bool) bool
+// cnf.DefaultFloat(key string, defaultVal float64) float64
// cnf.DIY(key string) (interface{}, error)
// cnf.GetSection(section string) (map[string]string, error)
// cnf.SaveConfigFile(filename string) error
@@ -45,30 +46,30 @@ import (
"fmt"
)
-// ConfigContainer defines how to get and set value from configuration raw data.
-type ConfigContainer interface {
- Set(key, val string) error // support section::key type in given key when using ini type.
- String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
+// Configer defines how to get and set value from configuration raw data.
+type Configer interface {
+ Set(key, val string) error //support section::key type in given key when using ini type.
+ String(key string) string //support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
Strings(key string) []string //get string slice
Int(key string) (int, error)
Int64(key string) (int64, error)
Bool(key string) (bool, error)
Float(key string) (float64, error)
- DefaultString(key string, defaultval string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
- DefaultStrings(key string, defaultval []string) []string //get string slice
- DefaultInt(key string, defaultval int) int
- DefaultInt64(key string, defaultval int64) int64
- DefaultBool(key string, defaultval bool) bool
- DefaultFloat(key string, defaultval float64) float64
+ DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
+ DefaultStrings(key string, defaultVal []string) []string //get string slice
+ DefaultInt(key string, defaultVal int) int
+ DefaultInt64(key string, defaultVal int64) int64
+ DefaultBool(key string, defaultVal bool) bool
+ DefaultFloat(key string, defaultVal float64) float64
DIY(key string) (interface{}, error)
GetSection(section string) (map[string]string, error)
SaveConfigFile(filename string) error
}
-// Config is the adapter interface for parsing config file to get raw data to ConfigContainer.
+// Config is the adapter interface for parsing config file to get raw data to Configer.
type Config interface {
- Parse(key string) (ConfigContainer, error)
- ParseData(data []byte) (ConfigContainer, error)
+ Parse(key string) (Configer, error)
+ ParseData(data []byte) (Configer, error)
}
var adapters = make(map[string]Config)
@@ -86,19 +87,19 @@ func Register(name string, adapter Config) {
adapters[name] = adapter
}
-// adapterName is ini/json/xml/yaml.
+// NewConfig adapterName is ini/json/xml/yaml.
// filename is the config file path.
-func NewConfig(adapterName, fileaname string) (ConfigContainer, error) {
+func NewConfig(adapterName, filename string) (Configer, error) {
adapter, ok := adapters[adapterName]
if !ok {
return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName)
}
- return adapter.Parse(fileaname)
+ return adapter.Parse(filename)
}
-// adapterName is ini/json/xml/yaml.
+// NewConfigData adapterName is ini/json/xml/yaml.
// data is the config data.
-func NewConfigData(adapterName string, data []byte) (ConfigContainer, error) {
+func NewConfigData(adapterName string, data []byte) (Configer, error) {
adapter, ok := adapters[adapterName]
if !ok {
return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName)
diff --git a/config/fake.go b/config/fake.go
index 54588e5e..50aa5d4a 100644
--- a/config/fake.go
+++ b/config/fake.go
@@ -38,11 +38,11 @@ func (c *fakeConfigContainer) String(key string) string {
}
func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string {
- if v := c.getData(key); v == "" {
+ v := c.getData(key)
+ if v == "" {
return defaultval
- } else {
- return v
}
+ return v
}
func (c *fakeConfigContainer) Strings(key string) []string {
@@ -50,11 +50,11 @@ func (c *fakeConfigContainer) Strings(key string) []string {
}
func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []string {
- if v := c.Strings(key); len(v) == 0 {
+ v := c.Strings(key)
+ if len(v) == 0 {
return defaultval
- } else {
- return v
}
+ return v
}
func (c *fakeConfigContainer) Int(key string) (int, error) {
@@ -62,11 +62,11 @@ func (c *fakeConfigContainer) Int(key string) (int, error) {
}
func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int {
- if v, err := c.Int(key); err != nil {
+ v, err := c.Int(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
func (c *fakeConfigContainer) Int64(key string) (int64, error) {
@@ -74,11 +74,11 @@ func (c *fakeConfigContainer) Int64(key string) (int64, error) {
}
func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
- if v, err := c.Int64(key); err != nil {
+ v, err := c.Int64(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
func (c *fakeConfigContainer) Bool(key string) (bool, error) {
@@ -86,11 +86,11 @@ func (c *fakeConfigContainer) Bool(key string) (bool, error) {
}
func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool {
- if v, err := c.Bool(key); err != nil {
+ v, err := c.Bool(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
func (c *fakeConfigContainer) Float(key string) (float64, error) {
@@ -98,11 +98,11 @@ func (c *fakeConfigContainer) Float(key string) (float64, error) {
}
func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
- if v, err := c.Float(key); err != nil {
+ v, err := c.Float(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
func (c *fakeConfigContainer) DIY(key string) (interface{}, error) {
@@ -120,9 +120,10 @@ func (c *fakeConfigContainer) SaveConfigFile(filename string) error {
return errors.New("not implement in the fakeConfigContainer")
}
-var _ ConfigContainer = new(fakeConfigContainer)
+var _ Configer = new(fakeConfigContainer)
-func NewFakeConfig() ConfigContainer {
+// NewFakeConfig return a fake Congiger
+func NewFakeConfig() Configer {
return &fakeConfigContainer{
data: make(map[string]string),
}
diff --git a/config/ini.go b/config/ini.go
index 837c9ffe..59e84e1e 100644
--- a/config/ini.go
+++ b/config/ini.go
@@ -31,23 +31,23 @@ import (
)
var (
- DEFAULT_SECTION = "default" // default section means if some ini items not in a section, make them in default section,
- bNumComment = []byte{'#'} // number signal
- bSemComment = []byte{';'} // semicolon signal
- bEmpty = []byte{}
- bEqual = []byte{'='} // equal signal
- bDQuote = []byte{'"'} // quote signal
- sectionStart = []byte{'['} // section start signal
- sectionEnd = []byte{']'} // section end signal
- lineBreak = "\n"
+ defaultSection = "default" // default section means if some ini items not in a section, make them in default section,
+ bNumComment = []byte{'#'} // number signal
+ bSemComment = []byte{';'} // semicolon signal
+ bEmpty = []byte{}
+ bEqual = []byte{'='} // equal signal
+ bDQuote = []byte{'"'} // quote signal
+ sectionStart = []byte{'['} // section start signal
+ sectionEnd = []byte{']'} // section end signal
+ lineBreak = "\n"
)
// IniConfig implements Config to parse ini file.
type IniConfig struct {
}
-// ParseFile creates a new Config and parses the file configuration from the named file.
-func (ini *IniConfig) Parse(name string) (ConfigContainer, error) {
+// Parse creates a new Config and parses the file configuration from the named file.
+func (ini *IniConfig) Parse(name string) (Configer, error) {
return ini.parseFile(name)
}
@@ -77,7 +77,7 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
buf.ReadByte()
}
}
- section := DEFAULT_SECTION
+ section := defaultSection
for {
line, _, err := buf.ReadLine()
if err == io.EOF {
@@ -171,7 +171,8 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
return cfg, nil
}
-func (ini *IniConfig) ParseData(data []byte) (ConfigContainer, error) {
+// ParseData parse ini the data
+func (ini *IniConfig) ParseData(data []byte) (Configer, error) {
// Save memory data to temporary file
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond()))
os.MkdirAll(path.Dir(tmpName), os.ModePerm)
@@ -181,7 +182,7 @@ func (ini *IniConfig) ParseData(data []byte) (ConfigContainer, error) {
return ini.Parse(tmpName)
}
-// A Config represents the ini configuration.
+// IniConfigContainer A Config represents the ini configuration.
// When set and get value, support key as section:name type.
type IniConfigContainer struct {
filename string
@@ -199,11 +200,11 @@ func (c *IniConfigContainer) Bool(key string) (bool, error) {
// DefaultBool returns the boolean value for a given key.
// if err != nil return defaltval
func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool {
- if v, err := c.Bool(key); err != nil {
+ v, err := c.Bool(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// Int returns the integer value for a given key.
@@ -214,11 +215,11 @@ func (c *IniConfigContainer) Int(key string) (int, error) {
// DefaultInt returns the integer value for a given key.
// if err != nil return defaltval
func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int {
- if v, err := c.Int(key); err != nil {
+ v, err := c.Int(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// Int64 returns the int64 value for a given key.
@@ -229,11 +230,11 @@ func (c *IniConfigContainer) Int64(key string) (int64, error) {
// DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval
func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
- if v, err := c.Int64(key); err != nil {
+ v, err := c.Int64(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// Float returns the float value for a given key.
@@ -244,11 +245,11 @@ func (c *IniConfigContainer) Float(key string) (float64, error) {
// DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval
func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
- if v, err := c.Float(key); err != nil {
+ v, err := c.Float(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// String returns the string value for a given key.
@@ -259,11 +260,11 @@ func (c *IniConfigContainer) String(key string) string {
// DefaultString returns the string value for a given key.
// if err != nil return defaltval
func (c *IniConfigContainer) DefaultString(key string, defaultval string) string {
- if v := c.String(key); v == "" {
+ v := c.String(key)
+ if v == "" {
return defaultval
- } else {
- return v
}
+ return v
}
// Strings returns the []string value for a given key.
@@ -274,20 +275,19 @@ func (c *IniConfigContainer) Strings(key string) []string {
// DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval
func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string {
- if v := c.Strings(key); len(v) == 0 {
+ v := c.Strings(key)
+ if len(v) == 0 {
return defaultval
- } else {
- return v
}
+ return v
}
// GetSection returns map for the given section
func (c *IniConfigContainer) GetSection(section string) (map[string]string, error) {
if v, ok := c.data[section]; ok {
return v, nil
- } else {
- return nil, errors.New("not exist setction")
}
+ return nil, errors.New("not exist setction")
}
// SaveConfigFile save the config into file
@@ -300,21 +300,8 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
defer f.Close()
buf := bytes.NewBuffer(nil)
- for section, dt := range c.data {
- // Write section comments.
- if v, ok := c.sectionComment[section]; ok {
- if _, err = buf.WriteString(string(bNumComment) + v + lineBreak); err != nil {
- return err
- }
- }
-
- if section != DEFAULT_SECTION {
- // Write section name.
- if _, err = buf.WriteString(string(sectionStart) + section + string(sectionEnd) + lineBreak); err != nil {
- return err
- }
- }
-
+ // Save default section at first place
+ if dt, ok := c.data[defaultSection]; ok {
for key, val := range dt {
if key != " " {
// Write key comments.
@@ -336,6 +323,43 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
return err
}
}
+ // Save named sections
+ for section, dt := range c.data {
+ if section != defaultSection {
+ // Write section comments.
+ if v, ok := c.sectionComment[section]; ok {
+ if _, err = buf.WriteString(string(bNumComment) + v + lineBreak); err != nil {
+ return err
+ }
+ }
+
+ // Write section name.
+ if _, err = buf.WriteString(string(sectionStart) + section + string(sectionEnd) + lineBreak); err != nil {
+ return err
+ }
+
+ for key, val := range dt {
+ if key != " " {
+ // Write key comments.
+ if v, ok := c.keyComment[key]; ok {
+ if _, err = buf.WriteString(string(bNumComment) + v + lineBreak); err != nil {
+ return err
+ }
+ }
+
+ // Write key and value.
+ if _, err = buf.WriteString(key + string(bEqual) + val + lineBreak); err != nil {
+ return err
+ }
+ }
+ }
+
+ // Put a line between sections.
+ if _, err = buf.WriteString(lineBreak); err != nil {
+ return err
+ }
+ }
+ }
if _, err = buf.WriteTo(f); err != nil {
return err
@@ -343,7 +367,7 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
return nil
}
-// WriteValue writes a new value for key.
+// Set writes a new value for key.
// if write to one section, the key need be "section::key".
// if the section is not existed, it panics.
func (c *IniConfigContainer) Set(key, value string) error {
@@ -355,14 +379,14 @@ func (c *IniConfigContainer) Set(key, value string) error {
var (
section, k string
- sectionKey []string = strings.Split(key, "::")
+ sectionKey = strings.Split(key, "::")
)
if len(sectionKey) >= 2 {
section = sectionKey[0]
k = sectionKey[1]
} else {
- section = DEFAULT_SECTION
+ section = defaultSection
k = sectionKey[0]
}
@@ -391,13 +415,13 @@ func (c *IniConfigContainer) getdata(key string) string {
var (
section, k string
- sectionKey []string = strings.Split(strings.ToLower(key), "::")
+ sectionKey = strings.Split(strings.ToLower(key), "::")
)
if len(sectionKey) >= 2 {
section = sectionKey[0]
k = sectionKey[1]
} else {
- section = DEFAULT_SECTION
+ section = defaultSection
k = sectionKey[0]
}
if v, ok := c.data[section]; ok {
diff --git a/config/json.go b/config/json.go
index e2b53793..65b4ac48 100644
--- a/config/json.go
+++ b/config/json.go
@@ -23,12 +23,12 @@ import (
"sync"
)
-// JsonConfig is a json config parser and implements Config interface.
-type JsonConfig struct {
+// JSONConfig is a json config parser and implements Config interface.
+type JSONConfig struct {
}
// Parse returns a ConfigContainer with parsed json config map.
-func (js *JsonConfig) Parse(filename string) (ConfigContainer, error) {
+func (js *JSONConfig) Parse(filename string) (Configer, error) {
file, err := os.Open(filename)
if err != nil {
return nil, err
@@ -43,8 +43,8 @@ func (js *JsonConfig) Parse(filename string) (ConfigContainer, error) {
}
// ParseData returns a ConfigContainer with json string
-func (js *JsonConfig) ParseData(data []byte) (ConfigContainer, error) {
- x := &JsonConfigContainer{
+func (js *JSONConfig) ParseData(data []byte) (Configer, error) {
+ x := &JSONConfigContainer{
data: make(map[string]interface{}),
}
err := json.Unmarshal(data, &x.data)
@@ -59,15 +59,15 @@ func (js *JsonConfig) ParseData(data []byte) (ConfigContainer, error) {
return x, nil
}
-// A Config represents the json configuration.
+// JSONConfigContainer A Config represents the json configuration.
// Only when get value, support key as section:name type.
-type JsonConfigContainer struct {
+type JSONConfigContainer struct {
data map[string]interface{}
sync.RWMutex
}
// Bool returns the boolean value for a given key.
-func (c *JsonConfigContainer) Bool(key string) (bool, error) {
+func (c *JSONConfigContainer) Bool(key string) (bool, error) {
val := c.getData(key)
if val != nil {
if v, ok := val.(bool); ok {
@@ -80,7 +80,7 @@ func (c *JsonConfigContainer) Bool(key string) (bool, error) {
// DefaultBool return the bool value if has no error
// otherwise return the defaultval
-func (c *JsonConfigContainer) DefaultBool(key string, defaultval bool) bool {
+func (c *JSONConfigContainer) DefaultBool(key string, defaultval bool) bool {
if v, err := c.Bool(key); err == nil {
return v
}
@@ -88,7 +88,7 @@ func (c *JsonConfigContainer) DefaultBool(key string, defaultval bool) bool {
}
// Int returns the integer value for a given key.
-func (c *JsonConfigContainer) Int(key string) (int, error) {
+func (c *JSONConfigContainer) Int(key string) (int, error) {
val := c.getData(key)
if val != nil {
if v, ok := val.(float64); ok {
@@ -101,7 +101,7 @@ func (c *JsonConfigContainer) Int(key string) (int, error) {
// DefaultInt returns the integer value for a given key.
// if err != nil return defaltval
-func (c *JsonConfigContainer) DefaultInt(key string, defaultval int) int {
+func (c *JSONConfigContainer) DefaultInt(key string, defaultval int) int {
if v, err := c.Int(key); err == nil {
return v
}
@@ -109,7 +109,7 @@ func (c *JsonConfigContainer) DefaultInt(key string, defaultval int) int {
}
// Int64 returns the int64 value for a given key.
-func (c *JsonConfigContainer) Int64(key string) (int64, error) {
+func (c *JSONConfigContainer) Int64(key string) (int64, error) {
val := c.getData(key)
if val != nil {
if v, ok := val.(float64); ok {
@@ -122,7 +122,7 @@ func (c *JsonConfigContainer) Int64(key string) (int64, error) {
// DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval
-func (c *JsonConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
+func (c *JSONConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
if v, err := c.Int64(key); err == nil {
return v
}
@@ -130,7 +130,7 @@ func (c *JsonConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
}
// Float returns the float value for a given key.
-func (c *JsonConfigContainer) Float(key string) (float64, error) {
+func (c *JSONConfigContainer) Float(key string) (float64, error) {
val := c.getData(key)
if val != nil {
if v, ok := val.(float64); ok {
@@ -143,7 +143,7 @@ func (c *JsonConfigContainer) Float(key string) (float64, error) {
// DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval
-func (c *JsonConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
+func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
if v, err := c.Float(key); err == nil {
return v
}
@@ -151,7 +151,7 @@ func (c *JsonConfigContainer) DefaultFloat(key string, defaultval float64) float
}
// String returns the string value for a given key.
-func (c *JsonConfigContainer) String(key string) string {
+func (c *JSONConfigContainer) String(key string) string {
val := c.getData(key)
if val != nil {
if v, ok := val.(string); ok {
@@ -163,8 +163,8 @@ func (c *JsonConfigContainer) String(key string) string {
// DefaultString returns the string value for a given key.
// if err != nil return defaltval
-func (c *JsonConfigContainer) DefaultString(key string, defaultval string) string {
- // TODO FIXME should not use "" to replace non existance
+func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string {
+ // TODO FIXME should not use "" to replace non existence
if v := c.String(key); v != "" {
return v
}
@@ -172,7 +172,7 @@ func (c *JsonConfigContainer) DefaultString(key string, defaultval string) strin
}
// Strings returns the []string value for a given key.
-func (c *JsonConfigContainer) Strings(key string) []string {
+func (c *JSONConfigContainer) Strings(key string) []string {
stringVal := c.String(key)
if stringVal == "" {
return []string{}
@@ -182,7 +182,7 @@ func (c *JsonConfigContainer) Strings(key string) []string {
// DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval
-func (c *JsonConfigContainer) DefaultStrings(key string, defaultval []string) []string {
+func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string {
if v := c.Strings(key); len(v) > 0 {
return v
}
@@ -190,7 +190,7 @@ func (c *JsonConfigContainer) DefaultStrings(key string, defaultval []string) []
}
// GetSection returns map for the given section
-func (c *JsonConfigContainer) GetSection(section string) (map[string]string, error) {
+func (c *JSONConfigContainer) GetSection(section string) (map[string]string, error) {
if v, ok := c.data[section]; ok {
return v.(map[string]string), nil
}
@@ -198,7 +198,7 @@ func (c *JsonConfigContainer) GetSection(section string) (map[string]string, err
}
// SaveConfigFile save the config into file
-func (c *JsonConfigContainer) SaveConfigFile(filename string) (err error) {
+func (c *JSONConfigContainer) SaveConfigFile(filename string) (err error) {
// Write configuration file by filename.
f, err := os.Create(filename)
if err != nil {
@@ -214,7 +214,7 @@ func (c *JsonConfigContainer) SaveConfigFile(filename string) (err error) {
}
// Set writes a new value for key.
-func (c *JsonConfigContainer) Set(key, val string) error {
+func (c *JSONConfigContainer) Set(key, val string) error {
c.Lock()
defer c.Unlock()
c.data[key] = val
@@ -222,7 +222,7 @@ func (c *JsonConfigContainer) Set(key, val string) error {
}
// DIY returns the raw value by a given key.
-func (c *JsonConfigContainer) DIY(key string) (v interface{}, err error) {
+func (c *JSONConfigContainer) DIY(key string) (v interface{}, err error) {
val := c.getData(key)
if val != nil {
return val, nil
@@ -231,7 +231,7 @@ func (c *JsonConfigContainer) DIY(key string) (v interface{}, err error) {
}
// section.key or key
-func (c *JsonConfigContainer) getData(key string) interface{} {
+func (c *JSONConfigContainer) getData(key string) interface{} {
if len(key) == 0 {
return nil
}
@@ -261,5 +261,5 @@ func (c *JsonConfigContainer) getData(key string) interface{} {
}
func init() {
- Register("json", &JsonConfig{})
+ Register("json", &JSONConfig{})
}
diff --git a/config/xml/xml.go b/config/xml/xml.go
index a1d9fcdb..4d48f7d2 100644
--- a/config/xml/xml.go
+++ b/config/xml/xml.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// package xml for config provider
+// Package xml for config provider
//
// depend on github.com/beego/x2j
//
@@ -45,20 +45,20 @@ import (
"github.com/beego/x2j"
)
-// XmlConfig is a xml config parser and implements Config interface.
+// Config is a xml config parser and implements Config interface.
// xml configurations should be included in tag.
// only support key/value pair as value as each item.
-type XMLConfig struct{}
+type Config struct{}
// Parse returns a ConfigContainer with parsed xml config map.
-func (xc *XMLConfig) Parse(filename string) (config.ConfigContainer, error) {
+func (xc *Config) Parse(filename string) (config.Configer, error) {
file, err := os.Open(filename)
if err != nil {
return nil, err
}
defer file.Close()
- x := &XMLConfigContainer{data: make(map[string]interface{})}
+ x := &ConfigContainer{data: make(map[string]interface{})}
content, err := ioutil.ReadAll(file)
if err != nil {
return nil, err
@@ -73,84 +73,86 @@ func (xc *XMLConfig) Parse(filename string) (config.ConfigContainer, error) {
return x, nil
}
-func (x *XMLConfig) ParseData(data []byte) (config.ConfigContainer, error) {
+// ParseData xml data
+func (xc *Config) ParseData(data []byte) (config.Configer, error) {
// Save memory data to temporary file
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond()))
os.MkdirAll(path.Dir(tmpName), os.ModePerm)
if err := ioutil.WriteFile(tmpName, data, 0655); err != nil {
return nil, err
}
- return x.Parse(tmpName)
+ return xc.Parse(tmpName)
}
-// A Config represents the xml configuration.
-type XMLConfigContainer struct {
+// ConfigContainer A Config represents the xml configuration.
+type ConfigContainer struct {
data map[string]interface{}
sync.Mutex
}
// Bool returns the boolean value for a given key.
-func (c *XMLConfigContainer) Bool(key string) (bool, error) {
+func (c *ConfigContainer) Bool(key string) (bool, error) {
return strconv.ParseBool(c.data[key].(string))
}
// DefaultBool return the bool value if has no error
// otherwise return the defaultval
-func (c *XMLConfigContainer) DefaultBool(key string, defaultval bool) bool {
- if v, err := c.Bool(key); err != nil {
+func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool {
+ v, err := c.Bool(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// Int returns the integer value for a given key.
-func (c *XMLConfigContainer) Int(key string) (int, error) {
+func (c *ConfigContainer) Int(key string) (int, error) {
return strconv.Atoi(c.data[key].(string))
}
// DefaultInt returns the integer value for a given key.
// if err != nil return defaltval
-func (c *XMLConfigContainer) DefaultInt(key string, defaultval int) int {
- if v, err := c.Int(key); err != nil {
+func (c *ConfigContainer) DefaultInt(key string, defaultval int) int {
+ v, err := c.Int(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// Int64 returns the int64 value for a given key.
-func (c *XMLConfigContainer) Int64(key string) (int64, error) {
+func (c *ConfigContainer) Int64(key string) (int64, error) {
return strconv.ParseInt(c.data[key].(string), 10, 64)
}
// DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval
-func (c *XMLConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
- if v, err := c.Int64(key); err != nil {
+func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
+ v, err := c.Int64(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
+
}
// Float returns the float value for a given key.
-func (c *XMLConfigContainer) Float(key string) (float64, error) {
+func (c *ConfigContainer) Float(key string) (float64, error) {
return strconv.ParseFloat(c.data[key].(string), 64)
}
// DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval
-func (c *XMLConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
- if v, err := c.Float(key); err != nil {
+func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
+ v, err := c.Float(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// String returns the string value for a given key.
-func (c *XMLConfigContainer) String(key string) string {
+func (c *ConfigContainer) String(key string) string {
if v, ok := c.data[key].(string); ok {
return v
}
@@ -159,40 +161,39 @@ func (c *XMLConfigContainer) String(key string) string {
// DefaultString returns the string value for a given key.
// if err != nil return defaltval
-func (c *XMLConfigContainer) DefaultString(key string, defaultval string) string {
- if v := c.String(key); v == "" {
+func (c *ConfigContainer) DefaultString(key string, defaultval string) string {
+ v := c.String(key)
+ if v == "" {
return defaultval
- } else {
- return v
}
+ return v
}
// Strings returns the []string value for a given key.
-func (c *XMLConfigContainer) Strings(key string) []string {
+func (c *ConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval
-func (c *XMLConfigContainer) DefaultStrings(key string, defaultval []string) []string {
- if v := c.Strings(key); len(v) == 0 {
+func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string {
+ v := c.Strings(key)
+ if len(v) == 0 {
return defaultval
- } else {
- return v
}
+ return v
}
// GetSection returns map for the given section
-func (c *XMLConfigContainer) GetSection(section string) (map[string]string, error) {
+func (c *ConfigContainer) GetSection(section string) (map[string]string, error) {
if v, ok := c.data[section]; ok {
return v.(map[string]string), nil
- } else {
- return nil, errors.New("not exist setction")
}
+ return nil, errors.New("not exist setction")
}
// SaveConfigFile save the config into file
-func (c *XMLConfigContainer) SaveConfigFile(filename string) (err error) {
+func (c *ConfigContainer) SaveConfigFile(filename string) (err error) {
// Write configuration file by filename.
f, err := os.Create(filename)
if err != nil {
@@ -207,8 +208,8 @@ func (c *XMLConfigContainer) SaveConfigFile(filename string) (err error) {
return err
}
-// WriteValue writes a new value for key.
-func (c *XMLConfigContainer) Set(key, val string) error {
+// Set writes a new value for key.
+func (c *ConfigContainer) Set(key, val string) error {
c.Lock()
defer c.Unlock()
c.data[key] = val
@@ -216,7 +217,7 @@ func (c *XMLConfigContainer) Set(key, val string) error {
}
// DIY returns the raw value by a given key.
-func (c *XMLConfigContainer) DIY(key string) (v interface{}, err error) {
+func (c *ConfigContainer) DIY(key string) (v interface{}, err error) {
if v, ok := c.data[key]; ok {
return v, nil
}
@@ -224,5 +225,5 @@ func (c *XMLConfigContainer) DIY(key string) (v interface{}, err error) {
}
func init() {
- config.Register("xml", &XMLConfig{})
+ config.Register("xml", &Config{})
}
diff --git a/config/yaml/yaml.go b/config/yaml/yaml.go
index c5be44a9..f034d3ba 100644
--- a/config/yaml/yaml.go
+++ b/config/yaml/yaml.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// package yaml for config provider
+// Package yaml for config provider
//
// depend on github.com/beego/goyaml2
//
@@ -46,22 +46,23 @@ import (
"github.com/beego/goyaml2"
)
-// YAMLConfig is a yaml config parser and implements Config interface.
-type YAMLConfig struct{}
+// Config is a yaml config parser and implements Config interface.
+type Config struct{}
// Parse returns a ConfigContainer with parsed yaml config map.
-func (yaml *YAMLConfig) Parse(filename string) (y config.ConfigContainer, err error) {
+func (yaml *Config) Parse(filename string) (y config.Configer, err error) {
cnf, err := ReadYmlReader(filename)
if err != nil {
return
}
- y = &YAMLConfigContainer{
+ y = &ConfigContainer{
data: cnf,
}
return
}
-func (yaml *YAMLConfig) ParseData(data []byte) (config.ConfigContainer, error) {
+// ParseData parse yaml data
+func (yaml *Config) ParseData(data []byte) (config.Configer, error) {
// Save memory data to temporary file
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond()))
os.MkdirAll(path.Dir(tmpName), os.ModePerm)
@@ -71,7 +72,7 @@ func (yaml *YAMLConfig) ParseData(data []byte) (config.ConfigContainer, error) {
return yaml.Parse(tmpName)
}
-// Read yaml file to map.
+// ReadYmlReader Read yaml file to map.
// if json like, use json package, unless goyaml2 package.
func ReadYmlReader(path string) (cnf map[string]interface{}, err error) {
f, err := os.Open(path)
@@ -112,14 +113,14 @@ func ReadYmlReader(path string) (cnf map[string]interface{}, err error) {
return
}
-// A Config represents the yaml configuration.
-type YAMLConfigContainer struct {
+// ConfigContainer A Config represents the yaml configuration.
+type ConfigContainer struct {
data map[string]interface{}
sync.Mutex
}
// Bool returns the boolean value for a given key.
-func (c *YAMLConfigContainer) Bool(key string) (bool, error) {
+func (c *ConfigContainer) Bool(key string) (bool, error) {
if v, ok := c.data[key].(bool); ok {
return v, nil
}
@@ -128,16 +129,16 @@ func (c *YAMLConfigContainer) Bool(key string) (bool, error) {
// DefaultBool return the bool value if has no error
// otherwise return the defaultval
-func (c *YAMLConfigContainer) DefaultBool(key string, defaultval bool) bool {
- if v, err := c.Bool(key); err != nil {
+func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool {
+ v, err := c.Bool(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// Int returns the integer value for a given key.
-func (c *YAMLConfigContainer) Int(key string) (int, error) {
+func (c *ConfigContainer) Int(key string) (int, error) {
if v, ok := c.data[key].(int64); ok {
return int(v), nil
}
@@ -146,16 +147,16 @@ func (c *YAMLConfigContainer) Int(key string) (int, error) {
// DefaultInt returns the integer value for a given key.
// if err != nil return defaltval
-func (c *YAMLConfigContainer) DefaultInt(key string, defaultval int) int {
- if v, err := c.Int(key); err != nil {
+func (c *ConfigContainer) DefaultInt(key string, defaultval int) int {
+ v, err := c.Int(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// Int64 returns the int64 value for a given key.
-func (c *YAMLConfigContainer) Int64(key string) (int64, error) {
+func (c *ConfigContainer) Int64(key string) (int64, error) {
if v, ok := c.data[key].(int64); ok {
return v, nil
}
@@ -164,16 +165,16 @@ func (c *YAMLConfigContainer) Int64(key string) (int64, error) {
// DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval
-func (c *YAMLConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
- if v, err := c.Int64(key); err != nil {
+func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
+ v, err := c.Int64(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// Float returns the float value for a given key.
-func (c *YAMLConfigContainer) Float(key string) (float64, error) {
+func (c *ConfigContainer) Float(key string) (float64, error) {
if v, ok := c.data[key].(float64); ok {
return v, nil
}
@@ -182,16 +183,16 @@ func (c *YAMLConfigContainer) Float(key string) (float64, error) {
// DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval
-func (c *YAMLConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
- if v, err := c.Float(key); err != nil {
+func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
+ v, err := c.Float(key)
+ if err != nil {
return defaultval
- } else {
- return v
}
+ return v
}
// String returns the string value for a given key.
-func (c *YAMLConfigContainer) String(key string) string {
+func (c *ConfigContainer) String(key string) string {
if v, ok := c.data[key].(string); ok {
return v
}
@@ -200,40 +201,40 @@ func (c *YAMLConfigContainer) String(key string) string {
// DefaultString returns the string value for a given key.
// if err != nil return defaltval
-func (c *YAMLConfigContainer) DefaultString(key string, defaultval string) string {
- if v := c.String(key); v == "" {
+func (c *ConfigContainer) DefaultString(key string, defaultval string) string {
+ v := c.String(key)
+ if v == "" {
return defaultval
- } else {
- return v
}
+ return v
}
// Strings returns the []string value for a given key.
-func (c *YAMLConfigContainer) Strings(key string) []string {
+func (c *ConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval
-func (c *YAMLConfigContainer) DefaultStrings(key string, defaultval []string) []string {
- if v := c.Strings(key); len(v) == 0 {
+func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string {
+ v := c.Strings(key)
+ if len(v) == 0 {
return defaultval
- } else {
- return v
}
+ return v
}
// GetSection returns map for the given section
-func (c *YAMLConfigContainer) GetSection(section string) (map[string]string, error) {
- if v, ok := c.data[section]; ok {
+func (c *ConfigContainer) GetSection(section string) (map[string]string, error) {
+ v, ok := c.data[section]
+ if ok {
return v.(map[string]string), nil
- } else {
- return nil, errors.New("not exist setction")
}
+ return nil, errors.New("not exist setction")
}
// SaveConfigFile save the config into file
-func (c *YAMLConfigContainer) SaveConfigFile(filename string) (err error) {
+func (c *ConfigContainer) SaveConfigFile(filename string) (err error) {
// Write configuration file by filename.
f, err := os.Create(filename)
if err != nil {
@@ -244,8 +245,8 @@ func (c *YAMLConfigContainer) SaveConfigFile(filename string) (err error) {
return err
}
-// WriteValue writes a new value for key.
-func (c *YAMLConfigContainer) Set(key, val string) error {
+// Set writes a new value for key.
+func (c *ConfigContainer) Set(key, val string) error {
c.Lock()
defer c.Unlock()
c.data[key] = val
@@ -253,7 +254,7 @@ func (c *YAMLConfigContainer) Set(key, val string) error {
}
// DIY returns the raw value by a given key.
-func (c *YAMLConfigContainer) DIY(key string) (v interface{}, err error) {
+func (c *ConfigContainer) DIY(key string) (v interface{}, err error) {
if v, ok := c.data[key]; ok {
return v, nil
}
@@ -261,5 +262,5 @@ func (c *YAMLConfigContainer) DIY(key string) (v interface{}, err error) {
}
func init() {
- config.Register("yaml", &YAMLConfig{})
+ config.Register("yaml", &Config{})
}
diff --git a/config_test.go b/config_test.go
index 17645f80..cf4a781d 100644
--- a/config_test.go
+++ b/config_test.go
@@ -19,11 +19,11 @@ import (
)
func TestDefaults(t *testing.T) {
- if FlashName != "BEEGO_FLASH" {
+ if BConfig.WebConfig.FlashName != "BEEGO_FLASH" {
t.Errorf("FlashName was not set to default.")
}
- if FlashSeperator != "BEEGOFLASH" {
+ if BConfig.WebConfig.FlashSeparator != "BEEGOFLASH" {
t.Errorf("FlashName was not set to default.")
}
}
diff --git a/context/acceptencoder.go b/context/acceptencoder.go
new file mode 100644
index 00000000..033d9ca8
--- /dev/null
+++ b/context/acceptencoder.go
@@ -0,0 +1,197 @@
+// Copyright 2015 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package context
+
+import (
+ "bytes"
+ "compress/flate"
+ "compress/gzip"
+ "compress/zlib"
+ "io"
+ "net/http"
+ "os"
+ "strconv"
+ "strings"
+ "sync"
+)
+
+type resetWriter interface {
+ io.Writer
+ Reset(w io.Writer)
+}
+
+type nopResetWriter struct {
+ io.Writer
+}
+
+func (n nopResetWriter) Reset(w io.Writer) {
+ //do nothing
+}
+
+type acceptEncoder struct {
+ name string
+ levelEncode func(int) resetWriter
+ bestSpeedPool *sync.Pool
+ bestCompressionPool *sync.Pool
+}
+
+func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter {
+ if ac.bestSpeedPool == nil || ac.bestCompressionPool == nil {
+ return nopResetWriter{wr}
+ }
+ var rwr resetWriter
+ switch level {
+ case flate.BestSpeed:
+ rwr = ac.bestSpeedPool.Get().(resetWriter)
+ case flate.BestCompression:
+ rwr = ac.bestCompressionPool.Get().(resetWriter)
+ default:
+ rwr = ac.levelEncode(level)
+ }
+ rwr.Reset(wr)
+ return rwr
+}
+
+func (ac acceptEncoder) put(wr resetWriter, level int) {
+ if ac.bestSpeedPool == nil || ac.bestCompressionPool == nil {
+ return
+ }
+ wr.Reset(nil)
+ switch level {
+ case flate.BestSpeed:
+ ac.bestSpeedPool.Put(wr)
+ case flate.BestCompression:
+ ac.bestCompressionPool.Put(wr)
+ }
+}
+
+var (
+ noneCompressEncoder = acceptEncoder{"", nil, nil, nil}
+ gzipCompressEncoder = acceptEncoder{"gzip",
+ func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr },
+ &sync.Pool{
+ New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestSpeed); return wr },
+ },
+ &sync.Pool{
+ New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr },
+ },
+ }
+
+ //according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed
+ //deflate
+ //The "zlib" format defined in RFC 1950 [31] in combination with
+ //the "deflate" compression mechanism described in RFC 1951 [29].
+ deflateCompressEncoder = acceptEncoder{"deflate",
+ func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr },
+ &sync.Pool{
+ New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestSpeed); return wr },
+ },
+ &sync.Pool{
+ New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr },
+ },
+ }
+)
+
+var (
+ encoderMap = map[string]acceptEncoder{ // all the other compress methods will ignore
+ "gzip": gzipCompressEncoder,
+ "deflate": deflateCompressEncoder,
+ "*": gzipCompressEncoder, // * means any compress will accept,we prefer gzip
+ "identity": noneCompressEncoder, // identity means none-compress
+ }
+)
+
+// WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate)
+func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) {
+ return writeLevel(encoding, writer, file, flate.BestCompression)
+}
+
+// WriteBody reads writes content to writer by the specific encoding(gzip/deflate)
+func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) {
+ return writeLevel(encoding, writer, bytes.NewReader(content), flate.BestSpeed)
+}
+
+// writeLevel reads from reader,writes to writer by specific encoding and compress level
+// the compress level is defined by deflate package
+func writeLevel(encoding string, writer io.Writer, reader io.Reader, level int) (bool, string, error) {
+ var outputWriter resetWriter
+ var err error
+ var ce = noneCompressEncoder
+
+ if cf, ok := encoderMap[encoding]; ok {
+ ce = cf
+ }
+ encoding = ce.name
+ outputWriter = ce.encode(writer, level)
+ defer ce.put(outputWriter, level)
+
+ _, err = io.Copy(outputWriter, reader)
+ if err != nil {
+ return false, "", err
+ }
+
+ switch outputWriter.(type) {
+ case io.WriteCloser:
+ outputWriter.(io.WriteCloser).Close()
+ }
+ return encoding != "", encoding, nil
+}
+
+// ParseEncoding will extract the right encoding for response
+// the Accept-Encoding's sec is here:
+// http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.3
+func ParseEncoding(r *http.Request) string {
+ if r == nil {
+ return ""
+ }
+ return parseEncoding(r)
+}
+
+type q struct {
+ name string
+ value float64
+}
+
+func parseEncoding(r *http.Request) string {
+ acceptEncoding := r.Header.Get("Accept-Encoding")
+ if acceptEncoding == "" {
+ return ""
+ }
+ var lastQ q
+ for _, v := range strings.Split(acceptEncoding, ",") {
+ v = strings.TrimSpace(v)
+ if v == "" {
+ continue
+ }
+ vs := strings.Split(v, ";")
+ if len(vs) == 1 {
+ lastQ = q{vs[0], 1}
+ break
+ }
+ if len(vs) == 2 {
+ f, _ := strconv.ParseFloat(strings.Replace(vs[1], "q=", "", -1), 64)
+ if f == 0 {
+ continue
+ }
+ if f > lastQ.value {
+ lastQ = q{vs[0], f}
+ }
+ }
+ }
+ if cf, ok := encoderMap[lastQ.name]; ok {
+ return cf.name
+ }
+ return ""
+}
diff --git a/context/acceptencoder_test.go b/context/acceptencoder_test.go
new file mode 100644
index 00000000..3afff679
--- /dev/null
+++ b/context/acceptencoder_test.go
@@ -0,0 +1,44 @@
+// Copyright 2015 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package context
+
+import (
+ "net/http"
+ "testing"
+)
+
+func Test_ExtractEncoding(t *testing.T) {
+ if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip,deflate"}}}) != "gzip" {
+ t.Fail()
+ }
+ if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"deflate,gzip"}}}) != "deflate" {
+ t.Fail()
+ }
+ if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=.5,deflate"}}}) != "deflate" {
+ t.Fail()
+ }
+ if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=.5,deflate;q=0.3"}}}) != "gzip" {
+ t.Fail()
+ }
+ if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0,deflate"}}}) != "deflate" {
+ t.Fail()
+ }
+ if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"deflate;q=0.5,gzip;q=0.5,identity"}}}) != "" {
+ t.Fail()
+ }
+ if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"*"}}}) != "gzip" {
+ t.Fail()
+ }
+}
diff --git a/context/context.go b/context/context.go
index f6aa85d6..db790ff2 100644
--- a/context/context.go
+++ b/context/context.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package context provide the context utils
// Usage:
//
// import "github.com/astaxie/beego/context"
@@ -22,10 +23,13 @@
package context
import (
+ "bufio"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
+ "errors"
"fmt"
+ "net"
"net/http"
"strconv"
"strings"
@@ -34,14 +38,30 @@ import (
"github.com/astaxie/beego/utils"
)
-// Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter.
+// NewContext return the Context with Input and Output
+func NewContext() *Context {
+ return &Context{
+ Input: NewInput(),
+ Output: NewOutput(),
+ }
+}
+
+// Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter.
// BeegoInput and BeegoOutput provides some api to operate request and response more easily.
type Context struct {
Input *BeegoInput
Output *BeegoOutput
Request *http.Request
- ResponseWriter http.ResponseWriter
- _xsrf_token string
+ ResponseWriter *Response
+ _xsrfToken string
+}
+
+// Reset init Context, BeegoInput and BeegoOutput
+func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) {
+ ctx.Request = r
+ ctx.ResponseWriter = &Response{rw, false, 0}
+ ctx.Input.Reset(ctx)
+ ctx.Output.Reset(ctx)
}
// Redirect does redirection to localurl with http header status code.
@@ -54,29 +74,28 @@ func (ctx *Context) Redirect(status int, localurl string) {
// Abort stops this request.
// if beego.ErrorMaps exists, panic body.
func (ctx *Context) Abort(status int, body string) {
- ctx.ResponseWriter.WriteHeader(status)
panic(body)
}
-// Write string to response body.
+// WriteString Write string to response body.
// it sends response body.
func (ctx *Context) WriteString(content string) {
ctx.ResponseWriter.Write([]byte(content))
}
-// Get cookie from request by a given key.
+// GetCookie Get cookie from request by a given key.
// It's alias of BeegoInput.Cookie.
func (ctx *Context) GetCookie(key string) string {
return ctx.Input.Cookie(key)
}
-// Set cookie for response.
+// SetCookie Set cookie for response.
// It's alias of BeegoOutput.Cookie.
func (ctx *Context) SetCookie(name string, value string, others ...interface{}) {
ctx.Output.Cookie(name, value, others...)
}
-// Get secure cookie from request by a given key.
+// GetSecureCookie Get secure cookie from request by a given key.
func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) {
val := ctx.Input.Cookie(key)
if val == "" {
@@ -103,7 +122,7 @@ func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) {
return string(res), true
}
-// Set Secure cookie for response.
+// SetSecureCookie Set Secure cookie for response.
func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) {
vs := base64.URLEncoding.EncodeToString([]byte(value))
timestamp := strconv.FormatInt(time.Now().UnixNano(), 10)
@@ -114,23 +133,23 @@ func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interf
ctx.Output.Cookie(name, cookie, others...)
}
-// XsrfToken creates a xsrf token string and returns.
-func (ctx *Context) XsrfToken(key string, expire int64) string {
- if ctx._xsrf_token == "" {
+// XSRFToken creates a xsrf token string and returns.
+func (ctx *Context) XSRFToken(key string, expire int64) string {
+ if ctx._xsrfToken == "" {
token, ok := ctx.GetSecureCookie(key, "_xsrf")
if !ok {
token = string(utils.RandomCreateBytes(32))
ctx.SetSecureCookie(key, "_xsrf", token, expire)
}
- ctx._xsrf_token = token
+ ctx._xsrfToken = token
}
- return ctx._xsrf_token
+ return ctx._xsrfToken
}
-// CheckXsrfCookie checks xsrf token in this request is valid or not.
+// CheckXSRFCookie checks xsrf token in this request is valid or not.
// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken"
// or in form field value named as "_xsrf".
-func (ctx *Context) CheckXsrfCookie() bool {
+func (ctx *Context) CheckXSRFCookie() bool {
token := ctx.Input.Query("_xsrf")
if token == "" {
token = ctx.Request.Header.Get("X-Xsrftoken")
@@ -142,9 +161,57 @@ func (ctx *Context) CheckXsrfCookie() bool {
ctx.Abort(403, "'_xsrf' argument missing from POST")
return false
}
- if ctx._xsrf_token != token {
+ if ctx._xsrfToken != token {
ctx.Abort(403, "XSRF cookie does not match POST argument")
return false
}
return true
}
+
+//Response is a wrapper for the http.ResponseWriter
+//started set to true if response was written to then don't execute other handler
+type Response struct {
+ http.ResponseWriter
+ Started bool
+ Status int
+}
+
+// Write writes the data to the connection as part of an HTTP reply,
+// and sets `started` to true.
+// started means the response has sent out.
+func (w *Response) Write(p []byte) (int, error) {
+ w.Started = true
+ return w.ResponseWriter.Write(p)
+}
+
+// WriteHeader sends an HTTP response header with status code,
+// and sets `started` to true.
+func (w *Response) WriteHeader(code int) {
+ w.Status = code
+ w.Started = true
+ w.ResponseWriter.WriteHeader(code)
+}
+
+// Hijack hijacker for http
+func (w *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ hj, ok := w.ResponseWriter.(http.Hijacker)
+ if !ok {
+ return nil, nil, errors.New("webserver doesn't support hijacking")
+ }
+ return hj.Hijack()
+}
+
+// Flush http.Flusher
+func (w *Response) Flush() {
+ if f, ok := w.ResponseWriter.(http.Flusher); ok {
+ f.Flush()
+ }
+}
+
+// CloseNotify http.CloseNotifier
+func (w *Response) CloseNotify() <-chan bool {
+ if cn, ok := w.ResponseWriter.(http.CloseNotifier); ok {
+ return cn.CloseNotify()
+ }
+ return nil
+}
diff --git a/context/input.go b/context/input.go
index f535e6a2..c37204bd 100644
--- a/context/input.go
+++ b/context/input.go
@@ -17,50 +17,69 @@ package context
import (
"bytes"
"errors"
+ "io"
"io/ioutil"
- "net/http"
"net/url"
"reflect"
+ "regexp"
"strconv"
"strings"
"github.com/astaxie/beego/session"
)
+// Regexes for checking the accept headers
+// TODO make sure these are correct
+var (
+ acceptsHTMLRegex = regexp.MustCompile(`(text/html|application/xhtml\+xml)(?:,|$)`)
+ acceptsXMLRegex = regexp.MustCompile(`(application/xml|text/xml)(?:,|$)`)
+ acceptsJSONRegex = regexp.MustCompile(`(application/json)(?:,|$)`)
+ maxParam = 50
+)
+
// BeegoInput operates the http request header, data, cookie and body.
// it also contains router params and current session.
type BeegoInput struct {
- CruSession session.SessionStore
- Params map[string]string
- Data map[interface{}]interface{} // store some values in this context when calling context in filter or controller.
- Request *http.Request
- RequestBody []byte
- RunController reflect.Type
- RunMethod string
+ Context *Context
+ CruSession session.Store
+ pnames []string
+ pvalues []string
+ data map[interface{}]interface{} // store some values in this context when calling context in filter or controller.
+ RequestBody []byte
}
-// NewInput return BeegoInput generated by http.Request.
-func NewInput(req *http.Request) *BeegoInput {
+// NewInput return BeegoInput generated by Context.
+func NewInput() *BeegoInput {
return &BeegoInput{
- Params: make(map[string]string),
- Data: make(map[interface{}]interface{}),
- Request: req,
+ pnames: make([]string, 0, maxParam),
+ pvalues: make([]string, 0, maxParam),
+ data: make(map[interface{}]interface{}),
}
}
+// Reset init the BeegoInput
+func (input *BeegoInput) Reset(ctx *Context) {
+ input.Context = ctx
+ input.CruSession = nil
+ input.pnames = input.pnames[:0]
+ input.pvalues = input.pvalues[:0]
+ input.data = nil
+ input.RequestBody = []byte{}
+}
+
// Protocol returns request protocol name, such as HTTP/1.1 .
func (input *BeegoInput) Protocol() string {
- return input.Request.Proto
+ return input.Context.Request.Proto
}
-// Uri returns full request url with query string, fragment.
-func (input *BeegoInput) Uri() string {
- return input.Request.RequestURI
+// URI returns full request url with query string, fragment.
+func (input *BeegoInput) URI() string {
+ return input.Context.Request.RequestURI
}
-// Url returns request url path (without query string, fragment).
-func (input *BeegoInput) Url() string {
- return input.Request.URL.Path
+// URL returns request url path (without query string, fragment).
+func (input *BeegoInput) URL() string {
+ return input.Context.Request.URL.Path
}
// Site returns base site url as scheme://domain type.
@@ -70,10 +89,10 @@ func (input *BeegoInput) Site() string {
// Scheme returns request scheme as "http" or "https".
func (input *BeegoInput) Scheme() string {
- if input.Request.URL.Scheme != "" {
- return input.Request.URL.Scheme
+ if input.Context.Request.URL.Scheme != "" {
+ return input.Context.Request.URL.Scheme
}
- if input.Request.TLS == nil {
+ if input.Context.Request.TLS == nil {
return "http"
}
return "https"
@@ -88,19 +107,19 @@ func (input *BeegoInput) Domain() string {
// Host returns host name.
// if no host info in request, return localhost.
func (input *BeegoInput) Host() string {
- if input.Request.Host != "" {
- hostParts := strings.Split(input.Request.Host, ":")
+ if input.Context.Request.Host != "" {
+ hostParts := strings.Split(input.Context.Request.Host, ":")
if len(hostParts) > 0 {
return hostParts[0]
}
- return input.Request.Host
+ return input.Context.Request.Host
}
return "localhost"
}
// Method returns http request method.
func (input *BeegoInput) Method() string {
- return input.Request.Method
+ return input.Context.Request.Method
}
// Is returns boolean of this request is on given method, such as Is("POST").
@@ -108,37 +127,37 @@ func (input *BeegoInput) Is(method string) bool {
return input.Method() == method
}
-// Is this a GET method request?
+// IsGet Is this a GET method request?
func (input *BeegoInput) IsGet() bool {
return input.Is("GET")
}
-// Is this a POST method request?
+// IsPost Is this a POST method request?
func (input *BeegoInput) IsPost() bool {
return input.Is("POST")
}
-// Is this a Head method request?
+// IsHead Is this a Head method request?
func (input *BeegoInput) IsHead() bool {
return input.Is("HEAD")
}
-// Is this a OPTIONS method request?
+// IsOptions Is this a OPTIONS method request?
func (input *BeegoInput) IsOptions() bool {
return input.Is("OPTIONS")
}
-// Is this a PUT method request?
+// IsPut Is this a PUT method request?
func (input *BeegoInput) IsPut() bool {
return input.Is("PUT")
}
-// Is this a DELETE method request?
+// IsDelete Is this a DELETE method request?
func (input *BeegoInput) IsDelete() bool {
return input.Is("DELETE")
}
-// Is this a PATCH method request?
+// IsPatch Is this a PATCH method request?
func (input *BeegoInput) IsPatch() bool {
return input.Is("PATCH")
}
@@ -163,6 +182,21 @@ func (input *BeegoInput) IsUpload() bool {
return strings.Contains(input.Header("Content-Type"), "multipart/form-data")
}
+// AcceptsHTML Checks if request accepts html response
+func (input *BeegoInput) AcceptsHTML() bool {
+ return acceptsHTMLRegex.MatchString(input.Header("Accept"))
+}
+
+// AcceptsXML Checks if request accepts xml response
+func (input *BeegoInput) AcceptsXML() bool {
+ return acceptsXMLRegex.MatchString(input.Header("Accept"))
+}
+
+// AcceptsJSON Checks if request accepts json response
+func (input *BeegoInput) AcceptsJSON() bool {
+ return acceptsJSONRegex.MatchString(input.Header("Accept"))
+}
+
// IP returns request client ip.
// if in proxy, return first proxy id.
// if error, return 127.0.0.1.
@@ -172,7 +206,7 @@ func (input *BeegoInput) IP() string {
rip := strings.Split(ips[0], ":")
return rip[0]
}
- ip := strings.Split(input.Request.RemoteAddr, ":")
+ ip := strings.Split(input.Context.Request.RemoteAddr, ":")
if len(ip) > 0 {
if ip[0] != "[" {
return ip[0]
@@ -212,7 +246,7 @@ func (input *BeegoInput) SubDomains() string {
// Port returns request client port.
// when error or empty, return 80.
func (input *BeegoInput) Port() int {
- parts := strings.Split(input.Request.Host, ":")
+ parts := strings.Split(input.Context.Request.Host, ":")
if len(parts) == 2 {
port, _ := strconv.Atoi(parts[1])
return port
@@ -225,35 +259,59 @@ func (input *BeegoInput) UserAgent() string {
return input.Header("User-Agent")
}
+// ParamsLen return the length of the params
+func (input *BeegoInput) ParamsLen() int {
+ return len(input.pnames)
+}
+
// Param returns router param by a given key.
func (input *BeegoInput) Param(key string) string {
- if v, ok := input.Params[key]; ok {
- return v
+ for i, v := range input.pnames {
+ if v == key && i <= len(input.pvalues) {
+ return input.pvalues[i]
+ }
}
return ""
}
+// Params returns the map[key]value.
+func (input *BeegoInput) Params() map[string]string {
+ m := make(map[string]string)
+ for i, v := range input.pnames {
+ if i <= len(input.pvalues) {
+ m[v] = input.pvalues[i]
+ }
+ }
+ return m
+}
+
+// SetParam will set the param with key and value
+func (input *BeegoInput) SetParam(key, val string) {
+ input.pvalues = append(input.pvalues, val)
+ input.pnames = append(input.pnames, key)
+}
+
// Query returns input data item string by a given string.
func (input *BeegoInput) Query(key string) string {
if val := input.Param(key); val != "" {
return val
}
- if input.Request.Form == nil {
- input.Request.ParseForm()
+ if input.Context.Request.Form == nil {
+ input.Context.Request.ParseForm()
}
- return input.Request.Form.Get(key)
+ return input.Context.Request.Form.Get(key)
}
// Header returns request header item string by a given string.
// if non-existed, return empty string.
func (input *BeegoInput) Header(key string) string {
- return input.Request.Header.Get(key)
+ return input.Context.Request.Header.Get(key)
}
// Cookie returns request cookie item string by a given key.
// if non-existed, return empty string.
func (input *BeegoInput) Cookie(key string) string {
- ck, err := input.Request.Cookie(key)
+ ck, err := input.Context.Request.Cookie(key)
if err != nil {
return ""
}
@@ -267,18 +325,27 @@ func (input *BeegoInput) Session(key interface{}) interface{} {
}
// CopyBody returns the raw request body data as bytes.
-func (input *BeegoInput) CopyBody() []byte {
- requestbody, _ := ioutil.ReadAll(input.Request.Body)
- input.Request.Body.Close()
+func (input *BeegoInput) CopyBody(MaxMemory int64) []byte {
+ safe := &io.LimitedReader{R: input.Context.Request.Body, N: MaxMemory}
+ requestbody, _ := ioutil.ReadAll(safe)
+ input.Context.Request.Body.Close()
bf := bytes.NewBuffer(requestbody)
- input.Request.Body = ioutil.NopCloser(bf)
+ input.Context.Request.Body = ioutil.NopCloser(bf)
input.RequestBody = requestbody
return requestbody
}
+// Data return the implicit data in the input
+func (input *BeegoInput) Data() map[interface{}]interface{} {
+ if input.data == nil {
+ input.data = make(map[interface{}]interface{})
+ }
+ return input.data
+}
+
// GetData returns the stored data in this context.
func (input *BeegoInput) GetData(key interface{}) interface{} {
- if v, ok := input.Data[key]; ok {
+ if v, ok := input.data[key]; ok {
return v
}
return nil
@@ -287,17 +354,20 @@ func (input *BeegoInput) GetData(key interface{}) interface{} {
// SetData stores data with given key in this context.
// This data are only available in this context.
func (input *BeegoInput) SetData(key, val interface{}) {
- input.Data[key] = val
+ if input.data == nil {
+ input.data = make(map[interface{}]interface{})
+ }
+ input.data[key] = val
}
-// parseForm or parseMultiForm based on Content-type
+// ParseFormOrMulitForm parseForm or parseMultiForm based on Content-type
func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error {
// Parse the body depending on the content type.
if strings.Contains(input.Header("Content-Type"), "multipart/form-data") {
- if err := input.Request.ParseMultipartForm(maxMemory); err != nil {
+ if err := input.Context.Request.ParseMultipartForm(maxMemory); err != nil {
return errors.New("Error parsing request body:" + err.Error())
}
- } else if err := input.Request.ParseForm(); err != nil {
+ } else if err := input.Context.Request.ParseForm(); err != nil {
return errors.New("Error parsing request body:" + err.Error())
}
return nil
@@ -306,7 +376,7 @@ func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error {
// Bind data from request.Form[key] to dest
// like /?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie
// var id int beegoInput.Bind(&id, "id") id ==123
-// var isok bool beegoInput.Bind(&isok, "isok") id ==true
+// var isok bool beegoInput.Bind(&isok, "isok") isok ==true
// var ft float64 beegoInput.Bind(&ft, "ft") ft ==1.2
// ol := make([]int, 0, 2) beegoInput.Bind(&ol, "ol") ol ==[1 2]
// ul := make([]string, 0, 2) beegoInput.Bind(&ul, "ul") ul ==[str array]
@@ -329,7 +399,7 @@ func (input *BeegoInput) Bind(dest interface{}, key string) error {
}
func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value {
- rv := reflect.Zero(reflect.TypeOf(0))
+ rv := reflect.Zero(typ)
switch typ.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
val := input.Query(key)
@@ -362,19 +432,19 @@ func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value {
}
rv = input.bindBool(val, typ)
case reflect.Slice:
- rv = input.bindSlice(&input.Request.Form, key, typ)
+ rv = input.bindSlice(&input.Context.Request.Form, key, typ)
case reflect.Struct:
- rv = input.bindStruct(&input.Request.Form, key, typ)
+ rv = input.bindStruct(&input.Context.Request.Form, key, typ)
case reflect.Ptr:
rv = input.bindPoint(key, typ)
case reflect.Map:
- rv = input.bindMap(&input.Request.Form, key, typ)
+ rv = input.bindMap(&input.Context.Request.Form, key, typ)
}
return rv
}
func (input *BeegoInput) bindValue(val string, typ reflect.Type) reflect.Value {
- rv := reflect.Zero(reflect.TypeOf(0))
+ rv := reflect.Zero(typ)
switch typ.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
rv = input.bindInt(val, typ)
diff --git a/context/input_test.go b/context/input_test.go
index 4566f6d6..618e1254 100644
--- a/context/input_test.go
+++ b/context/input_test.go
@@ -17,12 +17,15 @@ package context
import (
"fmt"
"net/http"
+ "net/http/httptest"
"testing"
)
func TestParse(t *testing.T) {
r, _ := http.NewRequest("GET", "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil)
- beegoInput := NewInput(r)
+ beegoInput := NewInput()
+ beegoInput.Context = NewContext()
+ beegoInput.Context.Reset(httptest.NewRecorder(), r)
beegoInput.ParseFormOrMulitForm(1 << 20)
var id int
@@ -73,7 +76,9 @@ func TestParse(t *testing.T) {
func TestSubDomain(t *testing.T) {
r, _ := http.NewRequest("GET", "http://www.example.com/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil)
- beegoInput := NewInput(r)
+ beegoInput := NewInput()
+ beegoInput.Context = NewContext()
+ beegoInput.Context.Reset(httptest.NewRecorder(), r)
subdomain := beegoInput.SubDomains()
if subdomain != "www" {
@@ -81,13 +86,13 @@ func TestSubDomain(t *testing.T) {
}
r, _ = http.NewRequest("GET", "http://localhost/", nil)
- beegoInput.Request = r
+ beegoInput.Context.Request = r
if beegoInput.SubDomains() != "" {
t.Fatal("Subdomain parse error, should be empty, got " + beegoInput.SubDomains())
}
r, _ = http.NewRequest("GET", "http://aa.bb.example.com/", nil)
- beegoInput.Request = r
+ beegoInput.Context.Request = r
if beegoInput.SubDomains() != "aa.bb" {
t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains())
}
@@ -101,13 +106,13 @@ func TestSubDomain(t *testing.T) {
*/
r, _ = http.NewRequest("GET", "http://example.com/", nil)
- beegoInput.Request = r
+ beegoInput.Context.Request = r
if beegoInput.SubDomains() != "" {
t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains())
}
r, _ = http.NewRequest("GET", "http://aa.bb.cc.dd.example.com/", nil)
- beegoInput.Request = r
+ beegoInput.Context.Request = r
if beegoInput.SubDomains() != "aa.bb.cc.dd" {
t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains())
}
diff --git a/context/output.go b/context/output.go
index 2141513d..2d756e27 100644
--- a/context/output.go
+++ b/context/output.go
@@ -16,8 +16,6 @@ package context
import (
"bytes"
- "compress/flate"
- "compress/gzip"
"encoding/json"
"encoding/xml"
"errors"
@@ -29,6 +27,7 @@ import (
"path/filepath"
"strconv"
"strings"
+ "time"
)
// BeegoOutput does work for sending response header.
@@ -44,6 +43,12 @@ func NewOutput() *BeegoOutput {
return &BeegoOutput{}
}
+// Reset init BeegoOutput
+func (output *BeegoOutput) Reset(ctx *Context) {
+ output.Context = ctx
+ output.Status = 0
+}
+
// Header sets response header item string via given key.
func (output *BeegoOutput) Header(key, val string) {
output.Context.ResponseWriter.Header().Set(key, val)
@@ -53,30 +58,16 @@ func (output *BeegoOutput) Header(key, val string) {
// if EnableGzip, compress content string.
// it sends out response body directly.
func (output *BeegoOutput) Body(content []byte) {
- output_writer := output.Context.ResponseWriter.(io.Writer)
- if output.EnableGzip == true && output.Context.Input.Header("Accept-Encoding") != "" {
- splitted := strings.SplitN(output.Context.Input.Header("Accept-Encoding"), ",", -1)
- encodings := make([]string, len(splitted))
-
- for i, val := range splitted {
- encodings[i] = strings.TrimSpace(val)
- }
- for _, val := range encodings {
- if val == "gzip" {
- output.Header("Content-Encoding", "gzip")
- output_writer, _ = gzip.NewWriterLevel(output.Context.ResponseWriter, gzip.BestSpeed)
-
- break
- } else if val == "deflate" {
- output.Header("Content-Encoding", "deflate")
- output_writer, _ = flate.NewWriter(output.Context.ResponseWriter, flate.BestSpeed)
- break
- }
- }
+ var encoding string
+ var buf = &bytes.Buffer{}
+ if output.EnableGzip {
+ encoding = ParseEncoding(output.Context.Request)
+ }
+ if b, n, _ := WriteBody(encoding, buf, content); b {
+ output.Header("Content-Encoding", n)
} else {
output.Header("Content-Length", strconv.Itoa(len(content)))
}
-
// Write status code if it has been set manually
// Set it to 0 afterwards to prevent "multiple response.WriteHeader calls"
if output.Status != 0 {
@@ -84,13 +75,7 @@ func (output *BeegoOutput) Body(content []byte) {
output.Status = 0
}
- output_writer.Write(content)
- switch output_writer.(type) {
- case *gzip.Writer:
- output_writer.(*gzip.Writer).Close()
- case *flate.Writer:
- output_writer.(*flate.Writer).Close()
- }
+ io.Copy(output.Context.ResponseWriter, buf)
}
// Cookie sets cookie value via given key.
@@ -98,26 +83,24 @@ func (output *BeegoOutput) Body(content []byte) {
func (output *BeegoOutput) Cookie(name string, value string, others ...interface{}) {
var b bytes.Buffer
fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value))
+
+ //fix cookie not work in IE
if len(others) > 0 {
+ var maxAge int64
+
switch v := others[0].(type) {
case int:
- if v > 0 {
- fmt.Fprintf(&b, "; Max-Age=%d", v)
- } else if v < 0 {
- fmt.Fprintf(&b, "; Max-Age=0")
- }
- case int64:
- if v > 0 {
- fmt.Fprintf(&b, "; Max-Age=%d", v)
- } else if v < 0 {
- fmt.Fprintf(&b, "; Max-Age=0")
- }
+ maxAge = int64(v)
case int32:
- if v > 0 {
- fmt.Fprintf(&b, "; Max-Age=%d", v)
- } else if v < 0 {
- fmt.Fprintf(&b, "; Max-Age=0")
- }
+ maxAge = int64(v)
+ case int64:
+ maxAge = v
+ }
+
+ if maxAge > 0 {
+ fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(maxAge)*time.Second).UTC().Format(time.RFC1123), maxAge)
+ } else {
+ fmt.Fprintf(&b, "; Max-Age=0")
}
}
@@ -185,9 +168,9 @@ func sanitizeValue(v string) string {
return cookieValueSanitizer.Replace(v)
}
-// Json writes json to response body.
+// JSON writes json to response body.
// if coding is true, it converts utf-8 to \u0000 type.
-func (output *BeegoOutput) Json(data interface{}, hasIndent bool, coding bool) error {
+func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, coding bool) error {
output.Header("Content-Type", "application/json; charset=utf-8")
var content []byte
var err error
@@ -201,14 +184,14 @@ func (output *BeegoOutput) Json(data interface{}, hasIndent bool, coding bool) e
return err
}
if coding {
- content = []byte(stringsToJson(string(content)))
+ content = []byte(stringsToJSON(string(content)))
}
output.Body(content)
return nil
}
-// Jsonp writes jsonp to response body.
-func (output *BeegoOutput) Jsonp(data interface{}, hasIndent bool) error {
+// JSONP writes jsonp to response body.
+func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error {
output.Header("Content-Type", "application/javascript; charset=utf-8")
var content []byte
var err error
@@ -225,16 +208,16 @@ func (output *BeegoOutput) Jsonp(data interface{}, hasIndent bool) error {
if callback == "" {
return errors.New(`"callback" parameter required`)
}
- callback_content := bytes.NewBufferString(" " + template.JSEscapeString(callback))
- callback_content.WriteString("(")
- callback_content.Write(content)
- callback_content.WriteString(");\r\n")
- output.Body(callback_content.Bytes())
+ callbackContent := bytes.NewBufferString(" " + template.JSEscapeString(callback))
+ callbackContent.WriteString("(")
+ callbackContent.Write(content)
+ callbackContent.WriteString(");\r\n")
+ output.Body(callbackContent.Bytes())
return nil
}
-// Xml writes xml string to response body.
-func (output *BeegoOutput) Xml(data interface{}, hasIndent bool) error {
+// XML writes xml string to response body.
+func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error {
output.Header("Content-Type", "application/xml; charset=utf-8")
var content []byte
var err error
@@ -328,7 +311,7 @@ func (output *BeegoOutput) IsNotFound(status int) bool {
return output.Status == 404
}
-// IsClient returns boolean of this request client sends error data.
+// IsClientError returns boolean of this request client sends error data.
// HTTP 4xx means forbidden.
func (output *BeegoOutput) IsClientError(status int) bool {
return output.Status >= 400 && output.Status < 500
@@ -340,7 +323,7 @@ func (output *BeegoOutput) IsServerError(status int) bool {
return output.Status >= 500 && output.Status < 600
}
-func stringsToJson(str string) string {
+func stringsToJSON(str string) string {
rs := []rune(str)
jsons := ""
for _, r := range rs {
@@ -354,7 +337,7 @@ func stringsToJson(str string) string {
return jsons
}
-// Sessions sets session item value with given key.
+// Session sets session item value with given key.
func (output *BeegoOutput) Session(name interface{}, value interface{}) {
output.Context.Input.CruSession.Set(name, value)
}
diff --git a/controller.go b/controller.go
index e056f52d..a2943d42 100644
--- a/controller.go
+++ b/controller.go
@@ -19,7 +19,6 @@ import (
"errors"
"html/template"
"io"
- "io/ioutil"
"mime/multipart"
"net/http"
"net/url"
@@ -34,18 +33,19 @@ import (
//commonly used mime-types
const (
- applicationJson = "application/json"
- applicationXml = "application/xml"
- textXml = "text/xml"
+ applicationJSON = "application/json"
+ applicationXML = "application/xml"
+ textXML = "text/xml"
)
var (
- // custom error when user stop request handler manually.
- USERSTOPRUN = errors.New("User stop run")
- GlobalControllerRouter map[string][]ControllerComments = make(map[string][]ControllerComments) //pkgpath+controller:comments
+ // ErrAbort custom error when user stop request handler manually.
+ ErrAbort = errors.New("User stop run")
+ // GlobalControllerRouter store comments with controller. pkgpath+controller:comments
+ GlobalControllerRouter = make(map[string][]ControllerComments)
)
-// store the comment for the controller method
+// ControllerComments store the comment for the controller method
type ControllerComments struct {
Method string
Router string
@@ -56,22 +56,31 @@ type ControllerComments struct {
// Controller defines some basic http request handler operations, such as
// http context, template and view, session and xsrf.
type Controller struct {
- Ctx *context.Context
- Data map[interface{}]interface{}
+ // context data
+ Ctx *context.Context
+ Data map[interface{}]interface{}
+
+ // route controller info
controllerName string
actionName string
- TplNames string
+ methodMapping map[string]func() //method:routertree
+ gotofunc string
+ AppController interface{}
+
+ // template data
+ TplName string
Layout string
LayoutSections map[string]string // the key is the section name and the value is the template name
TplExt string
- _xsrf_token string
- gotofunc string
- CruSession session.SessionStore
- XSRFExpire int
- AppController interface{}
EnableRender bool
- EnableXSRF bool
- methodMapping map[string]func() //method:routertree
+
+ // xsrf data
+ _xsrfToken string
+ XSRFExpire int
+ EnableXSRF bool
+
+ // session
+ CruSession session.Store
}
// ControllerInterface is an interface to uniform all controller handler.
@@ -87,8 +96,8 @@ type ControllerInterface interface {
Options()
Finish()
Render() error
- XsrfToken() string
- CheckXsrfCookie() bool
+ XSRFToken() string
+ CheckXSRFCookie() bool
HandlerFunc(fn string) bool
URLMapping()
}
@@ -96,7 +105,7 @@ type ControllerInterface interface {
// Init generates default values of controller operations.
func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) {
c.Layout = ""
- c.TplNames = ""
+ c.TplName = ""
c.controllerName = controllerName
c.actionName = actionName
c.Ctx = ctx
@@ -104,19 +113,15 @@ func (c *Controller) Init(ctx *context.Context, controllerName, actionName strin
c.AppController = app
c.EnableRender = true
c.EnableXSRF = true
- c.Data = ctx.Input.Data
+ c.Data = ctx.Input.Data()
c.methodMapping = make(map[string]func())
}
// Prepare runs after Init before request function execution.
-func (c *Controller) Prepare() {
-
-}
+func (c *Controller) Prepare() {}
// Finish runs after request function execution.
-func (c *Controller) Finish() {
-
-}
+func (c *Controller) Finish() {}
// Get adds a request function to handle GET request.
func (c *Controller) Get() {
@@ -153,20 +158,19 @@ func (c *Controller) Options() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405)
}
-// call function fn
+// HandlerFunc call function with the name
func (c *Controller) HandlerFunc(fnname string) bool {
if v, ok := c.methodMapping[fnname]; ok {
v()
return true
- } else {
- return false
}
+ return false
}
// URLMapping register the internal Controller router.
-func (c *Controller) URLMapping() {
-}
+func (c *Controller) URLMapping() {}
+// Mapping the method to function
func (c *Controller) Mapping(method string, fn func()) {
c.methodMapping[method] = fn
}
@@ -177,13 +181,11 @@ func (c *Controller) Render() error {
return nil
}
rb, err := c.RenderBytes()
-
if err != nil {
return err
- } else {
- c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8")
- c.Ctx.Output.Body(rb)
}
+ c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8")
+ c.Ctx.Output.Body(rb)
return nil
}
@@ -196,24 +198,33 @@ func (c *Controller) RenderString() (string, error) {
// RenderBytes returns the bytes of rendered template string. Do not send out response.
func (c *Controller) RenderBytes() ([]byte, error) {
//if the controller has set layout, then first get the tplname's content set the content to the layout
+ var buf bytes.Buffer
if c.Layout != "" {
- if c.TplNames == "" {
- c.TplNames = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt
+ if c.TplName == "" {
+ c.TplName = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt
}
- if RunMode == "dev" {
- BuildTemplate(ViewsPath)
+
+ if BConfig.RunMode == DEV {
+ buildFiles := []string{c.TplName}
+ if c.LayoutSections != nil {
+ for _, sectionTpl := range c.LayoutSections {
+ if sectionTpl == "" {
+ continue
+ }
+ buildFiles = append(buildFiles, sectionTpl)
+ }
+ }
+ BuildTemplate(BConfig.WebConfig.ViewsPath, buildFiles...)
}
- newbytes := bytes.NewBufferString("")
- if _, ok := BeeTemplates[c.TplNames]; !ok {
- panic("can't find templatefile in the path:" + c.TplNames)
+ if _, ok := BeeTemplates[c.TplName]; !ok {
+ panic("can't find templatefile in the path:" + c.TplName)
}
- err := BeeTemplates[c.TplNames].ExecuteTemplate(newbytes, c.TplNames, c.Data)
+ err := BeeTemplates[c.TplName].ExecuteTemplate(&buf, c.TplName, c.Data)
if err != nil {
Trace("template Execute err:", err)
return nil, err
}
- tplcontent, _ := ioutil.ReadAll(newbytes)
- c.Data["LayoutContent"] = template.HTML(string(tplcontent))
+ c.Data["LayoutContent"] = template.HTML(buf.String())
if c.LayoutSections != nil {
for sectionName, sectionTpl := range c.LayoutSections {
@@ -222,44 +233,41 @@ func (c *Controller) RenderBytes() ([]byte, error) {
continue
}
- sectionBytes := bytes.NewBufferString("")
- err = BeeTemplates[sectionTpl].ExecuteTemplate(sectionBytes, sectionTpl, c.Data)
+ buf.Reset()
+ err = BeeTemplates[sectionTpl].ExecuteTemplate(&buf, sectionTpl, c.Data)
if err != nil {
Trace("template Execute err:", err)
return nil, err
}
- sectionContent, _ := ioutil.ReadAll(sectionBytes)
- c.Data[sectionName] = template.HTML(string(sectionContent))
+ c.Data[sectionName] = template.HTML(buf.String())
}
}
- ibytes := bytes.NewBufferString("")
- err = BeeTemplates[c.Layout].ExecuteTemplate(ibytes, c.Layout, c.Data)
+ buf.Reset()
+ err = BeeTemplates[c.Layout].ExecuteTemplate(&buf, c.Layout, c.Data)
if err != nil {
Trace("template Execute err:", err)
return nil, err
}
- icontent, _ := ioutil.ReadAll(ibytes)
- return icontent, nil
- } else {
- if c.TplNames == "" {
- c.TplNames = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt
- }
- if RunMode == "dev" {
- BuildTemplate(ViewsPath)
- }
- ibytes := bytes.NewBufferString("")
- if _, ok := BeeTemplates[c.TplNames]; !ok {
- panic("can't find templatefile in the path:" + c.TplNames)
- }
- err := BeeTemplates[c.TplNames].ExecuteTemplate(ibytes, c.TplNames, c.Data)
- if err != nil {
- Trace("template Execute err:", err)
- return nil, err
- }
- icontent, _ := ioutil.ReadAll(ibytes)
- return icontent, nil
+ return buf.Bytes(), nil
}
+
+ if c.TplName == "" {
+ c.TplName = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt
+ }
+ if BConfig.RunMode == DEV {
+ BuildTemplate(BConfig.WebConfig.ViewsPath, c.TplName)
+ }
+ if _, ok := BeeTemplates[c.TplName]; !ok {
+ panic("can't find templatefile in the path:" + c.TplName)
+ }
+ buf.Reset()
+ err := BeeTemplates[c.TplName].ExecuteTemplate(&buf, c.TplName, c.Data)
+ if err != nil {
+ Trace("template Execute err:", err)
+ return nil, err
+ }
+ return buf.Bytes(), nil
}
// Redirect sends the redirection response to url with status code.
@@ -267,7 +275,7 @@ func (c *Controller) Redirect(url string, code int) {
c.Ctx.Redirect(code, url)
}
-// Aborts stops controller handler and show the error data if code is defined in ErrorMap or code string.
+// Abort stops controller handler and show the error data if code is defined in ErrorMap or code string.
func (c *Controller) Abort(code string) {
status, err := strconv.Atoi(code)
if err != nil {
@@ -285,74 +293,69 @@ func (c *Controller) CustomAbort(status int, body string) {
}
// last panic user string
c.Ctx.ResponseWriter.Write([]byte(body))
- panic(USERSTOPRUN)
+ panic(ErrAbort)
}
// StopRun makes panic of USERSTOPRUN error and go to recover function if defined.
func (c *Controller) StopRun() {
- panic(USERSTOPRUN)
+ panic(ErrAbort)
}
-// UrlFor does another controller handler in this request function.
+// URLFor does another controller handler in this request function.
// it goes to this controller method if endpoint is not clear.
-func (c *Controller) UrlFor(endpoint string, values ...interface{}) string {
- if len(endpoint) <= 0 {
+func (c *Controller) URLFor(endpoint string, values ...interface{}) string {
+ if len(endpoint) == 0 {
return ""
}
if endpoint[0] == '.' {
- return UrlFor(reflect.Indirect(reflect.ValueOf(c.AppController)).Type().Name()+endpoint, values...)
- } else {
- return UrlFor(endpoint, values...)
+ return URLFor(reflect.Indirect(reflect.ValueOf(c.AppController)).Type().Name()+endpoint, values...)
}
+ return URLFor(endpoint, values...)
}
-// ServeJson sends a json response with encoding charset.
-func (c *Controller) ServeJson(encoding ...bool) {
- var hasIndent bool
- var hasencoding bool
- if RunMode == "prod" {
+// ServeJSON sends a json response with encoding charset.
+func (c *Controller) ServeJSON(encoding ...bool) {
+ var (
+ hasIndent = true
+ hasEncoding = false
+ )
+ if BConfig.RunMode == PROD {
hasIndent = false
- } else {
- hasIndent = true
}
if len(encoding) > 0 && encoding[0] == true {
- hasencoding = true
+ hasEncoding = true
}
- c.Ctx.Output.Json(c.Data["json"], hasIndent, hasencoding)
+ c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding)
}
-// ServeJsonp sends a jsonp response.
-func (c *Controller) ServeJsonp() {
- var hasIndent bool
- if RunMode == "prod" {
+// ServeJSONP sends a jsonp response.
+func (c *Controller) ServeJSONP() {
+ hasIndent := true
+ if BConfig.RunMode == PROD {
hasIndent = false
- } else {
- hasIndent = true
}
- c.Ctx.Output.Jsonp(c.Data["jsonp"], hasIndent)
+ c.Ctx.Output.JSONP(c.Data["jsonp"], hasIndent)
}
-// ServeXml sends xml response.
-func (c *Controller) ServeXml() {
- var hasIndent bool
- if RunMode == "prod" {
+// ServeXML sends xml response.
+func (c *Controller) ServeXML() {
+ hasIndent := true
+ if BConfig.RunMode == PROD {
hasIndent = false
- } else {
- hasIndent = true
}
- c.Ctx.Output.Xml(c.Data["xml"], hasIndent)
+ c.Ctx.Output.XML(c.Data["xml"], hasIndent)
}
// ServeFormatted serve Xml OR Json, depending on the value of the Accept header
func (c *Controller) ServeFormatted() {
accept := c.Ctx.Input.Header("Accept")
switch accept {
- case applicationJson:
- c.ServeJson()
- case applicationXml, textXml:
- c.ServeXml()
+ case applicationJSON:
+ c.ServeJSON()
+ case applicationXML, textXML:
+ c.ServeXML()
default:
- c.ServeJson()
+ c.ServeJSON()
}
}
@@ -371,16 +374,13 @@ func (c *Controller) ParseForm(obj interface{}) error {
// GetString returns the input value by key string or the default value while it's present and input is blank
func (c *Controller) GetString(key string, def ...string) string {
- var defv string
- if len(def) > 0 {
- defv = def[0]
- }
-
if v := c.Ctx.Input.Query(key); v != "" {
return v
- } else {
- return defv
}
+ if len(def) > 0 {
+ return def[0]
+ }
+ return ""
}
// GetStrings returns the input string slice by key string or the default value while it's present and input is blank
@@ -391,122 +391,79 @@ func (c *Controller) GetStrings(key string, def ...[]string) []string {
defv = def[0]
}
- f := c.Input()
- if f == nil {
+ if f := c.Input(); f == nil {
return defv
+ } else if vs := f[key]; len(vs) > 0 {
+ return vs
}
- vs := f[key]
- if len(vs) > 0 {
- return vs
- } else {
- return defv
- }
+ return defv
}
// GetInt returns input as an int or the default value while it's present and input is blank
func (c *Controller) GetInt(key string, def ...int) (int, error) {
- var defv int
- if len(def) > 0 {
- defv = def[0]
- }
-
- if strv := c.Ctx.Input.Query(key); strv != "" {
- return strconv.Atoi(strv)
- } else {
- return defv, nil
+ strv := c.Ctx.Input.Query(key)
+ if len(strv) == 0 && len(def) > 0 {
+ return def[0], nil
}
+ return strconv.Atoi(strv)
}
// GetInt8 return input as an int8 or the default value while it's present and input is blank
func (c *Controller) GetInt8(key string, def ...int8) (int8, error) {
- var defv int8
- if len(def) > 0 {
- defv = def[0]
- }
-
- if strv := c.Ctx.Input.Query(key); strv != "" {
- i64, err := strconv.ParseInt(strv, 10, 8)
- i8 := int8(i64)
- return i8, err
- } else {
- return defv, nil
+ strv := c.Ctx.Input.Query(key)
+ if len(strv) == 0 && len(def) > 0 {
+ return def[0], nil
}
+ i64, err := strconv.ParseInt(strv, 10, 8)
+ return int8(i64), err
}
// GetInt16 returns input as an int16 or the default value while it's present and input is blank
func (c *Controller) GetInt16(key string, def ...int16) (int16, error) {
- var defv int16
- if len(def) > 0 {
- defv = def[0]
- }
-
- if strv := c.Ctx.Input.Query(key); strv != "" {
- i64, err := strconv.ParseInt(strv, 10, 16)
- i16 := int16(i64)
-
- return i16, err
- } else {
- return defv, nil
+ strv := c.Ctx.Input.Query(key)
+ if len(strv) == 0 && len(def) > 0 {
+ return def[0], nil
}
+ i64, err := strconv.ParseInt(strv, 10, 16)
+ return int16(i64), err
}
// GetInt32 returns input as an int32 or the default value while it's present and input is blank
func (c *Controller) GetInt32(key string, def ...int32) (int32, error) {
- var defv int32
- if len(def) > 0 {
- defv = def[0]
- }
-
- if strv := c.Ctx.Input.Query(key); strv != "" {
- i64, err := strconv.ParseInt(c.Ctx.Input.Query(key), 10, 32)
- i32 := int32(i64)
- return i32, err
- } else {
- return defv, nil
+ strv := c.Ctx.Input.Query(key)
+ if len(strv) == 0 && len(def) > 0 {
+ return def[0], nil
}
+ i64, err := strconv.ParseInt(strv, 10, 32)
+ return int32(i64), err
}
// GetInt64 returns input value as int64 or the default value while it's present and input is blank.
func (c *Controller) GetInt64(key string, def ...int64) (int64, error) {
- var defv int64
- if len(def) > 0 {
- defv = def[0]
- }
-
- if strv := c.Ctx.Input.Query(key); strv != "" {
- return strconv.ParseInt(strv, 10, 64)
- } else {
- return defv, nil
+ strv := c.Ctx.Input.Query(key)
+ if len(strv) == 0 && len(def) > 0 {
+ return def[0], nil
}
+ return strconv.ParseInt(strv, 10, 64)
}
// GetBool returns input value as bool or the default value while it's present and input is blank.
func (c *Controller) GetBool(key string, def ...bool) (bool, error) {
- var defv bool
- if len(def) > 0 {
- defv = def[0]
- }
-
- if strv := c.Ctx.Input.Query(key); strv != "" {
- return strconv.ParseBool(strv)
- } else {
- return defv, nil
+ strv := c.Ctx.Input.Query(key)
+ if len(strv) == 0 && len(def) > 0 {
+ return def[0], nil
}
+ return strconv.ParseBool(strv)
}
// GetFloat returns input value as float64 or the default value while it's present and input is blank.
func (c *Controller) GetFloat(key string, def ...float64) (float64, error) {
- var defv float64
- if len(def) > 0 {
- defv = def[0]
- }
-
- if strv := c.Ctx.Input.Query(key); strv != "" {
- return strconv.ParseFloat(c.Ctx.Input.Query(key), 64)
- } else {
- return defv, nil
+ strv := c.Ctx.Input.Query(key)
+ if len(strv) == 0 && len(def) > 0 {
+ return def[0], nil
}
+ return strconv.ParseFloat(strv, 64)
}
// GetFile returns the file data in file upload field named as key.
@@ -515,6 +472,40 @@ func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader,
return c.Ctx.Request.FormFile(key)
}
+// GetFiles return multi-upload files
+// files, err:=c.Getfiles("myfiles")
+// if err != nil {
+// http.Error(w, err.Error(), http.StatusNoContent)
+// return
+// }
+// for i, _ := range files {
+// //for each fileheader, get a handle to the actual file
+// file, err := files[i].Open()
+// defer file.Close()
+// if err != nil {
+// http.Error(w, err.Error(), http.StatusInternalServerError)
+// return
+// }
+// //create destination file making sure the path is writeable.
+// dst, err := os.Create("upload/" + files[i].Filename)
+// defer dst.Close()
+// if err != nil {
+// http.Error(w, err.Error(), http.StatusInternalServerError)
+// return
+// }
+// //copy the uploaded file to the destination file
+// if _, err := io.Copy(dst, file); err != nil {
+// http.Error(w, err.Error(), http.StatusInternalServerError)
+// return
+// }
+// }
+func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) {
+ if files, ok := c.Ctx.Request.MultipartForm.File[key]; ok {
+ return files, nil
+ }
+ return nil, http.ErrMissingFile
+}
+
// SaveToFile saves uploaded file to new path.
// it only operates the first one of mutil-upload form file field.
func (c *Controller) SaveToFile(fromfile, tofile string) error {
@@ -533,7 +524,7 @@ func (c *Controller) SaveToFile(fromfile, tofile string) error {
}
// StartSession starts session and load old session data info this controller.
-func (c *Controller) StartSession() session.SessionStore {
+func (c *Controller) StartSession() session.Store {
if c.CruSession == nil {
c.CruSession = c.Ctx.Input.CruSession
}
@@ -556,7 +547,7 @@ func (c *Controller) GetSession(name interface{}) interface{} {
return c.CruSession.Get(name)
}
-// SetSession removes value from session.
+// DelSession removes value from session.
func (c *Controller) DelSession(name interface{}) {
if c.CruSession == nil {
c.StartSession()
@@ -570,13 +561,14 @@ func (c *Controller) SessionRegenerateID() {
if c.CruSession != nil {
c.CruSession.SessionRelease(c.Ctx.ResponseWriter)
}
- c.CruSession = GlobalSessions.SessionRegenerateId(c.Ctx.ResponseWriter, c.Ctx.Request)
+ c.CruSession = GlobalSessions.SessionRegenerateID(c.Ctx.ResponseWriter, c.Ctx.Request)
c.Ctx.Input.CruSession = c.CruSession
}
// DestroySession cleans session data and session cookie.
func (c *Controller) DestroySession() {
c.Ctx.Input.CruSession.Flush()
+ c.Ctx.Input.CruSession = nil
GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request)
}
@@ -595,37 +587,35 @@ func (c *Controller) SetSecureCookie(Secret, name, value string, others ...inter
c.Ctx.SetSecureCookie(Secret, name, value, others...)
}
-// XsrfToken creates a xsrf token string and returns.
-func (c *Controller) XsrfToken() string {
- if c._xsrf_token == "" {
- var expire int64
+// XSRFToken creates a CSRF token string and returns.
+func (c *Controller) XSRFToken() string {
+ if c._xsrfToken == "" {
+ expire := int64(BConfig.WebConfig.XSRFExpire)
if c.XSRFExpire > 0 {
expire = int64(c.XSRFExpire)
- } else {
- expire = int64(XSRFExpire)
}
- c._xsrf_token = c.Ctx.XsrfToken(XSRFKEY, expire)
+ c._xsrfToken = c.Ctx.XSRFToken(BConfig.WebConfig.XSRFKey, expire)
}
- return c._xsrf_token
+ return c._xsrfToken
}
-// CheckXsrfCookie checks xsrf token in this request is valid or not.
+// CheckXSRFCookie checks xsrf token in this request is valid or not.
// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken"
// or in form field value named as "_xsrf".
-func (c *Controller) CheckXsrfCookie() bool {
+func (c *Controller) CheckXSRFCookie() bool {
if !c.EnableXSRF {
return true
}
- return c.Ctx.CheckXsrfCookie()
+ return c.Ctx.CheckXSRFCookie()
}
-// XsrfFormHtml writes an input field contains xsrf token value.
-func (c *Controller) XsrfFormHtml() string {
- return ""
+// XSRFFormHTML writes an input field contains xsrf token value.
+func (c *Controller) XSRFFormHTML() string {
+ return ``
}
// GetControllerAndAction gets the executing controller name and action name.
-func (c *Controller) GetControllerAndAction() (controllerName, actionName string) {
+func (c *Controller) GetControllerAndAction() (string, string) {
return c.controllerName, c.actionName
}
diff --git a/controller_test.go b/controller_test.go
index 15938cdc..51d3a5b7 100644
--- a/controller_test.go
+++ b/controller_test.go
@@ -15,61 +15,63 @@
package beego
import (
- "fmt"
+ "testing"
+
"github.com/astaxie/beego/context"
)
-func ExampleGetInt() {
-
- i := &context.BeegoInput{Params: map[string]string{"age": "40"}}
+func TestGetInt(t *testing.T) {
+ i := context.NewInput()
+ i.SetParam("age", "40")
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
-
val, _ := ctrlr.GetInt("age")
- fmt.Printf("%T", val)
- //Output: int
+ if val != 40 {
+ t.Errorf("TestGetInt expect 40,get %T,%v", val, val)
+ }
}
-func ExampleGetInt8() {
-
- i := &context.BeegoInput{Params: map[string]string{"age": "40"}}
+func TestGetInt8(t *testing.T) {
+ i := context.NewInput()
+ i.SetParam("age", "40")
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
-
val, _ := ctrlr.GetInt8("age")
- fmt.Printf("%T", val)
+ if val != 40 {
+ t.Errorf("TestGetInt8 expect 40,get %T,%v", val, val)
+ }
//Output: int8
}
-func ExampleGetInt16() {
-
- i := &context.BeegoInput{Params: map[string]string{"age": "40"}}
+func TestGetInt16(t *testing.T) {
+ i := context.NewInput()
+ i.SetParam("age", "40")
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
-
val, _ := ctrlr.GetInt16("age")
- fmt.Printf("%T", val)
- //Output: int16
+ if val != 40 {
+ t.Errorf("TestGetInt16 expect 40,get %T,%v", val, val)
+ }
}
-func ExampleGetInt32() {
-
- i := &context.BeegoInput{Params: map[string]string{"age": "40"}}
+func TestGetInt32(t *testing.T) {
+ i := context.NewInput()
+ i.SetParam("age", "40")
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
-
val, _ := ctrlr.GetInt32("age")
- fmt.Printf("%T", val)
- //Output: int32
+ if val != 40 {
+ t.Errorf("TestGetInt32 expect 40,get %T,%v", val, val)
+ }
}
-func ExampleGetInt64() {
-
- i := &context.BeegoInput{Params: map[string]string{"age": "40"}}
+func TestGetInt64(t *testing.T) {
+ i := context.NewInput()
+ i.SetParam("age", "40")
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
-
val, _ := ctrlr.GetInt64("age")
- fmt.Printf("%T", val)
- //Output: int64
+ if val != 40 {
+ t.Errorf("TestGeetInt64 expect 40,get %T,%v", val, val)
+ }
}
diff --git a/doc.go b/doc.go
new file mode 100644
index 00000000..8825bd29
--- /dev/null
+++ b/doc.go
@@ -0,0 +1,17 @@
+/*
+Package beego provide a MVC framework
+beego: an open-source, high-performance, modular, full-stack web framework
+
+It is used for rapid development of RESTful APIs, web apps and backend services in Go.
+beego is inspired by Tornado, Sinatra and Flask with the added benefit of some Go-specific features such as interfaces and struct embedding.
+
+ package main
+ import "github.com/astaxie/beego"
+
+ func main() {
+ beego.Run()
+ }
+
+more information: http://beego.me
+*/
+package beego
diff --git a/docs.go b/docs.go
index aaad205e..72532876 100644
--- a/docs.go
+++ b/docs.go
@@ -15,37 +15,24 @@
package beego
import (
- "encoding/json"
-
"github.com/astaxie/beego/context"
)
-var GlobalDocApi map[string]interface{}
-
-func init() {
- if EnableDocs {
- GlobalDocApi = make(map[string]interface{})
- }
-}
+// GlobalDocAPI store the swagger api documents
+var GlobalDocAPI = make(map[string]interface{})
func serverDocs(ctx *context.Context) {
var obj interface{}
if splat := ctx.Input.Param(":splat"); splat == "" {
- obj = GlobalDocApi["Root"]
+ obj = GlobalDocAPI["Root"]
} else {
- if v, ok := GlobalDocApi[splat]; ok {
+ if v, ok := GlobalDocAPI[splat]; ok {
obj = v
}
}
if obj != nil {
- bt, err := json.Marshal(obj)
- if err != nil {
- ctx.Output.SetStatus(504)
- return
- }
- ctx.Output.Header("Content-Type", "application/json;charset=UTF-8")
ctx.Output.Header("Access-Control-Allow-Origin", "*")
- ctx.Output.Body(bt)
+ ctx.Output.JSON(obj, false, false)
return
}
ctx.Output.SetStatus(404)
diff --git a/error.go b/error.go
index 71be6916..94151dd8 100644
--- a/error.go
+++ b/error.go
@@ -82,17 +82,18 @@ var tpl = `
`
// render default application error page with error and stack string.
-func showErr(err interface{}, ctx *context.Context, Stack string) {
+func showErr(err interface{}, ctx *context.Context, stack string) {
t, _ := template.New("beegoerrortemp").Parse(tpl)
- data := make(map[string]string)
- data["AppError"] = AppName + ":" + fmt.Sprint(err)
- data["RequestMethod"] = ctx.Input.Method()
- data["RequestURL"] = ctx.Input.Uri()
- data["RemoteAddr"] = ctx.Input.IP()
- data["Stack"] = Stack
- data["BeegoVersion"] = VERSION
- data["GoVersion"] = runtime.Version()
- ctx.Output.SetStatus(500)
+ data := map[string]string{
+ "AppError": fmt.Sprintf("%s:%v", BConfig.AppName, err),
+ "RequestMethod": ctx.Input.Method(),
+ "RequestURL": ctx.Input.URI(),
+ "RemoteAddr": ctx.Input.IP(),
+ "Stack": stack,
+ "BeegoVersion": VERSION,
+ "GoVersion": runtime.Version(),
+ }
+ ctx.ResponseWriter.WriteHeader(500)
t.Execute(ctx.ResponseWriter, data)
}
@@ -203,48 +204,49 @@ type errorInfo struct {
errorType int
}
-// map of http handlers for each error string.
-var ErrorMaps map[string]*errorInfo
-
-func init() {
- ErrorMaps = make(map[string]*errorInfo)
-}
+// ErrorMaps holds map of http handlers for each error string.
+// there is 10 kinds default error(40x and 50x)
+var ErrorMaps = make(map[string]*errorInfo, 10)
// show 401 unauthorized error.
func unauthorized(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := make(map[string]interface{})
- data["Title"] = "Unauthorized"
+ data := map[string]interface{}{
+ "Title": http.StatusText(401),
+ "BeegoVersion": VERSION,
+ }
data["Content"] = template.HTML(" The page you have requested can't be authorized." +
" Perhaps you are here because:" +
"
" +
" The credentials you supplied are incorrect" +
" There are errors in the website address" +
"
")
- data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 402 Payment Required
func paymentRequired(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := make(map[string]interface{})
- data["Title"] = "Payment Required"
+ data := map[string]interface{}{
+ "Title": http.StatusText(402),
+ "BeegoVersion": VERSION,
+ }
data["Content"] = template.HTML(" The page you have requested Payment Required." +
" Perhaps you are here because:" +
"
" +
" The credentials you supplied are incorrect" +
" There are errors in the website address" +
"
")
- data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 403 forbidden error.
func forbidden(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := make(map[string]interface{})
- data["Title"] = "Forbidden"
+ data := map[string]interface{}{
+ "Title": http.StatusText(403),
+ "BeegoVersion": VERSION,
+ }
data["Content"] = template.HTML(" The page you have requested is forbidden." +
" Perhaps you are here because:" +
"
" +
@@ -252,15 +254,16 @@ func forbidden(rw http.ResponseWriter, r *http.Request) {
" The site may be disabled" +
" You need to log in" +
"
")
- data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 404 notfound error.
func notFound(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := make(map[string]interface{})
- data["Title"] = "Page Not Found"
+ data := map[string]interface{}{
+ "Title": http.StatusText(404),
+ "BeegoVersion": VERSION,
+ }
data["Content"] = template.HTML(" The page you have requested has flown the coop." +
" Perhaps you are here because:" +
"
" +
@@ -269,194 +272,163 @@ func notFound(rw http.ResponseWriter, r *http.Request) {
" You were looking for your puppy and got lost" +
" You like 404 pages" +
"
")
- data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 405 Method Not Allowed
func methodNotAllowed(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := make(map[string]interface{})
- data["Title"] = "Method Not Allowed"
+ data := map[string]interface{}{
+ "Title": http.StatusText(405),
+ "BeegoVersion": VERSION,
+ }
data["Content"] = template.HTML(" The method you have requested Not Allowed." +
" Perhaps you are here because:" +
"
" +
" The method specified in the Request-Line is not allowed for the resource identified by the Request-URI" +
" The response MUST include an Allow header containing a list of valid methods for the requested resource." +
"
")
- data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 500 internal server error.
func internalServerError(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := make(map[string]interface{})
- data["Title"] = "Internal Server Error"
+ data := map[string]interface{}{
+ "Title": http.StatusText(500),
+ "BeegoVersion": VERSION,
+ }
data["Content"] = template.HTML(" The page you have requested is down right now." +
"
" +
" Please try again later and report the error to the website administrator" +
"
")
- data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 501 Not Implemented.
func notImplemented(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := make(map[string]interface{})
- data["Title"] = "Not Implemented"
+ data := map[string]interface{}{
+ "Title": http.StatusText(504),
+ "BeegoVersion": VERSION,
+ }
data["Content"] = template.HTML(" The page you have requested is Not Implemented." +
"
" +
" Please try again later and report the error to the website administrator" +
"
")
- data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 502 Bad Gateway.
func badGateway(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := make(map[string]interface{})
- data["Title"] = "Bad Gateway"
+ data := map[string]interface{}{
+ "Title": http.StatusText(502),
+ "BeegoVersion": VERSION,
+ }
data["Content"] = template.HTML(" The page you have requested is down right now." +
"
" +
" The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request." +
" Please try again later and report the error to the website administrator" +
"
")
- data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 503 service unavailable error.
func serviceUnavailable(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := make(map[string]interface{})
- data["Title"] = "Service Unavailable"
+ data := map[string]interface{}{
+ "Title": http.StatusText(503),
+ "BeegoVersion": VERSION,
+ }
data["Content"] = template.HTML(" The page you have requested is unavailable." +
" Perhaps you are here because:" +
"
" +
"
The page is overloaded" +
" Please try again later." +
"
")
- data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 504 Gateway Timeout.
func gatewayTimeout(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
- data := make(map[string]interface{})
- data["Title"] = "Gateway Timeout"
+ data := map[string]interface{}{
+ "Title": http.StatusText(504),
+ "BeegoVersion": VERSION,
+ }
data["Content"] = template.HTML(" The page you have requested is unavailable." +
" Perhaps you are here because:" +
"
" +
"
The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI." +
" Please try again later." +
"
")
- data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
-// register default error http handlers, 404,401,403,500 and 503.
-func registerDefaultErrorHandler() {
- if _, ok := ErrorMaps["401"]; !ok {
- Errorhandler("401", unauthorized)
- }
-
- if _, ok := ErrorMaps["402"]; !ok {
- Errorhandler("402", paymentRequired)
- }
-
- if _, ok := ErrorMaps["403"]; !ok {
- Errorhandler("403", forbidden)
- }
-
- if _, ok := ErrorMaps["404"]; !ok {
- Errorhandler("404", notFound)
- }
-
- if _, ok := ErrorMaps["405"]; !ok {
- Errorhandler("405", methodNotAllowed)
- }
-
- if _, ok := ErrorMaps["500"]; !ok {
- Errorhandler("500", internalServerError)
- }
- if _, ok := ErrorMaps["501"]; !ok {
- Errorhandler("501", notImplemented)
- }
- if _, ok := ErrorMaps["502"]; !ok {
- Errorhandler("502", badGateway)
- }
-
- if _, ok := ErrorMaps["503"]; !ok {
- Errorhandler("503", serviceUnavailable)
- }
-
- if _, ok := ErrorMaps["504"]; !ok {
- Errorhandler("504", gatewayTimeout)
- }
-}
-
// ErrorHandler registers http.HandlerFunc to each http err code string.
// usage:
// beego.ErrorHandler("404",NotFound)
// beego.ErrorHandler("500",InternalServerError)
-func Errorhandler(code string, h http.HandlerFunc) *App {
- errinfo := &errorInfo{}
- errinfo.errorType = errorTypeHandler
- errinfo.handler = h
- errinfo.method = code
- ErrorMaps[code] = errinfo
+func ErrorHandler(code string, h http.HandlerFunc) *App {
+ ErrorMaps[code] = &errorInfo{
+ errorType: errorTypeHandler,
+ handler: h,
+ method: code,
+ }
return BeeApp
}
// ErrorController registers ControllerInterface to each http err code string.
// usage:
-// beego.ErrorHandler(&controllers.ErrorController{})
+// beego.ErrorController(&controllers.ErrorController{})
func ErrorController(c ControllerInterface) *App {
reflectVal := reflect.ValueOf(c)
rt := reflectVal.Type()
ct := reflect.Indirect(reflectVal).Type()
for i := 0; i < rt.NumMethod(); i++ {
- if !utils.InSlice(rt.Method(i).Name, exceptMethod) && strings.HasPrefix(rt.Method(i).Name, "Error") {
- errinfo := &errorInfo{}
- errinfo.errorType = errorTypeController
- errinfo.controllerType = ct
- errinfo.method = rt.Method(i).Name
- errname := strings.TrimPrefix(rt.Method(i).Name, "Error")
- ErrorMaps[errname] = errinfo
+ methodName := rt.Method(i).Name
+ if !utils.InSlice(methodName, exceptMethod) && strings.HasPrefix(methodName, "Error") {
+ errName := strings.TrimPrefix(methodName, "Error")
+ ErrorMaps[errName] = &errorInfo{
+ errorType: errorTypeController,
+ controllerType: ct,
+ method: methodName,
+ }
}
}
return BeeApp
}
// show error string as simple text message.
-// if error string is empty, show 500 error as default.
-func exception(errcode string, ctx *context.Context) {
- code, err := strconv.Atoi(errcode)
- if err != nil {
- code = 503
+// if error string is empty, show 503 or 500 error as default.
+func exception(errCode string, ctx *context.Context) {
+ atoi := func(code string) int {
+ v, err := strconv.Atoi(code)
+ if err == nil {
+ return v
+ }
+ return 503
}
- ctx.ResponseWriter.WriteHeader(code)
- if h, ok := ErrorMaps[errcode]; ok {
- executeError(h, ctx)
- return
- } else if h, ok := ErrorMaps["503"]; ok {
- executeError(h, ctx)
- return
- } else {
- ctx.WriteString(errcode)
+
+ for _, ec := range []string{errCode, "503", "500"} {
+ if h, ok := ErrorMaps[ec]; ok {
+ executeError(h, ctx, atoi(ec))
+ return
+ }
}
+ //if 50x error has been removed from errorMap
+ ctx.ResponseWriter.WriteHeader(atoi(errCode))
+ ctx.WriteString(errCode)
}
-func executeError(err *errorInfo, ctx *context.Context) {
+func executeError(err *errorInfo, ctx *context.Context, code int) {
if err.errorType == errorTypeHandler {
err.handler(ctx.ResponseWriter, ctx.Request)
return
}
if err.errorType == errorTypeController {
+ ctx.Output.SetStatus(code)
//Invoke the request handler
vc := reflect.New(err.controllerType)
execController, ok := vc.Interface().(ControllerInterface)
@@ -471,16 +443,13 @@ func executeError(err *errorInfo, ctx *context.Context) {
execController.URLMapping()
- in := make([]reflect.Value, 0)
method := vc.MethodByName(err.method)
- method.Call(in)
+ method.Call([]reflect.Value{})
//render template
- if ctx.Output.Status == 0 {
- if AutoRender {
- if err := execController.Render(); err != nil {
- panic(err)
- }
+ if BConfig.WebConfig.AutoRender {
+ if err := execController.Render(); err != nil {
+ panic(err)
}
}
diff --git a/example/beeapi/conf/app.conf b/example/beeapi/conf/app.conf
deleted file mode 100644
index fd1681a2..00000000
--- a/example/beeapi/conf/app.conf
+++ /dev/null
@@ -1,5 +0,0 @@
-appname = beeapi
-httpport = 8080
-runmode = dev
-autorender = false
-copyrequestbody = true
diff --git a/example/beeapi/controllers/default.go b/example/beeapi/controllers/default.go
deleted file mode 100644
index a2184075..00000000
--- a/example/beeapi/controllers/default.go
+++ /dev/null
@@ -1,63 +0,0 @@
-// Beego (http://beego.me/)
-// @description beego is an open-source, high-performance web framework for the Go programming language.
-// @link http://github.com/astaxie/beego for the canonical source repository
-// @license http://github.com/astaxie/beego/blob/master/LICENSE
-// @authors astaxie
-
-package controllers
-
-import (
- "encoding/json"
-
- "github.com/astaxie/beego"
- "github.com/astaxie/beego/example/beeapi/models"
-)
-
-type ObjectController struct {
- beego.Controller
-}
-
-func (o *ObjectController) Post() {
- var ob models.Object
- json.Unmarshal(o.Ctx.Input.RequestBody, &ob)
- objectid := models.AddOne(ob)
- o.Data["json"] = map[string]string{"ObjectId": objectid}
- o.ServeJson()
-}
-
-func (o *ObjectController) Get() {
- objectId := o.Ctx.Input.Params[":objectId"]
- if objectId != "" {
- ob, err := models.GetOne(objectId)
- if err != nil {
- o.Data["json"] = err
- } else {
- o.Data["json"] = ob
- }
- } else {
- obs := models.GetAll()
- o.Data["json"] = obs
- }
- o.ServeJson()
-}
-
-func (o *ObjectController) Put() {
- objectId := o.Ctx.Input.Params[":objectId"]
- var ob models.Object
- json.Unmarshal(o.Ctx.Input.RequestBody, &ob)
-
- err := models.Update(objectId, ob.Score)
- if err != nil {
- o.Data["json"] = err
- } else {
- o.Data["json"] = "update success!"
- }
- o.ServeJson()
-}
-
-func (o *ObjectController) Delete() {
- objectId := o.Ctx.Input.Params[":objectId"]
- models.Delete(objectId)
- o.Data["json"] = "delete success!"
- o.ServeJson()
-}
diff --git a/example/beeapi/main.go b/example/beeapi/main.go
deleted file mode 100644
index c1250e03..00000000
--- a/example/beeapi/main.go
+++ /dev/null
@@ -1,30 +0,0 @@
-// Beego (http://beego.me/)
-
-// @description beego is an open-source, high-performance web framework for the Go programming language.
-
-// @link http://github.com/astaxie/beego for the canonical source repository
-
-// @license http://github.com/astaxie/beego/blob/master/LICENSE
-
-// @authors astaxie
-
-package main
-
-import (
- "github.com/astaxie/beego"
- "github.com/astaxie/beego/example/beeapi/controllers"
-)
-
-// Objects
-
-// URL HTTP Verb Functionality
-// /object POST Creating Objects
-// /object/ GET Retrieving Objects
-// /object/ PUT Updating Objects
-// /object GET Queries
-// /object/ DELETE Deleting Objects
-
-func main() {
- beego.RESTRouter("/object", &controllers.ObjectController{})
- beego.Run()
-}
diff --git a/example/beeapi/models/object.go b/example/beeapi/models/object.go
deleted file mode 100644
index 46109c50..00000000
--- a/example/beeapi/models/object.go
+++ /dev/null
@@ -1,58 +0,0 @@
-// Beego (http://beego.me/)
-// @description beego is an open-source, high-performance web framework for the Go programming language.
-// @link http://github.com/astaxie/beego for the canonical source repository
-// @license http://github.com/astaxie/beego/blob/master/LICENSE
-// @authors astaxie
-
-package models
-
-import (
- "errors"
- "strconv"
- "time"
-)
-
-var (
- Objects map[string]*Object
-)
-
-type Object struct {
- ObjectId string
- Score int64
- PlayerName string
-}
-
-func init() {
- Objects = make(map[string]*Object)
- Objects["hjkhsbnmn123"] = &Object{"hjkhsbnmn123", 100, "astaxie"}
- Objects["mjjkxsxsaa23"] = &Object{"mjjkxsxsaa23", 101, "someone"}
-}
-
-func AddOne(object Object) (ObjectId string) {
- object.ObjectId = "astaxie" + strconv.FormatInt(time.Now().UnixNano(), 10)
- Objects[object.ObjectId] = &object
- return object.ObjectId
-}
-
-func GetOne(ObjectId string) (object *Object, err error) {
- if v, ok := Objects[ObjectId]; ok {
- return v, nil
- }
- return nil, errors.New("ObjectId Not Exist")
-}
-
-func GetAll() map[string]*Object {
- return Objects
-}
-
-func Update(ObjectId string, Score int64) (err error) {
- if v, ok := Objects[ObjectId]; ok {
- v.Score = Score
- return nil
- }
- return errors.New("ObjectId Not Exist")
-}
-
-func Delete(ObjectId string) {
- delete(Objects, ObjectId)
-}
diff --git a/example/chat/conf/app.conf b/example/chat/conf/app.conf
deleted file mode 100644
index 9d5a823f..00000000
--- a/example/chat/conf/app.conf
+++ /dev/null
@@ -1,3 +0,0 @@
-appname = chat
-httpport = 8080
-runmode = dev
diff --git a/example/chat/controllers/default.go b/example/chat/controllers/default.go
deleted file mode 100644
index 95ef8a9b..00000000
--- a/example/chat/controllers/default.go
+++ /dev/null
@@ -1,20 +0,0 @@
-// Beego (http://beego.me/)
-// @description beego is an open-source, high-performance web framework for the Go programming language.
-// @link http://github.com/astaxie/beego for the canonical source repository
-// @license http://github.com/astaxie/beego/blob/master/LICENSE
-// @authors Unknwon
-
-package controllers
-
-import (
- "github.com/astaxie/beego"
-)
-
-type MainController struct {
- beego.Controller
-}
-
-func (m *MainController) Get() {
- m.Data["host"] = m.Ctx.Request.Host
- m.TplNames = "index.tpl"
-}
diff --git a/example/chat/controllers/ws.go b/example/chat/controllers/ws.go
deleted file mode 100644
index fc3917b3..00000000
--- a/example/chat/controllers/ws.go
+++ /dev/null
@@ -1,181 +0,0 @@
-// Beego (http://beego.me/)
-// @description beego is an open-source, high-performance web framework for the Go programming language.
-// @link http://github.com/astaxie/beego for the canonical source repository
-// @license http://github.com/astaxie/beego/blob/master/LICENSE
-// @authors Unknwon
-
-package controllers
-
-import (
- "io/ioutil"
- "math/rand"
- "net/http"
- "time"
-
- "github.com/astaxie/beego"
- "github.com/gorilla/websocket"
-)
-
-const (
- // Time allowed to write a message to the client.
- writeWait = 10 * time.Second
-
- // Time allowed to read the next message from the client.
- readWait = 60 * time.Second
-
- // Send pings to client with this period. Must be less than readWait.
- pingPeriod = (readWait * 9) / 10
-
- // Maximum message size allowed from client.
- maxMessageSize = 512
-)
-
-func init() {
- rand.Seed(time.Now().UTC().UnixNano())
- go h.run()
-}
-
-// connection is an middleman between the websocket connection and the hub.
-type connection struct {
- username string
-
- // The websocket connection.
- ws *websocket.Conn
-
- // Buffered channel of outbound messages.
- send chan []byte
-}
-
-// readPump pumps messages from the websocket connection to the hub.
-func (c *connection) readPump() {
- defer func() {
- h.unregister <- c
- c.ws.Close()
- }()
- c.ws.SetReadLimit(maxMessageSize)
- c.ws.SetReadDeadline(time.Now().Add(readWait))
- for {
- op, r, err := c.ws.NextReader()
- if err != nil {
- break
- }
- switch op {
- case websocket.PongMessage:
- c.ws.SetReadDeadline(time.Now().Add(readWait))
- case websocket.TextMessage:
- message, err := ioutil.ReadAll(r)
- if err != nil {
- break
- }
- h.broadcast <- []byte(c.username + "_" + time.Now().Format("15:04:05") + ":" + string(message))
- }
- }
-}
-
-// write writes a message with the given opCode and payload.
-func (c *connection) write(opCode int, payload []byte) error {
- c.ws.SetWriteDeadline(time.Now().Add(writeWait))
- return c.ws.WriteMessage(opCode, payload)
-}
-
-// writePump pumps messages from the hub to the websocket connection.
-func (c *connection) writePump() {
- ticker := time.NewTicker(pingPeriod)
- defer func() {
- ticker.Stop()
- c.ws.Close()
- }()
- for {
- select {
- case message, ok := <-c.send:
- if !ok {
- c.write(websocket.CloseMessage, []byte{})
- return
- }
- if err := c.write(websocket.TextMessage, message); err != nil {
- return
- }
- case <-ticker.C:
- if err := c.write(websocket.PingMessage, []byte{}); err != nil {
- return
- }
- }
- }
-}
-
-type hub struct {
- // Registered connections.
- connections map[*connection]bool
-
- // Inbound messages from the connections.
- broadcast chan []byte
-
- // Register requests from the connections.
- register chan *connection
-
- // Unregister requests from connections.
- unregister chan *connection
-}
-
-var h = &hub{
- broadcast: make(chan []byte, maxMessageSize),
- register: make(chan *connection, 1),
- unregister: make(chan *connection, 1),
- connections: make(map[*connection]bool),
-}
-
-func (h *hub) run() {
- for {
- select {
- case c := <-h.register:
- h.connections[c] = true
- case c := <-h.unregister:
- delete(h.connections, c)
- close(c.send)
- case m := <-h.broadcast:
- for c := range h.connections {
- select {
- case c.send <- m:
- default:
- close(c.send)
- delete(h.connections, c)
- }
- }
- }
- }
-}
-
-type WSController struct {
- beego.Controller
-}
-
-var upgrader = websocket.Upgrader{
- ReadBufferSize: 1024,
- WriteBufferSize: 1024,
-}
-
-func (w *WSController) Get() {
- ws, err := upgrader.Upgrade(w.Ctx.ResponseWriter, w.Ctx.Request, nil)
- if _, ok := err.(websocket.HandshakeError); ok {
- http.Error(w.Ctx.ResponseWriter, "Not a websocket handshake", 400)
- return
- } else if err != nil {
- return
- }
- c := &connection{send: make(chan []byte, 256), ws: ws, username: randomString(10)}
- h.register <- c
- go c.writePump()
- c.readPump()
-}
-
-func randomString(l int) string {
- bytes := make([]byte, l)
- for i := 0; i < l; i++ {
- bytes[i] = byte(randInt(65, 90))
- }
- return string(bytes)
-}
-
-func randInt(min int, max int) int {
- return min + rand.Intn(max-min)
-}
diff --git a/example/chat/main.go b/example/chat/main.go
deleted file mode 100644
index be055b25..00000000
--- a/example/chat/main.go
+++ /dev/null
@@ -1,17 +0,0 @@
-// Beego (http://beego.me/)
-// @description beego is an open-source, high-performance web framework for the Go programming language.
-// @link http://github.com/astaxie/beego for the canonical source repository
-// @license http://github.com/astaxie/beego/blob/master/LICENSE
-// @authors Unknwon
-package main
-
-import (
- "github.com/astaxie/beego"
- "github.com/astaxie/beego/example/chat/controllers"
-)
-
-func main() {
- beego.Router("/", &controllers.MainController{})
- beego.Router("/ws", &controllers.WSController{})
- beego.Run()
-}
diff --git a/example/chat/views/index.tpl b/example/chat/views/index.tpl
deleted file mode 100644
index 3a9d8838..00000000
--- a/example/chat/views/index.tpl
+++ /dev/null
@@ -1,92 +0,0 @@
-
-
-
-Chat Example
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/filter.go b/filter.go
index ddd61094..863223f7 100644
--- a/filter.go
+++ b/filter.go
@@ -16,11 +16,12 @@ package beego
import "github.com/astaxie/beego/context"
-// FilterFunc defines filter function type.
+// FilterFunc defines a filter function which is invoked before the controller handler is executed.
type FilterFunc func(*context.Context)
-// FilterRouter defines filter operation before controller handler execution.
-// it can match patterned url and do filter function when action arrives.
+// FilterRouter defines a filter operation which is invoked before the controller handler is executed.
+// It can match the URL against a pattern, and execute a filter function
+// when a request with a matching URL arrives.
type FilterRouter struct {
filterFunc FilterFunc
tree *Tree
@@ -28,16 +29,15 @@ type FilterRouter struct {
returnOnOutput bool
}
-// ValidRouter check current request is valid for this filter.
-// if matched, returns parsed params in this request by defined filter router pattern.
-func (f *FilterRouter) ValidRouter(router string) (bool, map[string]string) {
- isok, params := f.tree.Match(router)
- if isok == nil {
- return false, nil
- }
- if isok, ok := isok.(bool); ok {
- return isok, params
- } else {
- return false, nil
+// ValidRouter checks if the current request is matched by this filter.
+// If the request is matched, the values of the URL parameters defined
+// by the filter pattern are also returned.
+func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool {
+ isOk := f.tree.Match(url, ctx)
+ if isOk != nil {
+ if b, ok := isOk.(bool); ok {
+ return b
+ }
}
+ return false
}
diff --git a/filter_test.go b/filter_test.go
index ff6f750b..d9928d8d 100644
--- a/filter_test.go
+++ b/filter_test.go
@@ -20,10 +20,16 @@ import (
"testing"
"github.com/astaxie/beego/context"
+ "github.com/astaxie/beego/logs"
)
+func init() {
+ BeeLogger = logs.NewLogger(10000)
+ BeeLogger.SetLogger("console", "")
+}
+
var FilterUser = func(ctx *context.Context) {
- ctx.Output.Body([]byte("i am " + ctx.Input.Params[":last"] + ctx.Input.Params[":first"]))
+ ctx.Output.Body([]byte("i am " + ctx.Input.Param(":last") + ctx.Input.Param(":first")))
}
func TestFilter(t *testing.T) {
diff --git a/flash.go b/flash.go
index 5ccf339a..a6485a17 100644
--- a/flash.go
+++ b/flash.go
@@ -83,27 +83,27 @@ func (fd *FlashData) Store(c *Controller) {
c.Data["flash"] = fd.Data
var flashValue string
for key, value := range fd.Data {
- flashValue += "\x00" + key + "\x23" + FlashSeperator + "\x23" + value + "\x00"
+ flashValue += "\x00" + key + "\x23" + BConfig.WebConfig.FlashSeparator + "\x23" + value + "\x00"
}
- c.Ctx.SetCookie(FlashName, url.QueryEscape(flashValue), 0, "/")
+ c.Ctx.SetCookie(BConfig.WebConfig.FlashName, url.QueryEscape(flashValue), 0, "/")
}
// ReadFromRequest parsed flash data from encoded values in cookie.
func ReadFromRequest(c *Controller) *FlashData {
flash := NewFlash()
- if cookie, err := c.Ctx.Request.Cookie(FlashName); err == nil {
+ if cookie, err := c.Ctx.Request.Cookie(BConfig.WebConfig.FlashName); err == nil {
v, _ := url.QueryUnescape(cookie.Value)
vals := strings.Split(v, "\x00")
for _, v := range vals {
if len(v) > 0 {
- kv := strings.Split(v, "\x23"+FlashSeperator+"\x23")
+ kv := strings.Split(v, "\x23"+BConfig.WebConfig.FlashSeparator+"\x23")
if len(kv) == 2 {
flash.Data[kv[0]] = kv[1]
}
}
}
//read one time then delete it
- c.Ctx.SetCookie(FlashName, "", -1, "/")
+ c.Ctx.SetCookie(BConfig.WebConfig.FlashName, "", -1, "/")
}
c.Data["flash"] = flash.Data
return flash
diff --git a/flash_test.go b/flash_test.go
index b655f552..640d54de 100644
--- a/flash_test.go
+++ b/flash_test.go
@@ -30,7 +30,7 @@ func (t *TestFlashController) TestWriteFlash() {
flash.Notice("TestFlashString")
flash.Store(&t.Controller)
// we choose to serve json because we don't want to load a template html file
- t.ServeJson(true)
+ t.ServeJSON(true)
}
func TestFlashHeader(t *testing.T) {
diff --git a/grace/conn.go b/grace/conn.go
new file mode 100644
index 00000000..6807e1ac
--- /dev/null
+++ b/grace/conn.go
@@ -0,0 +1,28 @@
+package grace
+
+import (
+ "errors"
+ "net"
+)
+
+type graceConn struct {
+ net.Conn
+ server *Server
+}
+
+func (c graceConn) Close() (err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ switch x := r.(type) {
+ case string:
+ err = errors.New(x)
+ case error:
+ err = x
+ default:
+ err = errors.New("Unknown panic")
+ }
+ }
+ }()
+ c.server.wg.Done()
+ return c.Conn.Close()
+}
diff --git a/grace/grace.go b/grace/grace.go
new file mode 100644
index 00000000..af4e9068
--- /dev/null
+++ b/grace/grace.go
@@ -0,0 +1,158 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package grace use to hot reload
+// Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/
+//
+// Usage:
+//
+// import(
+// "log"
+// "net/http"
+// "os"
+//
+// "github.com/astaxie/beego/grace"
+// )
+//
+// func handler(w http.ResponseWriter, r *http.Request) {
+// w.Write([]byte("WORLD!"))
+// }
+//
+// func main() {
+// mux := http.NewServeMux()
+// mux.HandleFunc("/hello", handler)
+//
+// err := grace.ListenAndServe("localhost:8080", mux)
+// if err != nil {
+// log.Println(err)
+// }
+// log.Println("Server on 8080 stopped")
+// os.Exit(0)
+// }
+package grace
+
+import (
+ "flag"
+ "net/http"
+ "os"
+ "strings"
+ "sync"
+ "syscall"
+ "time"
+)
+
+const (
+ // PreSignal is the position to add filter before signal
+ PreSignal = iota
+ // PostSignal is the position to add filter after signal
+ PostSignal
+ // StateInit represent the application inited
+ StateInit
+ // StateRunning represent the application is running
+ StateRunning
+ // StateShuttingDown represent the application is shutting down
+ StateShuttingDown
+ // StateTerminate represent the application is killed
+ StateTerminate
+)
+
+var (
+ regLock *sync.Mutex
+ runningServers map[string]*Server
+ runningServersOrder []string
+ socketPtrOffsetMap map[string]uint
+ runningServersForked bool
+
+ // DefaultReadTimeOut is the HTTP read timeout
+ DefaultReadTimeOut time.Duration
+ // DefaultWriteTimeOut is the HTTP Write timeout
+ DefaultWriteTimeOut time.Duration
+ // DefaultMaxHeaderBytes is the Max HTTP Herder size, default is 0, no limit
+ DefaultMaxHeaderBytes int
+ // DefaultTimeout is the shutdown server's timeout. default is 60s
+ DefaultTimeout = 60 * time.Second
+
+ isChild bool
+ socketOrder string
+ once sync.Once
+)
+
+func onceInit() {
+ regLock = &sync.Mutex{}
+ flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)")
+ flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started")
+ runningServers = make(map[string]*Server)
+ runningServersOrder = []string{}
+ socketPtrOffsetMap = make(map[string]uint)
+}
+
+// NewServer returns a new graceServer.
+func NewServer(addr string, handler http.Handler) (srv *Server) {
+ once.Do(onceInit)
+ regLock.Lock()
+ defer regLock.Unlock()
+ if !flag.Parsed() {
+ flag.Parse()
+ }
+ if len(socketOrder) > 0 {
+ for i, addr := range strings.Split(socketOrder, ",") {
+ socketPtrOffsetMap[addr] = uint(i)
+ }
+ } else {
+ socketPtrOffsetMap[addr] = uint(len(runningServersOrder))
+ }
+
+ srv = &Server{
+ wg: sync.WaitGroup{},
+ sigChan: make(chan os.Signal),
+ isChild: isChild,
+ SignalHooks: map[int]map[os.Signal][]func(){
+ PreSignal: {
+ syscall.SIGHUP: {},
+ syscall.SIGINT: {},
+ syscall.SIGTERM: {},
+ },
+ PostSignal: {
+ syscall.SIGHUP: {},
+ syscall.SIGINT: {},
+ syscall.SIGTERM: {},
+ },
+ },
+ state: StateInit,
+ Network: "tcp",
+ }
+ srv.Server = &http.Server{}
+ srv.Server.Addr = addr
+ srv.Server.ReadTimeout = DefaultReadTimeOut
+ srv.Server.WriteTimeout = DefaultWriteTimeOut
+ srv.Server.MaxHeaderBytes = DefaultMaxHeaderBytes
+ srv.Server.Handler = handler
+
+ runningServersOrder = append(runningServersOrder, addr)
+ runningServers[addr] = srv
+
+ return
+}
+
+// ListenAndServe refer http.ListenAndServe
+func ListenAndServe(addr string, handler http.Handler) error {
+ server := NewServer(addr, handler)
+ return server.ListenAndServe()
+}
+
+// ListenAndServeTLS refer http.ListenAndServeTLS
+func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error {
+ server := NewServer(addr, handler)
+ return server.ListenAndServeTLS(certFile, keyFile)
+}
diff --git a/grace/listener.go b/grace/listener.go
new file mode 100644
index 00000000..5439d0b2
--- /dev/null
+++ b/grace/listener.go
@@ -0,0 +1,62 @@
+package grace
+
+import (
+ "net"
+ "os"
+ "syscall"
+ "time"
+)
+
+type graceListener struct {
+ net.Listener
+ stop chan error
+ stopped bool
+ server *Server
+}
+
+func newGraceListener(l net.Listener, srv *Server) (el *graceListener) {
+ el = &graceListener{
+ Listener: l,
+ stop: make(chan error),
+ server: srv,
+ }
+ go func() {
+ _ = <-el.stop
+ el.stopped = true
+ el.stop <- el.Listener.Close()
+ }()
+ return
+}
+
+func (gl *graceListener) Accept() (c net.Conn, err error) {
+ tc, err := gl.Listener.(*net.TCPListener).AcceptTCP()
+ if err != nil {
+ return
+ }
+
+ tc.SetKeepAlive(true)
+ tc.SetKeepAlivePeriod(3 * time.Minute)
+
+ c = graceConn{
+ Conn: tc,
+ server: gl.server,
+ }
+
+ gl.server.wg.Add(1)
+ return
+}
+
+func (gl *graceListener) Close() error {
+ if gl.stopped {
+ return syscall.EINVAL
+ }
+ gl.stop <- nil
+ return <-gl.stop
+}
+
+func (gl *graceListener) File() *os.File {
+ // returns a dup(2) - FD_CLOEXEC flag *not* set
+ tl := gl.Listener.(*net.TCPListener)
+ fl, _ := tl.File()
+ return fl
+}
diff --git a/grace/server.go b/grace/server.go
new file mode 100644
index 00000000..f4512ded
--- /dev/null
+++ b/grace/server.go
@@ -0,0 +1,293 @@
+package grace
+
+import (
+ "crypto/tls"
+ "fmt"
+ "log"
+ "net"
+ "net/http"
+ "os"
+ "os/exec"
+ "os/signal"
+ "strings"
+ "sync"
+ "syscall"
+ "time"
+)
+
+// Server embedded http.Server
+type Server struct {
+ *http.Server
+ GraceListener net.Listener
+ SignalHooks map[int]map[os.Signal][]func()
+ tlsInnerListener *graceListener
+ wg sync.WaitGroup
+ sigChan chan os.Signal
+ isChild bool
+ state uint8
+ Network string
+}
+
+// Serve accepts incoming connections on the Listener l,
+// creating a new service goroutine for each.
+// The service goroutines read requests and then call srv.Handler to reply to them.
+func (srv *Server) Serve() (err error) {
+ srv.state = StateRunning
+ err = srv.Server.Serve(srv.GraceListener)
+ log.Println(syscall.Getpid(), "Waiting for connections to finish...")
+ srv.wg.Wait()
+ srv.state = StateTerminate
+ return
+}
+
+// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
+// to handle requests on incoming connections. If srv.Addr is blank, ":http" is
+// used.
+func (srv *Server) ListenAndServe() (err error) {
+ addr := srv.Addr
+ if addr == "" {
+ addr = ":http"
+ }
+
+ go srv.handleSignals()
+
+ l, err := srv.getListener(addr)
+ if err != nil {
+ log.Println(err)
+ return err
+ }
+
+ srv.GraceListener = newGraceListener(l, srv)
+
+ if srv.isChild {
+ process, err := os.FindProcess(os.Getppid())
+ if err != nil {
+ log.Println(err)
+ return err
+ }
+ err = process.Kill()
+ if err != nil {
+ return err
+ }
+ }
+
+ log.Println(os.Getpid(), srv.Addr)
+ return srv.Serve()
+}
+
+// ListenAndServeTLS listens on the TCP network address srv.Addr and then calls
+// Serve to handle requests on incoming TLS connections.
+//
+// Filenames containing a certificate and matching private key for the server must
+// be provided. If the certificate is signed by a certificate authority, the
+// certFile should be the concatenation of the server's certificate followed by the
+// CA's certificate.
+//
+// If srv.Addr is blank, ":https" is used.
+func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
+ addr := srv.Addr
+ if addr == "" {
+ addr = ":https"
+ }
+
+ config := &tls.Config{}
+ if srv.TLSConfig != nil {
+ *config = *srv.TLSConfig
+ }
+ if config.NextProtos == nil {
+ config.NextProtos = []string{"http/1.1"}
+ }
+
+ config.Certificates = make([]tls.Certificate, 1)
+ config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
+ if err != nil {
+ return
+ }
+
+ go srv.handleSignals()
+
+ l, err := srv.getListener(addr)
+ if err != nil {
+ log.Println(err)
+ return err
+ }
+
+ srv.tlsInnerListener = newGraceListener(l, srv)
+ srv.GraceListener = tls.NewListener(srv.tlsInnerListener, config)
+
+ if srv.isChild {
+ process, err := os.FindProcess(os.Getppid())
+ if err != nil {
+ log.Println(err)
+ return err
+ }
+ err = process.Kill()
+ if err != nil {
+ return err
+ }
+ }
+ log.Println(os.Getpid(), srv.Addr)
+ return srv.Serve()
+}
+
+// getListener either opens a new socket to listen on, or takes the acceptor socket
+// it got passed when restarted.
+func (srv *Server) getListener(laddr string) (l net.Listener, err error) {
+ if srv.isChild {
+ var ptrOffset uint
+ if len(socketPtrOffsetMap) > 0 {
+ ptrOffset = socketPtrOffsetMap[laddr]
+ log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr])
+ }
+
+ f := os.NewFile(uintptr(3+ptrOffset), "")
+ l, err = net.FileListener(f)
+ if err != nil {
+ err = fmt.Errorf("net.FileListener error: %v", err)
+ return
+ }
+ } else {
+ l, err = net.Listen(srv.Network, laddr)
+ if err != nil {
+ err = fmt.Errorf("net.Listen error: %v", err)
+ return
+ }
+ }
+ return
+}
+
+// handleSignals listens for os Signals and calls any hooked in function that the
+// user had registered with the signal.
+func (srv *Server) handleSignals() {
+ var sig os.Signal
+
+ signal.Notify(
+ srv.sigChan,
+ syscall.SIGHUP,
+ syscall.SIGINT,
+ syscall.SIGTERM,
+ )
+
+ pid := syscall.Getpid()
+ for {
+ sig = <-srv.sigChan
+ srv.signalHooks(PreSignal, sig)
+ switch sig {
+ case syscall.SIGHUP:
+ log.Println(pid, "Received SIGHUP. forking.")
+ err := srv.fork()
+ if err != nil {
+ log.Println("Fork err:", err)
+ }
+ case syscall.SIGINT:
+ log.Println(pid, "Received SIGINT.")
+ srv.shutdown()
+ case syscall.SIGTERM:
+ log.Println(pid, "Received SIGTERM.")
+ srv.shutdown()
+ default:
+ log.Printf("Received %v: nothing i care about...\n", sig)
+ }
+ srv.signalHooks(PostSignal, sig)
+ }
+}
+
+func (srv *Server) signalHooks(ppFlag int, sig os.Signal) {
+ if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet {
+ return
+ }
+ for _, f := range srv.SignalHooks[ppFlag][sig] {
+ f()
+ }
+ return
+}
+
+// shutdown closes the listener so that no new connections are accepted. it also
+// starts a goroutine that will serverTimeout (stop all running requests) the server
+// after DefaultTimeout.
+func (srv *Server) shutdown() {
+ if srv.state != StateRunning {
+ return
+ }
+
+ srv.state = StateShuttingDown
+ if DefaultTimeout >= 0 {
+ go srv.serverTimeout(DefaultTimeout)
+ }
+ err := srv.GraceListener.Close()
+ if err != nil {
+ log.Println(syscall.Getpid(), "Listener.Close() error:", err)
+ } else {
+ log.Println(syscall.Getpid(), srv.GraceListener.Addr(), "Listener closed.")
+ }
+}
+
+// serverTimeout forces the server to shutdown in a given timeout - whether it
+// finished outstanding requests or not. if Read/WriteTimeout are not set or the
+// max header size is very big a connection could hang
+func (srv *Server) serverTimeout(d time.Duration) {
+ defer func() {
+ if r := recover(); r != nil {
+ log.Println("WaitGroup at 0", r)
+ }
+ }()
+ if srv.state != StateShuttingDown {
+ return
+ }
+ time.Sleep(d)
+ log.Println("[STOP - Hammer Time] Forcefully shutting down parent")
+ for {
+ if srv.state == StateTerminate {
+ break
+ }
+ srv.wg.Done()
+ }
+}
+
+func (srv *Server) fork() (err error) {
+ regLock.Lock()
+ defer regLock.Unlock()
+ if runningServersForked {
+ return
+ }
+ runningServersForked = true
+
+ var files = make([]*os.File, len(runningServers))
+ var orderArgs = make([]string, len(runningServers))
+ for _, srvPtr := range runningServers {
+ switch srvPtr.GraceListener.(type) {
+ case *graceListener:
+ files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.GraceListener.(*graceListener).File()
+ default:
+ files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File()
+ }
+ orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr
+ }
+
+ log.Println(files)
+ path := os.Args[0]
+ var args []string
+ if len(os.Args) > 1 {
+ for _, arg := range os.Args[1:] {
+ if arg == "-graceful" {
+ break
+ }
+ args = append(args, arg)
+ }
+ }
+ args = append(args, "-graceful")
+ if len(runningServers) > 1 {
+ args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ",")))
+ log.Println(args)
+ }
+ cmd := exec.Command(path, args...)
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+ cmd.ExtraFiles = files
+ err = cmd.Start()
+ if err != nil {
+ log.Fatalf("Restart: Failed to launch, error: %v", err)
+ }
+
+ return
+}
diff --git a/hooks.go b/hooks.go
new file mode 100644
index 00000000..78abf8ef
--- /dev/null
+++ b/hooks.go
@@ -0,0 +1,95 @@
+package beego
+
+import (
+ "encoding/json"
+ "mime"
+ "net/http"
+ "path/filepath"
+
+ "github.com/astaxie/beego/session"
+)
+
+//
+func registerMime() error {
+ for k, v := range mimemaps {
+ mime.AddExtensionType(k, v)
+ }
+ return nil
+}
+
+// register default error http handlers, 404,401,403,500 and 503.
+func registerDefaultErrorHandler() error {
+ m := map[string]func(http.ResponseWriter, *http.Request){
+ "401": unauthorized,
+ "402": paymentRequired,
+ "403": forbidden,
+ "404": notFound,
+ "405": methodNotAllowed,
+ "500": internalServerError,
+ "501": notImplemented,
+ "502": badGateway,
+ "503": serviceUnavailable,
+ "504": gatewayTimeout,
+ }
+ for e, h := range m {
+ if _, ok := ErrorMaps[e]; !ok {
+ ErrorHandler(e, h)
+ }
+ }
+ return nil
+}
+
+func registerSession() error {
+ if BConfig.WebConfig.Session.SessionOn {
+ var err error
+ sessionConfig := AppConfig.String("sessionConfig")
+ if sessionConfig == "" {
+ conf := map[string]interface{}{
+ "cookieName": BConfig.WebConfig.Session.SessionName,
+ "gclifetime": BConfig.WebConfig.Session.SessionGCMaxLifetime,
+ "providerConfig": filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig),
+ "secure": BConfig.Listen.EnableHTTPS,
+ "enableSetCookie": BConfig.WebConfig.Session.SessionAutoSetCookie,
+ "domain": BConfig.WebConfig.Session.SessionDomain,
+ "cookieLifeTime": BConfig.WebConfig.Session.SessionCookieLifeTime,
+ }
+ confBytes, err := json.Marshal(conf)
+ if err != nil {
+ return err
+ }
+ sessionConfig = string(confBytes)
+ }
+ if GlobalSessions, err = session.NewManager(BConfig.WebConfig.Session.SessionProvider, sessionConfig); err != nil {
+ return err
+ }
+ go GlobalSessions.GC()
+ }
+ return nil
+}
+
+func registerTemplate() error {
+ if BConfig.WebConfig.AutoRender {
+ if err := BuildTemplate(BConfig.WebConfig.ViewsPath); err != nil {
+ if BConfig.RunMode == DEV {
+ Warn(err)
+ }
+ return err
+ }
+ }
+ return nil
+}
+
+func registerDocs() error {
+ if BConfig.WebConfig.EnableDocs {
+ Get("/docs", serverDocs)
+ Get("/docs/*", serverDocs)
+ }
+ return nil
+}
+
+func registerAdmin() error {
+ if BConfig.Listen.EnableAdmin {
+ go beeAdminApp.Run()
+ }
+ return nil
+}
diff --git a/httplib/httplib.go b/httplib/httplib.go
index 7ff2f1d2..fb64a30a 100644
--- a/httplib/httplib.go
+++ b/httplib/httplib.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package httplib is used as http.Client
// Usage:
//
// import "github.com/astaxie/beego/httplib"
@@ -32,6 +33,7 @@ package httplib
import (
"bytes"
+ "compress/gzip"
"crypto/tls"
"encoding/json"
"encoding/xml"
@@ -50,7 +52,14 @@ import (
"time"
)
-var defaultSetting = BeegoHttpSettings{false, "beegoServer", 60 * time.Second, 60 * time.Second, nil, nil, nil, false}
+var defaultSetting = BeegoHTTPSettings{
+ UserAgent: "beegoServer",
+ ConnectTimeout: 60 * time.Second,
+ ReadWriteTimeout: 60 * time.Second,
+ Gzip: true,
+ DumpBody: true,
+}
+
var defaultCookieJar http.CookieJar
var settingMutex sync.Mutex
@@ -61,132 +70,163 @@ func createDefaultCookie() {
defaultCookieJar, _ = cookiejar.New(nil)
}
-// Overwrite default settings
-func SetDefaultSetting(setting BeegoHttpSettings) {
+// SetDefaultSetting Overwrite default settings
+func SetDefaultSetting(setting BeegoHTTPSettings) {
settingMutex.Lock()
defer settingMutex.Unlock()
defaultSetting = setting
- if defaultSetting.ConnectTimeout == 0 {
- defaultSetting.ConnectTimeout = 60 * time.Second
- }
- if defaultSetting.ReadWriteTimeout == 0 {
- defaultSetting.ReadWriteTimeout = 60 * time.Second
- }
}
-// return *BeegoHttpRequest with specific method
-func newBeegoRequest(url, method string) *BeegoHttpRequest {
+// NewBeegoRequest return *BeegoHttpRequest with specific method
+func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest {
var resp http.Response
+ u, err := url.Parse(rawurl)
+ if err != nil {
+ log.Println("Httplib:", err)
+ }
req := http.Request{
+ URL: u,
Method: method,
Header: make(http.Header),
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
}
- return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, defaultSetting, &resp, nil}
+ return &BeegoHTTPRequest{
+ url: rawurl,
+ req: &req,
+ params: map[string][]string{},
+ files: map[string]string{},
+ setting: defaultSetting,
+ resp: &resp,
+ }
}
// Get returns *BeegoHttpRequest with GET method.
-func Get(url string) *BeegoHttpRequest {
- return newBeegoRequest(url, "GET")
+func Get(url string) *BeegoHTTPRequest {
+ return NewBeegoRequest(url, "GET")
}
// Post returns *BeegoHttpRequest with POST method.
-func Post(url string) *BeegoHttpRequest {
- return newBeegoRequest(url, "POST")
+func Post(url string) *BeegoHTTPRequest {
+ return NewBeegoRequest(url, "POST")
}
// Put returns *BeegoHttpRequest with PUT method.
-func Put(url string) *BeegoHttpRequest {
- return newBeegoRequest(url, "PUT")
+func Put(url string) *BeegoHTTPRequest {
+ return NewBeegoRequest(url, "PUT")
}
// Delete returns *BeegoHttpRequest DELETE method.
-func Delete(url string) *BeegoHttpRequest {
- return newBeegoRequest(url, "DELETE")
+func Delete(url string) *BeegoHTTPRequest {
+ return NewBeegoRequest(url, "DELETE")
}
// Head returns *BeegoHttpRequest with HEAD method.
-func Head(url string) *BeegoHttpRequest {
- return newBeegoRequest(url, "HEAD")
+func Head(url string) *BeegoHTTPRequest {
+ return NewBeegoRequest(url, "HEAD")
}
-// BeegoHttpSettings
-type BeegoHttpSettings struct {
+// BeegoHTTPSettings is the http.Client setting
+type BeegoHTTPSettings struct {
ShowDebug bool
UserAgent string
ConnectTimeout time.Duration
ReadWriteTimeout time.Duration
- TlsClientConfig *tls.Config
+ TLSClientConfig *tls.Config
Proxy func(*http.Request) (*url.URL, error)
Transport http.RoundTripper
EnableCookie bool
+ Gzip bool
+ DumpBody bool
}
-// BeegoHttpRequest provides more useful methods for requesting one url than http.Request.
-type BeegoHttpRequest struct {
+// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request.
+type BeegoHTTPRequest struct {
url string
req *http.Request
- params map[string]string
+ params map[string][]string
files map[string]string
- setting BeegoHttpSettings
+ setting BeegoHTTPSettings
resp *http.Response
body []byte
+ dump []byte
}
-// Change request settings
-func (b *BeegoHttpRequest) Setting(setting BeegoHttpSettings) *BeegoHttpRequest {
+// GetRequest return the request object
+func (b *BeegoHTTPRequest) GetRequest() *http.Request {
+ return b.req
+}
+
+// Setting Change request settings
+func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest {
b.setting = setting
return b
}
// SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password.
-func (b *BeegoHttpRequest) SetBasicAuth(username, password string) *BeegoHttpRequest {
+func (b *BeegoHTTPRequest) SetBasicAuth(username, password string) *BeegoHTTPRequest {
b.req.SetBasicAuth(username, password)
return b
}
// SetEnableCookie sets enable/disable cookiejar
-func (b *BeegoHttpRequest) SetEnableCookie(enable bool) *BeegoHttpRequest {
+func (b *BeegoHTTPRequest) SetEnableCookie(enable bool) *BeegoHTTPRequest {
b.setting.EnableCookie = enable
return b
}
// SetUserAgent sets User-Agent header field
-func (b *BeegoHttpRequest) SetUserAgent(useragent string) *BeegoHttpRequest {
+func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest {
b.setting.UserAgent = useragent
return b
}
// Debug sets show debug or not when executing request.
-func (b *BeegoHttpRequest) Debug(isdebug bool) *BeegoHttpRequest {
+func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest {
b.setting.ShowDebug = isdebug
return b
}
+// DumpBody setting whether need to Dump the Body.
+func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest {
+ b.setting.DumpBody = isdump
+ return b
+}
+
+// DumpRequest return the DumpRequest
+func (b *BeegoHTTPRequest) DumpRequest() []byte {
+ return b.dump
+}
+
// SetTimeout sets connect time out and read-write time out for BeegoRequest.
-func (b *BeegoHttpRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHttpRequest {
+func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest {
b.setting.ConnectTimeout = connectTimeout
b.setting.ReadWriteTimeout = readWriteTimeout
return b
}
// SetTLSClientConfig sets tls connection configurations if visiting https url.
-func (b *BeegoHttpRequest) SetTLSClientConfig(config *tls.Config) *BeegoHttpRequest {
- b.setting.TlsClientConfig = config
+func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest {
+ b.setting.TLSClientConfig = config
return b
}
// Header add header item string in request.
-func (b *BeegoHttpRequest) Header(key, value string) *BeegoHttpRequest {
+func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest {
b.req.Header.Set(key, value)
return b
}
-// Set the protocol version for incoming requests.
+// SetHost set the request host
+func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest {
+ b.req.Host = host
+ return b
+}
+
+// SetProtocolVersion Set the protocol version for incoming requests.
// Client requests always use HTTP/1.1.
-func (b *BeegoHttpRequest) SetProtocolVersion(vers string) *BeegoHttpRequest {
+func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest {
if len(vers) == 0 {
vers = "HTTP/1.1"
}
@@ -202,44 +242,49 @@ func (b *BeegoHttpRequest) SetProtocolVersion(vers string) *BeegoHttpRequest {
}
// SetCookie add cookie into request.
-func (b *BeegoHttpRequest) SetCookie(cookie *http.Cookie) *BeegoHttpRequest {
+func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest {
b.req.Header.Add("Cookie", cookie.String())
return b
}
-// Set transport to
-func (b *BeegoHttpRequest) SetTransport(transport http.RoundTripper) *BeegoHttpRequest {
+// SetTransport set the setting transport
+func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest {
b.setting.Transport = transport
return b
}
-// Set http proxy
+// SetProxy set the http proxy
// example:
//
// func(req *http.Request) (*url.URL, error) {
// u, _ := url.ParseRequestURI("http://127.0.0.1:8118")
// return u, nil
// }
-func (b *BeegoHttpRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHttpRequest {
+func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHTTPRequest {
b.setting.Proxy = proxy
return b
}
// Param adds query param in to request.
// params build query string as ?key1=value1&key2=value2...
-func (b *BeegoHttpRequest) Param(key, value string) *BeegoHttpRequest {
- b.params[key] = value
+func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest {
+ if param, ok := b.params[key]; ok {
+ b.params[key] = append(param, value)
+ } else {
+ b.params[key] = []string{value}
+ }
return b
}
-func (b *BeegoHttpRequest) PostFile(formname, filename string) *BeegoHttpRequest {
+// PostFile add a post file to the request
+func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest {
b.files[formname] = filename
return b
}
// Body adds request raw body.
// it supports string and []byte.
-func (b *BeegoHttpRequest) Body(data interface{}) *BeegoHttpRequest {
+func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest {
switch t := data.(type) {
case string:
bf := bytes.NewBufferString(t)
@@ -253,7 +298,22 @@ func (b *BeegoHttpRequest) Body(data interface{}) *BeegoHttpRequest {
return b
}
-func (b *BeegoHttpRequest) buildUrl(paramBody string) {
+// JSONBody adds request raw body encoding by JSON.
+func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) {
+ if b.req.Body == nil && obj != nil {
+ buf := bytes.NewBuffer(nil)
+ enc := json.NewEncoder(buf)
+ if err := enc.Encode(obj); err != nil {
+ return b, err
+ }
+ b.req.Body = ioutil.NopCloser(buf)
+ b.req.ContentLength = int64(buf.Len())
+ b.req.Header.Set("Content-Type", "application/json")
+ }
+ return b, nil
+}
+
+func (b *BeegoHTTPRequest) buildURL(paramBody string) {
// build GET url with query string
if b.req.Method == "GET" && len(paramBody) > 0 {
if strings.Index(b.url, "?") != -1 {
@@ -264,8 +324,8 @@ func (b *BeegoHttpRequest) buildUrl(paramBody string) {
return
}
- // build POST url and body
- if b.req.Method == "POST" && b.req.Body == nil {
+ // build POST/PUT/PATCH url and body
+ if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH") && b.req.Body == nil {
// with files
if len(b.files) > 0 {
pr, pw := io.Pipe()
@@ -274,21 +334,23 @@ func (b *BeegoHttpRequest) buildUrl(paramBody string) {
for formname, filename := range b.files {
fileWriter, err := bodyWriter.CreateFormFile(formname, filename)
if err != nil {
- log.Fatal(err)
+ log.Println("Httplib:", err)
}
fh, err := os.Open(filename)
if err != nil {
- log.Fatal(err)
+ log.Println("Httplib:", err)
}
//iocopy
_, err = io.Copy(fileWriter, fh)
fh.Close()
if err != nil {
- log.Fatal(err)
+ log.Println("Httplib:", err)
}
}
for k, v := range b.params {
- bodyWriter.WriteField(k, v)
+ for _, vv := range v {
+ bodyWriter.WriteField(k, vv)
+ }
}
bodyWriter.Close()
pw.Close()
@@ -306,24 +368,36 @@ func (b *BeegoHttpRequest) buildUrl(paramBody string) {
}
}
-func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
+func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) {
if b.resp.StatusCode != 0 {
return b.resp, nil
}
+ resp, err := b.DoRequest()
+ if err != nil {
+ return nil, err
+ }
+ b.resp = resp
+ return resp, nil
+}
+
+// DoRequest will do the client.Do
+func (b *BeegoHTTPRequest) DoRequest() (*http.Response, error) {
var paramBody string
if len(b.params) > 0 {
var buf bytes.Buffer
for k, v := range b.params {
- buf.WriteString(url.QueryEscape(k))
- buf.WriteByte('=')
- buf.WriteString(url.QueryEscape(v))
- buf.WriteByte('&')
+ for _, vv := range v {
+ buf.WriteString(url.QueryEscape(k))
+ buf.WriteByte('=')
+ buf.WriteString(url.QueryEscape(vv))
+ buf.WriteByte('&')
+ }
}
paramBody = buf.String()
paramBody = paramBody[0 : len(paramBody)-1]
}
- b.buildUrl(paramBody)
+ b.buildURL(paramBody)
url, err := url.Parse(b.url)
if err != nil {
return nil, err
@@ -336,7 +410,7 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
if trans == nil {
// create default transport
trans = &http.Transport{
- TLSClientConfig: b.setting.TlsClientConfig,
+ TLSClientConfig: b.setting.TLSClientConfig,
Proxy: b.setting.Proxy,
Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout),
}
@@ -344,7 +418,7 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
// if b.transport is *http.Transport then set the settings.
if t, ok := trans.(*http.Transport); ok {
if t.TLSClientConfig == nil {
- t.TLSClientConfig = b.setting.TlsClientConfig
+ t.TLSClientConfig = b.setting.TLSClientConfig
}
if t.Proxy == nil {
t.Proxy = b.setting.Proxy
@@ -355,7 +429,7 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
}
}
- var jar http.CookieJar = nil
+ var jar http.CookieJar
if b.setting.EnableCookie {
if defaultCookieJar == nil {
createDefaultCookie()
@@ -373,24 +447,18 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
}
if b.setting.ShowDebug {
- dump, err := httputil.DumpRequest(b.req, true)
+ dump, err := httputil.DumpRequest(b.req, b.setting.DumpBody)
if err != nil {
- println(err.Error())
+ log.Println(err.Error())
}
- println(string(dump))
+ b.dump = dump
}
-
- resp, err := client.Do(b.req)
- if err != nil {
- return nil, err
- }
- b.resp = resp
- return resp, nil
+ return client.Do(b.req)
}
// String returns the body string in response.
// it calls Response inner.
-func (b *BeegoHttpRequest) String() (string, error) {
+func (b *BeegoHTTPRequest) String() (string, error) {
data, err := b.Bytes()
if err != nil {
return "", err
@@ -401,7 +469,7 @@ func (b *BeegoHttpRequest) String() (string, error) {
// Bytes returns the body []byte in response.
// it calls Response inner.
-func (b *BeegoHttpRequest) Bytes() ([]byte, error) {
+func (b *BeegoHTTPRequest) Bytes() ([]byte, error) {
if b.body != nil {
return b.body, nil
}
@@ -413,16 +481,21 @@ func (b *BeegoHttpRequest) Bytes() ([]byte, error) {
return nil, nil
}
defer resp.Body.Close()
- b.body, err = ioutil.ReadAll(resp.Body)
- if err != nil {
- return nil, err
+ if b.setting.Gzip && resp.Header.Get("Content-Encoding") == "gzip" {
+ reader, err := gzip.NewReader(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+ b.body, err = ioutil.ReadAll(reader)
+ } else {
+ b.body, err = ioutil.ReadAll(resp.Body)
}
- return b.body, nil
+ return b.body, err
}
// ToFile saves the body data in response to one file.
// it calls Response inner.
-func (b *BeegoHttpRequest) ToFile(filename string) error {
+func (b *BeegoHTTPRequest) ToFile(filename string) error {
f, err := os.Create(filename)
if err != nil {
return err
@@ -441,9 +514,9 @@ func (b *BeegoHttpRequest) ToFile(filename string) error {
return err
}
-// ToJson returns the map that marshals from the body bytes as json in response .
+// ToJSON returns the map that marshals from the body bytes as json in response .
// it calls Response inner.
-func (b *BeegoHttpRequest) ToJson(v interface{}) error {
+func (b *BeegoHTTPRequest) ToJSON(v interface{}) error {
data, err := b.Bytes()
if err != nil {
return err
@@ -451,9 +524,9 @@ func (b *BeegoHttpRequest) ToJson(v interface{}) error {
return json.Unmarshal(data, v)
}
-// ToXml returns the map that marshals from the body bytes as xml in response .
+// ToXML returns the map that marshals from the body bytes as xml in response .
// it calls Response inner.
-func (b *BeegoHttpRequest) ToXml(v interface{}) error {
+func (b *BeegoHTTPRequest) ToXML(v interface{}) error {
data, err := b.Bytes()
if err != nil {
return err
@@ -462,7 +535,7 @@ func (b *BeegoHttpRequest) ToXml(v interface{}) error {
}
// Response executes request client gets response mannually.
-func (b *BeegoHttpRequest) Response() (*http.Response, error) {
+func (b *BeegoHTTPRequest) Response() (*http.Response, error) {
return b.getResponse()
}
@@ -473,7 +546,7 @@ func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, ad
if err != nil {
return nil, err
}
- conn.SetDeadline(time.Now().Add(rwTimeout))
- return conn, nil
+ err = conn.SetDeadline(time.Now().Add(rwTimeout))
+ return conn, err
}
}
diff --git a/httplib/httplib_test.go b/httplib/httplib_test.go
index 0b551c53..05815054 100644
--- a/httplib/httplib_test.go
+++ b/httplib/httplib_test.go
@@ -19,6 +19,7 @@ import (
"os"
"strings"
"testing"
+ "time"
)
func TestResponse(t *testing.T) {
@@ -149,10 +150,11 @@ func TestWithUserAgent(t *testing.T) {
func TestWithSetting(t *testing.T) {
v := "beego"
- var setting BeegoHttpSettings
+ var setting BeegoHTTPSettings
setting.EnableCookie = true
setting.UserAgent = v
setting.Transport = nil
+ setting.ReadWriteTimeout = 5 * time.Second
SetDefaultSetting(setting)
str, err := Get("http://httpbin.org/get").String()
@@ -176,11 +178,11 @@ func TestToJson(t *testing.T) {
t.Log(resp)
// httpbin will return http remote addr
- type Ip struct {
+ type IP struct {
Origin string `json:"origin"`
}
- var ip Ip
- err = req.ToJson(&ip)
+ var ip IP
+ err = req.ToJSON(&ip)
if err != nil {
t.Fatal(err)
}
diff --git a/log.go b/log.go
index 5afba8ed..46ec57dd 100644
--- a/log.go
+++ b/log.go
@@ -32,20 +32,20 @@ const (
LevelDebug
)
-// SetLogLevel sets the global log level used by the simple
-// logger.
+// BeeLogger references the used application logger.
+var BeeLogger = logs.NewLogger(100)
+
+// SetLevel sets the global log level used by the simple logger.
func SetLevel(l int) {
BeeLogger.SetLevel(l)
}
+// SetLogFuncCall set the CallDepth, default is 3
func SetLogFuncCall(b bool) {
BeeLogger.EnableFuncCallDepth(b)
BeeLogger.SetLogFuncCallDepth(3)
}
-// logger references the used application logger.
-var BeeLogger *logs.BeeLogger
-
// SetLogger sets a new logger.
func SetLogger(adaptername string, config string) error {
err := BeeLogger.SetLogger(adaptername, config)
@@ -55,10 +55,12 @@ func SetLogger(adaptername string, config string) error {
return nil
}
+// Emergency logs a message at emergency level.
func Emergency(v ...interface{}) {
BeeLogger.Emergency(generateFmtStr(len(v)), v...)
}
+// Alert logs a message at alert level.
func Alert(v ...interface{}) {
BeeLogger.Alert(generateFmtStr(len(v)), v...)
}
@@ -78,23 +80,24 @@ func Warning(v ...interface{}) {
BeeLogger.Warning(generateFmtStr(len(v)), v...)
}
-// Deprecated: compatibility alias for Warning(), Will be removed in 1.5.0.
+// Warn compatibility alias for Warning()
func Warn(v ...interface{}) {
- Warning(v...)
+ BeeLogger.Warn(generateFmtStr(len(v)), v...)
}
+// Notice logs a message at notice level.
func Notice(v ...interface{}) {
BeeLogger.Notice(generateFmtStr(len(v)), v...)
}
-// Info logs a message at info level.
+// Informational logs a message at info level.
func Informational(v ...interface{}) {
BeeLogger.Informational(generateFmtStr(len(v)), v...)
}
-// Deprecated: compatibility alias for Warning(), Will be removed in 1.5.0.
+// Info compatibility alias for Warning()
func Info(v ...interface{}) {
- Informational(v...)
+ BeeLogger.Info(generateFmtStr(len(v)), v...)
}
// Debug logs a message at debug level.
@@ -103,7 +106,7 @@ func Debug(v ...interface{}) {
}
// Trace logs a message at trace level.
-// Deprecated: compatibility alias for Warning(), Will be removed in 1.5.0.
+// compatibility alias for Warning()
func Trace(v ...interface{}) {
BeeLogger.Trace(generateFmtStr(len(v)), v...)
}
diff --git a/logs/conn.go b/logs/conn.go
index 2240eece..3655bf51 100644
--- a/logs/conn.go
+++ b/logs/conn.go
@@ -21,9 +21,9 @@ import (
"net"
)
-// ConnWriter implements LoggerInterface.
+// connWriter implements LoggerInterface.
// it writes messages in keep-live tcp connection.
-type ConnWriter struct {
+type connWriter struct {
lg *log.Logger
innerWriter io.WriteCloser
ReconnectOnMsg bool `json:"reconnectOnMsg"`
@@ -33,22 +33,22 @@ type ConnWriter struct {
Level int `json:"level"`
}
-// create new ConnWrite returning as LoggerInterface.
-func NewConn() LoggerInterface {
- conn := new(ConnWriter)
+// NewConn create new ConnWrite returning as LoggerInterface.
+func NewConn() Logger {
+ conn := new(connWriter)
conn.Level = LevelTrace
return conn
}
-// init connection writer with json config.
+// Init init connection writer with json config.
// json config only need key "level".
-func (c *ConnWriter) Init(jsonconfig string) error {
+func (c *connWriter) Init(jsonconfig string) error {
return json.Unmarshal([]byte(jsonconfig), c)
}
-// write message in connection.
+// WriteMsg write message in connection.
// if connection is down, try to re-connect.
-func (c *ConnWriter) WriteMsg(msg string, level int) error {
+func (c *connWriter) WriteMsg(msg string, level int) error {
if level > c.Level {
return nil
}
@@ -66,19 +66,19 @@ func (c *ConnWriter) WriteMsg(msg string, level int) error {
return nil
}
-// implementing method. empty.
-func (c *ConnWriter) Flush() {
+// Flush implementing method. empty.
+func (c *connWriter) Flush() {
}
-// destroy connection writer and close tcp listener.
-func (c *ConnWriter) Destroy() {
+// Destroy destroy connection writer and close tcp listener.
+func (c *connWriter) Destroy() {
if c.innerWriter != nil {
c.innerWriter.Close()
}
}
-func (c *ConnWriter) connect() error {
+func (c *connWriter) connect() error {
if c.innerWriter != nil {
c.innerWriter.Close()
c.innerWriter = nil
@@ -98,7 +98,7 @@ func (c *ConnWriter) connect() error {
return nil
}
-func (c *ConnWriter) neddedConnectOnMsg() bool {
+func (c *connWriter) neddedConnectOnMsg() bool {
if c.Reconnect {
c.Reconnect = false
return true
diff --git a/logs/console.go b/logs/console.go
index ce7ecd54..23e8ebca 100644
--- a/logs/console.go
+++ b/logs/console.go
@@ -21,9 +21,11 @@ import (
"runtime"
)
-type Brush func(string) string
+// brush is a color join function
+type brush func(string) string
-func NewBrush(color string) Brush {
+// newBrush return a fix color Brush
+func newBrush(color string) brush {
pre := "\033["
reset := "\033[0m"
return func(text string) string {
@@ -31,43 +33,43 @@ func NewBrush(color string) Brush {
}
}
-var colors = []Brush{
- NewBrush("1;37"), // Emergency white
- NewBrush("1;36"), // Alert cyan
- NewBrush("1;35"), // Critical magenta
- NewBrush("1;31"), // Error red
- NewBrush("1;33"), // Warning yellow
- NewBrush("1;32"), // Notice green
- NewBrush("1;34"), // Informational blue
- NewBrush("1;34"), // Debug blue
+var colors = []brush{
+ newBrush("1;37"), // Emergency white
+ newBrush("1;36"), // Alert cyan
+ newBrush("1;35"), // Critical magenta
+ newBrush("1;31"), // Error red
+ newBrush("1;33"), // Warning yellow
+ newBrush("1;32"), // Notice green
+ newBrush("1;34"), // Informational blue
+ newBrush("1;34"), // Debug blue
}
-// ConsoleWriter implements LoggerInterface and writes messages to terminal.
-type ConsoleWriter struct {
+// consoleWriter implements LoggerInterface and writes messages to terminal.
+type consoleWriter struct {
lg *log.Logger
Level int `json:"level"`
}
-// create ConsoleWriter returning as LoggerInterface.
-func NewConsole() LoggerInterface {
- cw := &ConsoleWriter{
+// NewConsole create ConsoleWriter returning as LoggerInterface.
+func NewConsole() Logger {
+ cw := &consoleWriter{
lg: log.New(os.Stdout, "", log.Ldate|log.Ltime),
Level: LevelDebug,
}
return cw
}
-// init console logger.
+// Init init console logger.
// jsonconfig like '{"level":LevelTrace}'.
-func (c *ConsoleWriter) Init(jsonconfig string) error {
+func (c *consoleWriter) Init(jsonconfig string) error {
if len(jsonconfig) == 0 {
return nil
}
return json.Unmarshal([]byte(jsonconfig), c)
}
-// write message in console.
-func (c *ConsoleWriter) WriteMsg(msg string, level int) error {
+// WriteMsg write message in console.
+func (c *consoleWriter) WriteMsg(msg string, level int) error {
if level > c.Level {
return nil
}
@@ -80,13 +82,13 @@ func (c *ConsoleWriter) WriteMsg(msg string, level int) error {
return nil
}
-// implementing method. empty.
-func (c *ConsoleWriter) Destroy() {
+// Destroy implementing method. empty.
+func (c *consoleWriter) Destroy() {
}
-// implementing method. empty.
-func (c *ConsoleWriter) Flush() {
+// Flush implementing method. empty.
+func (c *consoleWriter) Flush() {
}
diff --git a/logs/console_test.go b/logs/console_test.go
index 2fad7241..ce8937d4 100644
--- a/logs/console_test.go
+++ b/logs/console_test.go
@@ -42,12 +42,3 @@ func TestConsole(t *testing.T) {
log2.SetLogger("console", `{"level":3}`)
testConsoleCalls(log2)
}
-
-func BenchmarkConsole(b *testing.B) {
- log := NewLogger(10000)
- log.EnableFuncCallDepth(true)
- log.SetLogger("console", "")
- for i := 0; i < b.N; i++ {
- log.Debug("debug")
- }
-}
diff --git a/logs/es/es.go b/logs/es/es.go
new file mode 100644
index 00000000..f8dc5f65
--- /dev/null
+++ b/logs/es/es.go
@@ -0,0 +1,80 @@
+package es
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net"
+ "net/url"
+ "time"
+
+ "github.com/astaxie/beego/logs"
+ "github.com/belogik/goes"
+)
+
+// NewES return a LoggerInterface
+func NewES() logs.Logger {
+ cw := &esLogger{
+ Level: logs.LevelDebug,
+ }
+ return cw
+}
+
+type esLogger struct {
+ *goes.Connection
+ DSN string `json:"dsn"`
+ Level int `json:"level"`
+}
+
+// {"dsn":"http://localhost:9200/","level":1}
+func (el *esLogger) Init(jsonconfig string) error {
+ err := json.Unmarshal([]byte(jsonconfig), el)
+ if err != nil {
+ return err
+ }
+ if el.DSN == "" {
+ return errors.New("empty dsn")
+ } else if u, err := url.Parse(el.DSN); err != nil {
+ return err
+ } else if u.Path == "" {
+ return errors.New("missing prefix")
+ } else if host, port, err := net.SplitHostPort(u.Host); err != nil {
+ return err
+ } else {
+ conn := goes.NewConnection(host, port)
+ el.Connection = conn
+ }
+ return nil
+}
+
+// WriteMsg will write the msg and level into es
+func (el *esLogger) WriteMsg(msg string, level int) error {
+ if level > el.Level {
+ return nil
+ }
+ t := time.Now()
+ vals := make(map[string]interface{})
+ vals["@timestamp"] = t.Format(time.RFC3339)
+ vals["@msg"] = msg
+ d := goes.Document{
+ Index: fmt.Sprintf("%04d.%02d.%02d", t.Year(), t.Month(), t.Day()),
+ Type: "logs",
+ Fields: vals,
+ }
+ _, err := el.Index(d, nil)
+ return err
+}
+
+// Destroy is a empty method
+func (el *esLogger) Destroy() {
+
+}
+
+// Flush is a empty method
+func (el *esLogger) Flush() {
+
+}
+
+func init() {
+ logs.Register("es", NewES)
+}
diff --git a/logs/file.go b/logs/file.go
index 2d3449ce..0eae734a 100644
--- a/logs/file.go
+++ b/logs/file.go
@@ -20,7 +20,6 @@ import (
"errors"
"fmt"
"io"
- "log"
"os"
"path/filepath"
"strings"
@@ -28,84 +27,62 @@ import (
"time"
)
-// FileLogWriter implements LoggerInterface.
+// fileLogWriter implements LoggerInterface.
// It writes messages by lines limit, file size limit, or time frequency.
-type FileLogWriter struct {
- *log.Logger
- mw *MuxWriter
+type fileLogWriter struct {
+ sync.Mutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize
// The opened file
- Filename string `json:"filename"`
+ Filename string `json:"filename"`
+ fileWriter *os.File
- Maxlines int `json:"maxlines"`
- maxlines_curlines int
+ // Rotate at line
+ MaxLines int `json:"maxlines"`
+ maxLinesCurLines int
// Rotate at size
- Maxsize int `json:"maxsize"`
- maxsize_cursize int
+ MaxSize int `json:"maxsize"`
+ maxSizeCurSize int
// Rotate daily
- Daily bool `json:"daily"`
- Maxdays int64 `json:"maxdays"`
- daily_opendate int
+ Daily bool `json:"daily"`
+ MaxDays int64 `json:"maxdays"`
+ dailyOpenDate int
Rotate bool `json:"rotate"`
- startLock sync.Mutex // Only one log can write to the file
-
Level int `json:"level"`
+
+ Perm os.FileMode `json:"perm"`
}
-// an *os.File writer with locker.
-type MuxWriter struct {
- sync.Mutex
- fd *os.File
-}
-
-// write to os.File.
-func (l *MuxWriter) Write(b []byte) (int, error) {
- l.Lock()
- defer l.Unlock()
- return l.fd.Write(b)
-}
-
-// set os.File in writer.
-func (l *MuxWriter) SetFd(fd *os.File) {
- if l.fd != nil {
- l.fd.Close()
- }
- l.fd = fd
-}
-
-// create a FileLogWriter returning as LoggerInterface.
-func NewFileWriter() LoggerInterface {
- w := &FileLogWriter{
+// NewFileWriter create a FileLogWriter returning as LoggerInterface.
+func newFileWriter() Logger {
+ w := &fileLogWriter{
Filename: "",
- Maxlines: 1000000,
- Maxsize: 1 << 28, //256 MB
+ MaxLines: 1000000,
+ MaxSize: 1 << 28, //256 MB
Daily: true,
- Maxdays: 7,
+ MaxDays: 7,
Rotate: true,
Level: LevelTrace,
+ Perm: 0660,
}
- // use MuxWriter instead direct use os.File for lock write when rotate
- w.mw = new(MuxWriter)
- // set MuxWriter as Logger's io.Writer
- w.Logger = log.New(w.mw, "", log.Ldate|log.Ltime)
return w
}
// Init file logger with json config.
-// jsonconfig like:
+// jsonConfig like:
// {
// "filename":"logs/beego.log",
-// "maxlines":10000,
+// "maxLines":10000,
// "maxsize":1<<30,
// "daily":true,
-// "maxdays":15,
-// "rotate":true
+// "maxDays":15,
+// "rotate":true,
+// "perm":0600
// }
-func (w *FileLogWriter) Init(jsonconfig string) error {
- err := json.Unmarshal([]byte(jsonconfig), w)
+func (w *fileLogWriter) Init(jsonConfig string) error {
+ err := json.Unmarshal([]byte(jsonConfig), w)
if err != nil {
return err
}
@@ -117,67 +94,120 @@ func (w *FileLogWriter) Init(jsonconfig string) error {
}
// start file logger. create log file and set to locker-inside file writer.
-func (w *FileLogWriter) startLogger() error {
- fd, err := w.createLogFile()
+func (w *fileLogWriter) startLogger() error {
+ file, err := w.createLogFile()
if err != nil {
return err
}
- w.mw.SetFd(fd)
+ if w.fileWriter != nil {
+ w.fileWriter.Close()
+ }
+ w.fileWriter = file
return w.initFd()
}
-func (w *FileLogWriter) docheck(size int) {
- w.startLock.Lock()
- defer w.startLock.Unlock()
- if w.Rotate && ((w.Maxlines > 0 && w.maxlines_curlines >= w.Maxlines) ||
- (w.Maxsize > 0 && w.maxsize_cursize >= w.Maxsize) ||
- (w.Daily && time.Now().Day() != w.daily_opendate)) {
- if err := w.DoRotate(); err != nil {
- fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
- return
- }
- }
- w.maxlines_curlines++
- w.maxsize_cursize += size
+func (w *fileLogWriter) needRotate(size int, day int) bool {
+ return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) ||
+ (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) ||
+ (w.Daily && day != w.dailyOpenDate)
+
}
-// write logger message into file.
-func (w *FileLogWriter) WriteMsg(msg string, level int) error {
+// WriteMsg write logger message into file.
+func (w *fileLogWriter) WriteMsg(msg string, level int) error {
if level > w.Level {
return nil
}
- n := 24 + len(msg) // 24 stand for the length "2013/06/23 21:00:22 [T] "
- w.docheck(n)
- w.Logger.Println(msg)
- return nil
+ //2016/01/12 21:34:33
+ now := time.Now()
+ y, mo, d := now.Date()
+ h, mi, s := now.Clock()
+ //len(2006/01/02 15:03:04)==19
+ var buf [20]byte
+ t := 3
+ for y >= 10 {
+ p := y / 10
+ buf[t] = byte('0' + y - p*10)
+ y = p
+ t--
+ }
+ buf[0] = byte('0' + y)
+ buf[4] = '/'
+ if mo > 9 {
+ buf[5] = '1'
+ buf[6] = byte('0' + mo - 9)
+ } else {
+ buf[5] = '0'
+ buf[6] = byte('0' + mo)
+ }
+ buf[7] = '/'
+ t = d / 10
+ buf[8] = byte('0' + t)
+ buf[9] = byte('0' + d - t*10)
+ buf[10] = ' '
+ t = h / 10
+ buf[11] = byte('0' + t)
+ buf[12] = byte('0' + h - t*10)
+ buf[13] = ':'
+ t = mi / 10
+ buf[14] = byte('0' + t)
+ buf[15] = byte('0' + mi - t*10)
+ buf[16] = ':'
+ t = s / 10
+ buf[17] = byte('0' + t)
+ buf[18] = byte('0' + s - t*10)
+ buf[19] = ' '
+ msg = string(buf[0:]) + msg + "\n"
+
+ if w.Rotate {
+ if w.needRotate(len(msg), d) {
+ w.Lock()
+ if w.needRotate(len(msg), d) {
+ if err := w.doRotate(); err != nil {
+ fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
+ }
+
+ }
+ w.Unlock()
+ }
+ }
+
+ w.Lock()
+ _, err := w.fileWriter.Write([]byte(msg))
+ if err == nil {
+ w.maxLinesCurLines++
+ w.maxSizeCurSize += len(msg)
+ }
+ w.Unlock()
+ return err
}
-func (w *FileLogWriter) createLogFile() (*os.File, error) {
+func (w *fileLogWriter) createLogFile() (*os.File, error) {
// Open the log file
- fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0660)
+ fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, w.Perm)
return fd, err
}
-func (w *FileLogWriter) initFd() error {
- fd := w.mw.fd
- finfo, err := fd.Stat()
+func (w *fileLogWriter) initFd() error {
+ fd := w.fileWriter
+ fInfo, err := fd.Stat()
if err != nil {
return fmt.Errorf("get stat err: %s\n", err)
}
- w.maxsize_cursize = int(finfo.Size())
- w.daily_opendate = time.Now().Day()
- w.maxlines_curlines = 0
- if finfo.Size() > 0 {
+ w.maxSizeCurSize = int(fInfo.Size())
+ w.dailyOpenDate = time.Now().Day()
+ w.maxLinesCurLines = 0
+ if fInfo.Size() > 0 {
count, err := w.lines()
if err != nil {
return err
}
- w.maxlines_curlines = count
+ w.maxLinesCurLines = count
}
return nil
}
-func (w *FileLogWriter) lines() (int, error) {
+func (w *fileLogWriter) lines() (int, error) {
fd, err := os.Open(w.Filename)
if err != nil {
return 0, err
@@ -205,59 +235,60 @@ func (w *FileLogWriter) lines() (int, error) {
}
// DoRotate means it need to write file in new file.
-// new file name like xx.log.2013-01-01.2
-func (w *FileLogWriter) DoRotate() error {
+// new file name like xx.2013-01-01.2.log
+func (w *fileLogWriter) doRotate() error {
_, err := os.Lstat(w.Filename)
- if err == nil { // file exists
- // Find the next available number
- num := 1
- fname := ""
- for ; err == nil && num <= 999; num++ {
- fname = w.Filename + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), num)
- _, err = os.Lstat(fname)
- }
- // return error if the last file checked still existed
- if err == nil {
- return fmt.Errorf("Rotate: Cannot find free log number to rename %s\n", w.Filename)
- }
-
- // block Logger's io.Writer
- w.mw.Lock()
- defer w.mw.Unlock()
-
- fd := w.mw.fd
- fd.Close()
-
- // close fd before rename
- // Rename the file to its newfound home
- err = os.Rename(w.Filename, fname)
- if err != nil {
- return fmt.Errorf("Rotate: %s\n", err)
- }
-
- // re-start logger
- err = w.startLogger()
- if err != nil {
- return fmt.Errorf("Rotate StartLogger: %s\n", err)
- }
-
- go w.deleteOldLog()
+ if err != nil {
+ return err
+ }
+ // file exists
+ // Find the next available number
+ num := 1
+ fName := ""
+ suffix := filepath.Ext(w.Filename)
+ filenameOnly := strings.TrimSuffix(w.Filename, suffix)
+ if suffix == "" {
+ suffix = ".log"
+ }
+ for ; err == nil && num <= 999; num++ {
+ fName = filenameOnly + fmt.Sprintf(".%s.%03d%s", time.Now().Format("2006-01-02"), num, suffix)
+ _, err = os.Lstat(fName)
+ }
+ // return error if the last file checked still existed
+ if err == nil {
+ return fmt.Errorf("Rotate: Cannot find free log number to rename %s\n", w.Filename)
}
+ // close fileWriter before rename
+ w.fileWriter.Close()
+
+ // Rename the file to its new found name
+ // even if occurs error,we MUST guarantee to restart new logger
+ renameErr := os.Rename(w.Filename, fName)
+ // re-start logger
+ startLoggerErr := w.startLogger()
+ go w.deleteOldLog()
+
+ if startLoggerErr != nil {
+ return fmt.Errorf("Rotate StartLogger: %s\n", startLoggerErr)
+ }
+ if renameErr != nil {
+ return fmt.Errorf("Rotate: %s\n", renameErr)
+ }
return nil
+
}
-func (w *FileLogWriter) deleteOldLog() {
+func (w *fileLogWriter) deleteOldLog() {
dir := filepath.Dir(w.Filename)
filepath.Walk(dir, func(path string, info os.FileInfo, err error) (returnErr error) {
defer func() {
if r := recover(); r != nil {
- returnErr = fmt.Errorf("Unable to delete old log '%s', error: %+v", path, r)
- fmt.Println(returnErr)
+ fmt.Fprintf(os.Stderr, "Unable to delete old log '%s', error: %v\n", path, r)
}
}()
- if !info.IsDir() && info.ModTime().Unix() < (time.Now().Unix()-60*60*24*w.Maxdays) {
+ if !info.IsDir() && info.ModTime().Unix() < (time.Now().Unix()-60*60*24*w.MaxDays) {
if strings.HasPrefix(filepath.Base(path), filepath.Base(w.Filename)) {
os.Remove(path)
}
@@ -266,18 +297,18 @@ func (w *FileLogWriter) deleteOldLog() {
})
}
-// destroy file logger, close file writer.
-func (w *FileLogWriter) Destroy() {
- w.mw.fd.Close()
+// Destroy close the file description, close file writer.
+func (w *fileLogWriter) Destroy() {
+ w.fileWriter.Close()
}
-// flush file logger.
+// Flush flush file logger.
// there are no buffering messages in file logger in memory.
// flush file means sync file from disk.
-func (w *FileLogWriter) Flush() {
- w.mw.fd.Sync()
+func (w *fileLogWriter) Flush() {
+ w.fileWriter.Sync()
}
func init() {
- Register("file", NewFileWriter)
+ Register("file", newFileWriter)
}
diff --git a/logs/file_test.go b/logs/file_test.go
index c71e9bb4..1fa6cdaa 100644
--- a/logs/file_test.go
+++ b/logs/file_test.go
@@ -23,7 +23,7 @@ import (
"time"
)
-func TestFile(t *testing.T) {
+func TestFile1(t *testing.T) {
log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test.log"}`)
log.Debug("debug")
@@ -34,25 +34,24 @@ func TestFile(t *testing.T) {
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
- time.Sleep(time.Second * 4)
f, err := os.Open("test.log")
if err != nil {
t.Fatal(err)
}
b := bufio.NewReader(f)
- linenum := 0
+ lineNum := 0
for {
line, _, err := b.ReadLine()
if err != nil {
break
}
if len(line) > 0 {
- linenum++
+ lineNum++
}
}
var expected = LevelDebug + 1
- if linenum != expected {
- t.Fatal(linenum, "not "+strconv.Itoa(expected)+" lines")
+ if lineNum != expected {
+ t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines")
}
os.Remove("test.log")
}
@@ -68,25 +67,24 @@ func TestFile2(t *testing.T) {
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
- time.Sleep(time.Second * 4)
f, err := os.Open("test2.log")
if err != nil {
t.Fatal(err)
}
b := bufio.NewReader(f)
- linenum := 0
+ lineNum := 0
for {
line, _, err := b.ReadLine()
if err != nil {
break
}
if len(line) > 0 {
- linenum++
+ lineNum++
}
}
var expected = LevelError + 1
- if linenum != expected {
- t.Fatal(linenum, "not "+strconv.Itoa(expected)+" lines")
+ if lineNum != expected {
+ t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines")
}
os.Remove("test2.log")
}
@@ -102,13 +100,13 @@ func TestFileRotate(t *testing.T) {
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
- time.Sleep(time.Second * 4)
- rotatename := "test3.log" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1)
- b, err := exists(rotatename)
+ rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log"
+ b, err := exists(rotateName)
if !b || err != nil {
+ os.Remove("test3.log")
t.Fatal("rotate not generated")
}
- os.Remove(rotatename)
+ os.Remove(rotateName)
os.Remove("test3.log")
}
@@ -131,3 +129,45 @@ func BenchmarkFile(b *testing.B) {
}
os.Remove("test4.log")
}
+
+func BenchmarkFileAsynchronous(b *testing.B) {
+ log := NewLogger(100000)
+ log.SetLogger("file", `{"filename":"test4.log"}`)
+ log.Async()
+ for i := 0; i < b.N; i++ {
+ log.Debug("debug")
+ }
+ os.Remove("test4.log")
+}
+
+func BenchmarkFileCallDepth(b *testing.B) {
+ log := NewLogger(100000)
+ log.SetLogger("file", `{"filename":"test4.log"}`)
+ log.EnableFuncCallDepth(true)
+ log.SetLogFuncCallDepth(2)
+ for i := 0; i < b.N; i++ {
+ log.Debug("debug")
+ }
+ os.Remove("test4.log")
+}
+
+func BenchmarkFileAsynchronousCallDepth(b *testing.B) {
+ log := NewLogger(100000)
+ log.SetLogger("file", `{"filename":"test4.log"}`)
+ log.EnableFuncCallDepth(true)
+ log.SetLogFuncCallDepth(2)
+ log.Async()
+ for i := 0; i < b.N; i++ {
+ log.Debug("debug")
+ }
+ os.Remove("test4.log")
+}
+
+func BenchmarkFileOnGoroutine(b *testing.B) {
+ log := NewLogger(100000)
+ log.SetLogger("file", `{"filename":"test4.log"}`)
+ for i := 0; i < b.N; i++ {
+ go log.Debug("debug")
+ }
+ os.Remove("test4.log")
+}
diff --git a/logs/log.go b/logs/log.go
index 32e0187c..ccaaa3ad 100644
--- a/logs/log.go
+++ b/logs/log.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package logs provide a general log interface
// Usage:
//
// import "github.com/astaxie/beego/logs"
@@ -34,8 +35,10 @@ package logs
import (
"fmt"
+ "os"
"path"
"runtime"
+ "strconv"
"sync"
)
@@ -60,10 +63,10 @@ const (
LevelWarn = LevelWarning
)
-type loggerType func() LoggerInterface
+type loggerType func() Logger
-// LoggerInterface defines the behavior of a log provider.
-type LoggerInterface interface {
+// Logger defines the behavior of a log provider.
+type Logger interface {
Init(config string) error
WriteMsg(msg string, level int) error
Destroy()
@@ -92,8 +95,14 @@ type BeeLogger struct {
level int
enableFuncCallDepth bool
loggerFuncCallDepth int
- msg chan *logMsg
- outputs map[string]LoggerInterface
+ asynchronous bool
+ msgChan chan *logMsg
+ outputs []*nameLogger
+}
+
+type nameLogger struct {
+ Logger
+ name string
}
type logMsg struct {
@@ -101,90 +110,117 @@ type logMsg struct {
msg string
}
+var logMsgPool *sync.Pool
+
// NewLogger returns a new BeeLogger.
-// channellen means the number of messages in chan.
+// channelLen means the number of messages in chan(used where asynchronous is true).
// if the buffering chan is full, logger adapters write to file or other way.
-func NewLogger(channellen int64) *BeeLogger {
+func NewLogger(channelLen int64) *BeeLogger {
bl := new(BeeLogger)
bl.level = LevelDebug
bl.loggerFuncCallDepth = 2
- bl.msg = make(chan *logMsg, channellen)
- bl.outputs = make(map[string]LoggerInterface)
- //bl.SetLogger("console", "") // default output to console
+ bl.msgChan = make(chan *logMsg, channelLen)
+ return bl
+}
+
+// Async set the log to asynchronous and start the goroutine
+func (bl *BeeLogger) Async() *BeeLogger {
+ bl.asynchronous = true
+ logMsgPool = &sync.Pool{
+ New: func() interface{} {
+ return &logMsg{}
+ },
+ }
go bl.startLogger()
return bl
}
// SetLogger provides a given logger adapter into BeeLogger with config string.
// config need to be correct JSON as string: {"interval":360}.
-func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
+func (bl *BeeLogger) SetLogger(adapterName string, config string) error {
bl.lock.Lock()
defer bl.lock.Unlock()
- if log, ok := adapters[adaptername]; ok {
+ if log, ok := adapters[adapterName]; ok {
lg := log()
err := lg.Init(config)
- bl.outputs[adaptername] = lg
if err != nil {
- fmt.Println("logs.BeeLogger.SetLogger: " + err.Error())
+ fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error())
return err
}
+ bl.outputs = append(bl.outputs, &nameLogger{name: adapterName, Logger: lg})
} else {
- return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adaptername)
+ return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName)
}
return nil
}
-// remove a logger adapter in BeeLogger.
-func (bl *BeeLogger) DelLogger(adaptername string) error {
+// DelLogger remove a logger adapter in BeeLogger.
+func (bl *BeeLogger) DelLogger(adapterName string) error {
bl.lock.Lock()
defer bl.lock.Unlock()
- if lg, ok := bl.outputs[adaptername]; ok {
- lg.Destroy()
- delete(bl.outputs, adaptername)
- return nil
- } else {
- return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adaptername)
- }
-}
-
-func (bl *BeeLogger) writerMsg(loglevel int, msg string) error {
- if loglevel > bl.level {
- return nil
- }
- lm := new(logMsg)
- lm.level = loglevel
- if bl.enableFuncCallDepth {
- _, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth)
- if _, filename := path.Split(file); filename == "log.go" && (line == 97 || line == 83) {
- _, file, line, ok = runtime.Caller(bl.loggerFuncCallDepth + 1)
- }
- if ok {
- _, filename := path.Split(file)
- lm.msg = fmt.Sprintf("[%s:%d] %s", filename, line, msg)
+ outputs := []*nameLogger{}
+ for _, lg := range bl.outputs {
+ if lg.name == adapterName {
+ lg.Destroy()
} else {
- lm.msg = msg
+ outputs = append(outputs, lg)
}
- } else {
- lm.msg = msg
}
- bl.msg <- lm
+ if len(outputs) == len(bl.outputs) {
+ return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName)
+ }
+ bl.outputs = outputs
return nil
}
-// Set log message level.
-//
+func (bl *BeeLogger) writeToLoggers(msg string, level int) {
+ for _, l := range bl.outputs {
+ err := l.WriteMsg(msg, level)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "unable to WriteMsg to adapter:%v,error:%v\n", l.name, err)
+ }
+ }
+}
+
+func (bl *BeeLogger) writeMsg(logLevel int, msg string) error {
+ if bl.enableFuncCallDepth {
+ _, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth)
+ if !ok {
+ file = "???"
+ line = 0
+ }
+ _, filename := path.Split(file)
+ msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "]" + msg
+ }
+ if bl.asynchronous {
+ lm := logMsgPool.Get().(*logMsg)
+ lm.level = logLevel
+ lm.msg = msg
+ bl.msgChan <- lm
+ } else {
+ bl.writeToLoggers(msg, logLevel)
+ }
+ return nil
+}
+
+// SetLevel Set log message level.
// If message level (such as LevelDebug) is higher than logger level (such as LevelWarning),
// log providers will not even be sent the message.
func (bl *BeeLogger) SetLevel(l int) {
bl.level = l
}
-// set log funcCallDepth
+// SetLogFuncCallDepth set log funcCallDepth
func (bl *BeeLogger) SetLogFuncCallDepth(d int) {
bl.loggerFuncCallDepth = d
}
-// enable log funcCallDepth
+// GetLogFuncCallDepth return log funcCallDepth for wrapper
+func (bl *BeeLogger) GetLogFuncCallDepth() int {
+ return bl.loggerFuncCallDepth
+}
+
+// EnableFuncCallDepth enable log funcCallDepth
func (bl *BeeLogger) EnableFuncCallDepth(b bool) {
bl.enableFuncCallDepth = b
}
@@ -194,104 +230,129 @@ func (bl *BeeLogger) EnableFuncCallDepth(b bool) {
func (bl *BeeLogger) startLogger() {
for {
select {
- case bm := <-bl.msg:
- for _, l := range bl.outputs {
- err := l.WriteMsg(bm.msg, bm.level)
- if err != nil {
- fmt.Println("ERROR, unable to WriteMsg:", err)
- }
- }
+ case bm := <-bl.msgChan:
+ bl.writeToLoggers(bm.msg, bm.level)
+ logMsgPool.Put(bm)
}
}
}
-// Log EMERGENCY level message.
+// Emergency Log EMERGENCY level message.
func (bl *BeeLogger) Emergency(format string, v ...interface{}) {
+ if LevelEmergency > bl.level {
+ return
+ }
msg := fmt.Sprintf("[M] "+format, v...)
- bl.writerMsg(LevelEmergency, msg)
+ bl.writeMsg(LevelEmergency, msg)
}
-// Log ALERT level message.
+// Alert Log ALERT level message.
func (bl *BeeLogger) Alert(format string, v ...interface{}) {
+ if LevelAlert > bl.level {
+ return
+ }
msg := fmt.Sprintf("[A] "+format, v...)
- bl.writerMsg(LevelAlert, msg)
+ bl.writeMsg(LevelAlert, msg)
}
-// Log CRITICAL level message.
+// Critical Log CRITICAL level message.
func (bl *BeeLogger) Critical(format string, v ...interface{}) {
+ if LevelCritical > bl.level {
+ return
+ }
msg := fmt.Sprintf("[C] "+format, v...)
- bl.writerMsg(LevelCritical, msg)
+ bl.writeMsg(LevelCritical, msg)
}
-// Log ERROR level message.
+// Error Log ERROR level message.
func (bl *BeeLogger) Error(format string, v ...interface{}) {
+ if LevelError > bl.level {
+ return
+ }
msg := fmt.Sprintf("[E] "+format, v...)
- bl.writerMsg(LevelError, msg)
+ bl.writeMsg(LevelError, msg)
}
-// Log WARNING level message.
+// Warning Log WARNING level message.
func (bl *BeeLogger) Warning(format string, v ...interface{}) {
+ if LevelWarning > bl.level {
+ return
+ }
msg := fmt.Sprintf("[W] "+format, v...)
- bl.writerMsg(LevelWarning, msg)
+ bl.writeMsg(LevelWarning, msg)
}
-// Log NOTICE level message.
+// Notice Log NOTICE level message.
func (bl *BeeLogger) Notice(format string, v ...interface{}) {
+ if LevelNotice > bl.level {
+ return
+ }
msg := fmt.Sprintf("[N] "+format, v...)
- bl.writerMsg(LevelNotice, msg)
+ bl.writeMsg(LevelNotice, msg)
}
-// Log INFORMATIONAL level message.
+// Informational Log INFORMATIONAL level message.
func (bl *BeeLogger) Informational(format string, v ...interface{}) {
+ if LevelInformational > bl.level {
+ return
+ }
msg := fmt.Sprintf("[I] "+format, v...)
- bl.writerMsg(LevelInformational, msg)
+ bl.writeMsg(LevelInformational, msg)
}
-// Log DEBUG level message.
+// Debug Log DEBUG level message.
func (bl *BeeLogger) Debug(format string, v ...interface{}) {
+ if LevelDebug > bl.level {
+ return
+ }
msg := fmt.Sprintf("[D] "+format, v...)
- bl.writerMsg(LevelDebug, msg)
+ bl.writeMsg(LevelDebug, msg)
}
-// Log WARN level message.
-//
-// Deprecated: compatibility alias for Warning(), Will be removed in 1.5.0.
+// Warn Log WARN level message.
+// compatibility alias for Warning()
func (bl *BeeLogger) Warn(format string, v ...interface{}) {
- bl.Warning(format, v...)
+ if LevelWarning > bl.level {
+ return
+ }
+ msg := fmt.Sprintf("[W] "+format, v...)
+ bl.writeMsg(LevelWarning, msg)
}
-// Log INFO level message.
-//
-// Deprecated: compatibility alias for Informational(), Will be removed in 1.5.0.
+// Info Log INFO level message.
+// compatibility alias for Informational()
func (bl *BeeLogger) Info(format string, v ...interface{}) {
- bl.Informational(format, v...)
+ if LevelInformational > bl.level {
+ return
+ }
+ msg := fmt.Sprintf("[I] "+format, v...)
+ bl.writeMsg(LevelInformational, msg)
}
-// Log TRACE level message.
-//
-// Deprecated: compatibility alias for Debug(), Will be removed in 1.5.0.
+// Trace Log TRACE level message.
+// compatibility alias for Debug()
func (bl *BeeLogger) Trace(format string, v ...interface{}) {
- bl.Debug(format, v...)
+ if LevelDebug > bl.level {
+ return
+ }
+ msg := fmt.Sprintf("[D] "+format, v...)
+ bl.writeMsg(LevelDebug, msg)
}
-// flush all chan data.
+// Flush flush all chan data.
func (bl *BeeLogger) Flush() {
for _, l := range bl.outputs {
l.Flush()
}
}
-// close logger, flush all chan data and destroy all adapters in BeeLogger.
+// Close close logger, flush all chan data and destroy all adapters in BeeLogger.
func (bl *BeeLogger) Close() {
for {
- if len(bl.msg) > 0 {
- bm := <-bl.msg
- for _, l := range bl.outputs {
- err := l.WriteMsg(bm.msg, bm.level)
- if err != nil {
- fmt.Println("ERROR, unable to WriteMsg (while closing logger):", err)
- }
- }
+ if len(bl.msgChan) > 0 {
+ bm := <-bl.msgChan
+ bl.writeToLoggers(bm.msg, bm.level)
+ logMsgPool.Put(bm)
continue
}
break
diff --git a/logs/smtp.go b/logs/smtp.go
index 95123ebf..748462f9 100644
--- a/logs/smtp.go
+++ b/logs/smtp.go
@@ -24,31 +24,26 @@ import (
"time"
)
-const (
-// no usage
-// subjectPhrase = "Diagnostic message from server"
-)
-
-// smtpWriter implements LoggerInterface and is used to send emails via given SMTP-server.
-type SmtpWriter struct {
- Username string `json:"Username"`
+// SMTPWriter implements LoggerInterface and is used to send emails via given SMTP-server.
+type SMTPWriter struct {
+ Username string `json:"username"`
Password string `json:"password"`
- Host string `json:"Host"`
+ Host string `json:"host"`
Subject string `json:"subject"`
FromAddress string `json:"fromAddress"`
RecipientAddresses []string `json:"sendTos"`
Level int `json:"level"`
}
-// create smtp writer.
-func NewSmtpWriter() LoggerInterface {
- return &SmtpWriter{Level: LevelTrace}
+// NewSMTPWriter create smtp writer.
+func newSMTPWriter() Logger {
+ return &SMTPWriter{Level: LevelTrace}
}
-// init smtp writer with json config.
+// Init smtp writer with json config.
// config like:
// {
-// "Username":"example@gmail.com",
+// "username":"example@gmail.com",
// "password:"password",
// "host":"smtp.gmail.com:465",
// "subject":"email title",
@@ -56,7 +51,7 @@ func NewSmtpWriter() LoggerInterface {
// "sendTos":["email1","email2"],
// "level":LevelError
// }
-func (s *SmtpWriter) Init(jsonconfig string) error {
+func (s *SMTPWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), s)
if err != nil {
return err
@@ -64,7 +59,7 @@ func (s *SmtpWriter) Init(jsonconfig string) error {
return nil
}
-func (s *SmtpWriter) GetSmtpAuth(host string) smtp.Auth {
+func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth {
if len(strings.Trim(s.Username, " ")) == 0 && len(strings.Trim(s.Password, " ")) == 0 {
return nil
}
@@ -76,7 +71,7 @@ func (s *SmtpWriter) GetSmtpAuth(host string) smtp.Auth {
)
}
-func (s *SmtpWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error {
+func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error {
client, err := smtp.Dial(hostAddressWithPort)
if err != nil {
return err
@@ -129,9 +124,9 @@ func (s *SmtpWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd
return nil
}
-// write message in smtp writer.
+// WriteMsg write message in smtp writer.
// it will send an email with subject and only this message.
-func (s *SmtpWriter) WriteMsg(msg string, level int) error {
+func (s *SMTPWriter) WriteMsg(msg string, level int) error {
if level > s.Level {
return nil
}
@@ -139,27 +134,27 @@ func (s *SmtpWriter) WriteMsg(msg string, level int) error {
hp := strings.Split(s.Host, ":")
// Set up authentication information.
- auth := s.GetSmtpAuth(hp[0])
+ auth := s.getSMTPAuth(hp[0])
// Connect to the server, authenticate, set the sender and recipient,
// and send the email all in one step.
- content_type := "Content-Type: text/plain" + "; charset=UTF-8"
+ contentType := "Content-Type: text/plain" + "; charset=UTF-8"
mailmsg := []byte("To: " + strings.Join(s.RecipientAddresses, ";") + "\r\nFrom: " + s.FromAddress + "<" + s.FromAddress +
- ">\r\nSubject: " + s.Subject + "\r\n" + content_type + "\r\n\r\n" + fmt.Sprintf(".%s", time.Now().Format("2006-01-02 15:04:05")) + msg)
+ ">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", time.Now().Format("2006-01-02 15:04:05")) + msg)
return s.sendMail(s.Host, auth, s.FromAddress, s.RecipientAddresses, mailmsg)
}
-// implementing method. empty.
-func (s *SmtpWriter) Flush() {
+// Flush implementing method. empty.
+func (s *SMTPWriter) Flush() {
return
}
-// implementing method. empty.
-func (s *SmtpWriter) Destroy() {
+// Destroy implementing method. empty.
+func (s *SMTPWriter) Destroy() {
return
}
func init() {
- Register("smtp", NewSmtpWriter)
+ Register("smtp", newSMTPWriter)
}
diff --git a/memzipfile.go b/memzipfile.go
deleted file mode 100644
index cc5e3851..00000000
--- a/memzipfile.go
+++ /dev/null
@@ -1,212 +0,0 @@
-// Copyright 2014 beego Author. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package beego
-
-import (
- "bytes"
- "compress/flate"
- "compress/gzip"
- "errors"
- "io"
- "io/ioutil"
- "net/http"
- "os"
- "strings"
- "sync"
- "time"
-)
-
-var gmfim map[string]*memFileInfo = make(map[string]*memFileInfo)
-var lock sync.RWMutex
-
-// OpenMemZipFile returns MemFile object with a compressed static file.
-// it's used for serve static file if gzip enable.
-func openMemZipFile(path string, zip string) (*memFile, error) {
- osfile, e := os.Open(path)
- if e != nil {
- return nil, e
- }
- defer osfile.Close()
-
- osfileinfo, e := osfile.Stat()
- if e != nil {
- return nil, e
- }
-
- modtime := osfileinfo.ModTime()
- fileSize := osfileinfo.Size()
- lock.RLock()
- cfi, ok := gmfim[zip+":"+path]
- lock.RUnlock()
- if !(ok && cfi.ModTime() == modtime && cfi.fileSize == fileSize) {
- var content []byte
- if zip == "gzip" {
- var zipbuf bytes.Buffer
- gzipwriter, e := gzip.NewWriterLevel(&zipbuf, gzip.BestCompression)
- if e != nil {
- return nil, e
- }
- _, e = io.Copy(gzipwriter, osfile)
- gzipwriter.Close()
- if e != nil {
- return nil, e
- }
- content, e = ioutil.ReadAll(&zipbuf)
- if e != nil {
- return nil, e
- }
- } else if zip == "deflate" {
- var zipbuf bytes.Buffer
- deflatewriter, e := flate.NewWriter(&zipbuf, flate.BestCompression)
- if e != nil {
- return nil, e
- }
- _, e = io.Copy(deflatewriter, osfile)
- deflatewriter.Close()
- if e != nil {
- return nil, e
- }
- content, e = ioutil.ReadAll(&zipbuf)
- if e != nil {
- return nil, e
- }
- } else {
- content, e = ioutil.ReadAll(osfile)
- if e != nil {
- return nil, e
- }
- }
-
- cfi = &memFileInfo{osfileinfo, modtime, content, int64(len(content)), fileSize}
- lock.Lock()
- defer lock.Unlock()
- gmfim[zip+":"+path] = cfi
- }
- return &memFile{fi: cfi, offset: 0}, nil
-}
-
-// MemFileInfo contains a compressed file bytes and file information.
-// it implements os.FileInfo interface.
-type memFileInfo struct {
- os.FileInfo
- modTime time.Time
- content []byte
- contentSize int64
- fileSize int64
-}
-
-// Name returns the compressed filename.
-func (fi *memFileInfo) Name() string {
- return fi.Name()
-}
-
-// Size returns the raw file content size, not compressed size.
-func (fi *memFileInfo) Size() int64 {
- return fi.contentSize
-}
-
-// Mode returns file mode.
-func (fi *memFileInfo) Mode() os.FileMode {
- return fi.Mode()
-}
-
-// ModTime returns the last modified time of raw file.
-func (fi *memFileInfo) ModTime() time.Time {
- return fi.modTime
-}
-
-// IsDir returns the compressing file is a directory or not.
-func (fi *memFileInfo) IsDir() bool {
- return fi.IsDir()
-}
-
-// return nil. implement the os.FileInfo interface method.
-func (fi *memFileInfo) Sys() interface{} {
- return nil
-}
-
-// MemFile contains MemFileInfo and bytes offset when reading.
-// it implements io.Reader,io.ReadCloser and io.Seeker.
-type memFile struct {
- fi *memFileInfo
- offset int64
-}
-
-// Close memfile.
-func (f *memFile) Close() error {
- return nil
-}
-
-// Get os.FileInfo of memfile.
-func (f *memFile) Stat() (os.FileInfo, error) {
- return f.fi, nil
-}
-
-// read os.FileInfo of files in directory of memfile.
-// it returns empty slice.
-func (f *memFile) Readdir(count int) ([]os.FileInfo, error) {
- infos := []os.FileInfo{}
-
- return infos, nil
-}
-
-// Read bytes from the compressed file bytes.
-func (f *memFile) Read(p []byte) (n int, err error) {
- if len(f.fi.content)-int(f.offset) >= len(p) {
- n = len(p)
- } else {
- n = len(f.fi.content) - int(f.offset)
- err = io.EOF
- }
- copy(p, f.fi.content[f.offset:f.offset+int64(n)])
- f.offset += int64(n)
- return
-}
-
-var errWhence = errors.New("Seek: invalid whence")
-var errOffset = errors.New("Seek: invalid offset")
-
-// Read bytes from the compressed file bytes by seeker.
-func (f *memFile) Seek(offset int64, whence int) (ret int64, err error) {
- switch whence {
- default:
- return 0, errWhence
- case os.SEEK_SET:
- case os.SEEK_CUR:
- offset += f.offset
- case os.SEEK_END:
- offset += int64(len(f.fi.content))
- }
- if offset < 0 || int(offset) > len(f.fi.content) {
- return 0, errOffset
- }
- f.offset = offset
- return f.offset, nil
-}
-
-// GetAcceptEncodingZip returns accept encoding format in http header.
-// zip is first, then deflate if both accepted.
-// If no accepted, return empty string.
-func getAcceptEncodingZip(r *http.Request) string {
- ss := r.Header.Get("Accept-Encoding")
- ss = strings.ToLower(ss)
- if strings.Contains(ss, "gzip") {
- return "gzip"
- } else if strings.Contains(ss, "deflate") {
- return "deflate"
- } else {
- return ""
- }
-}
diff --git a/middleware/i18n.go b/middleware/i18n.go
deleted file mode 100644
index f54b4bb5..00000000
--- a/middleware/i18n.go
+++ /dev/null
@@ -1,71 +0,0 @@
-// Copyright 2014 beego Author. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Usage:
-//
-// import "github.com/astaxie/beego/middleware"
-//
-// I18N = middleware.NewLocale("conf/i18n.conf", beego.AppConfig.String("language"))
-//
-// more docs: http://beego.me/docs/module/i18n.md
-package middleware
-
-import (
- "encoding/json"
- "io/ioutil"
- "os"
-)
-
-type Translation struct {
- filepath string
- CurrentLocal string
- Locales map[string]map[string]string
-}
-
-func NewLocale(filepath string, defaultlocal string) *Translation {
- file, err := os.Open(filepath)
- if err != nil {
- panic("open " + filepath + " err :" + err.Error())
- }
- data, err := ioutil.ReadAll(file)
- if err != nil {
- panic("read " + filepath + " err :" + err.Error())
- }
-
- i18n := make(map[string]map[string]string)
- if err = json.Unmarshal(data, &i18n); err != nil {
- panic("json.Unmarshal " + filepath + " err :" + err.Error())
- }
- return &Translation{
- filepath: filepath,
- CurrentLocal: defaultlocal,
- Locales: i18n,
- }
-}
-
-func (t *Translation) SetLocale(local string) {
- t.CurrentLocal = local
-}
-
-func (t *Translation) Translate(key string, local string) string {
- if local == "" {
- local = t.CurrentLocal
- }
- if ct, ok := t.Locales[key]; ok {
- if v, o := ct[local]; o {
- return v
- }
- }
- return key
-}
diff --git a/migration/ddl.go b/migration/ddl.go
index f9b60117..51243337 100644
--- a/migration/ddl.go
+++ b/migration/ddl.go
@@ -14,33 +14,40 @@
package migration
+// Table store the tablename and Column
type Table struct {
TableName string
Columns []*Column
}
+// Create return the create sql
func (t *Table) Create() string {
return ""
}
+// Drop return the drop sql
func (t *Table) Drop() string {
return ""
}
+// Column define the columns name type and Default
type Column struct {
Name string
Type string
Default interface{}
}
+// Create return create sql with the provided tbname and columns
func Create(tbname string, columns ...Column) string {
return ""
}
+// Drop return the drop sql with the provided tbname and columns
func Drop(tbname string, columns ...Column) string {
return ""
}
+// TableDDL is still in think
func TableDDL(tbname string, columns ...Column) string {
return ""
}
diff --git a/migration/migration.go b/migration/migration.go
index d64d60d3..1591bc50 100644
--- a/migration/migration.go
+++ b/migration/migration.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// migration package for migration
+// Package migration is used for migration
//
// The table structure is as follow:
//
@@ -39,8 +39,8 @@ import (
// const the data format for the bee generate migration datatype
const (
- M_DATE_FORMAT = "20060102_150405"
- M_DB_DATE_FORMAT = "2006-01-02 15:04:05"
+ DateFormat = "20060102_150405"
+ DBDateFormat = "2006-01-02 15:04:05"
)
// Migrationer is an interface for all Migration struct
@@ -60,24 +60,24 @@ func init() {
migrationMap = make(map[string]Migrationer)
}
-// the basic type which will implement the basic type
+// Migration the basic type which will implement the basic type
type Migration struct {
sqls []string
Created string
}
-// implement in the Inheritance struct for upgrade
+// Up implement in the Inheritance struct for upgrade
func (m *Migration) Up() {
}
-// implement in the Inheritance struct for down
+// Down implement in the Inheritance struct for down
func (m *Migration) Down() {
}
-// add sql want to execute
-func (m *Migration) Sql(sql string) {
+// SQL add sql want to execute
+func (m *Migration) SQL(sql string) {
m.sqls = append(m.sqls, sql)
}
@@ -86,7 +86,7 @@ func (m *Migration) Reset() {
m.sqls = make([]string, 0)
}
-// execute the sql already add in the sql
+// Exec execute the sql already add in the sql
func (m *Migration) Exec(name, status string) error {
o := orm.NewOrm()
for _, s := range m.sqls {
@@ -104,33 +104,32 @@ func (m *Migration) addOrUpdateRecord(name, status string) error {
o := orm.NewOrm()
if status == "down" {
status = "rollback"
- p, err := o.Raw("update migrations set `status` = ?, `rollback_statements` = ?, `created_at` = ? where name = ?").Prepare()
+ p, err := o.Raw("update migrations set status = ?, rollback_statements = ?, created_at = ? where name = ?").Prepare()
if err != nil {
return nil
}
- _, err = p.Exec(status, strings.Join(m.sqls, "; "), time.Now().Format(M_DB_DATE_FORMAT), name)
- return err
- } else {
- status = "update"
- p, err := o.Raw("insert into migrations(`name`, `created_at`, `statements`, `status`) values(?,?,?,?)").Prepare()
- if err != nil {
- return err
- }
- _, err = p.Exec(name, time.Now().Format(M_DB_DATE_FORMAT), strings.Join(m.sqls, "; "), status)
+ _, err = p.Exec(status, strings.Join(m.sqls, "; "), time.Now().Format(DBDateFormat), name)
return err
}
+ status = "update"
+ p, err := o.Raw("insert into migrations(name, created_at, statements, status) values(?,?,?,?)").Prepare()
+ if err != nil {
+ return err
+ }
+ _, err = p.Exec(name, time.Now().Format(DBDateFormat), strings.Join(m.sqls, "; "), status)
+ return err
}
-// get the unixtime from the Created
+// GetCreated get the unixtime from the Created
func (m *Migration) GetCreated() int64 {
- t, err := time.Parse(M_DATE_FORMAT, m.Created)
+ t, err := time.Parse(DateFormat, m.Created)
if err != nil {
return 0
}
return t.Unix()
}
-// register the Migration in the map
+// Register register the Migration in the map
func Register(name string, m Migrationer) error {
if _, ok := migrationMap[name]; ok {
return errors.New("already exist name:" + name)
@@ -139,7 +138,7 @@ func Register(name string, m Migrationer) error {
return nil
}
-// upgrate the migration from lasttime
+// Upgrade upgrate the migration from lasttime
func Upgrade(lasttime int64) error {
sm := sortMap(migrationMap)
i := 0
@@ -163,7 +162,7 @@ func Upgrade(lasttime int64) error {
return nil
}
-//rollback the migration by the name
+// Rollback rollback the migration by the name
func Rollback(name string) error {
if v, ok := migrationMap[name]; ok {
beego.Info("start rollback")
@@ -178,14 +177,13 @@ func Rollback(name string) error {
beego.Info("end rollback")
time.Sleep(2 * time.Second)
return nil
- } else {
- beego.Error("not exist the migrationMap name:" + name)
- time.Sleep(2 * time.Second)
- return errors.New("not exist the migrationMap name:" + name)
}
+ beego.Error("not exist the migrationMap name:" + name)
+ time.Sleep(2 * time.Second)
+ return errors.New("not exist the migrationMap name:" + name)
}
-// reset all migration
+// Reset reset all migration
// run all migration's down function
func Reset() error {
sm := sortMap(migrationMap)
@@ -214,7 +212,7 @@ func Reset() error {
return nil
}
-// first Reset, then Upgrade
+// Refresh first Reset, then Upgrade
func Refresh() error {
err := Reset()
if err != nil {
diff --git a/mime.go b/mime.go
index 155e5e12..e85fcb2a 100644
--- a/mime.go
+++ b/mime.go
@@ -14,11 +14,7 @@
package beego
-import (
- "mime"
-)
-
-var mimemaps map[string]string = map[string]string{
+var mimemaps = map[string]string{
".3dm": "x-world/x-3dmf",
".3dmf": "x-world/x-3dmf",
".7z": "application/x-7z-compressed",
@@ -40,6 +36,7 @@ var mimemaps map[string]string = map[string]string{
".ani": "application/x-navi-animation",
".aos": "application/x-nokia-9000-communicator-add-on-software",
".aps": "application/mime",
+ ".apk": "application/vnd.android.package-archive",
".arc": "application/x-arc-compressed",
".arj": "application/arj",
".art": "image/x-jg",
@@ -557,10 +554,3 @@ var mimemaps map[string]string = map[string]string{
".oex": "application/x-opera-extension",
".mustache": "text/html",
}
-
-func initMime() error {
- for k, v := range mimemaps {
- mime.AddExtensionType(k, v)
- }
- return nil
-}
diff --git a/namespace.go b/namespace.go
index ebb7c14f..0dfdd7af 100644
--- a/namespace.go
+++ b/namespace.go
@@ -23,16 +23,17 @@ import (
type namespaceCond func(*beecontext.Context) bool
-type innnerNamespace func(*Namespace)
+// LinkNamespace used as link action
+type LinkNamespace func(*Namespace)
// Namespace is store all the info
type Namespace struct {
prefix string
- handlers *ControllerRegistor
+ handlers *ControllerRegister
}
-// get new Namespace
-func NewNamespace(prefix string, params ...innnerNamespace) *Namespace {
+// NewNamespace get new Namespace
+func NewNamespace(prefix string, params ...LinkNamespace) *Namespace {
ns := &Namespace{
prefix: prefix,
handlers: NewControllerRegister(),
@@ -43,7 +44,7 @@ func NewNamespace(prefix string, params ...innnerNamespace) *Namespace {
return ns
}
-// set condtion function
+// Cond set condtion function
// if cond return true can run this namespace, else can't
// usage:
// ns.Cond(func (ctx *context.Context) bool{
@@ -72,7 +73,7 @@ func (n *Namespace) Cond(cond namespaceCond) *Namespace {
return n
}
-// add filter in the Namespace
+// Filter add filter in the Namespace
// action has before & after
// FilterFunc
// usage:
@@ -95,98 +96,98 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace {
return n
}
-// same as beego.Rourer
+// Router same as beego.Rourer
// refer: https://godoc.org/github.com/astaxie/beego#Router
func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace {
n.handlers.Add(rootpath, c, mappingMethods...)
return n
}
-// same as beego.AutoRouter
+// AutoRouter same as beego.AutoRouter
// refer: https://godoc.org/github.com/astaxie/beego#AutoRouter
func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace {
n.handlers.AddAuto(c)
return n
}
-// same as beego.AutoPrefix
+// AutoPrefix same as beego.AutoPrefix
// refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix
func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace {
n.handlers.AddAutoPrefix(prefix, c)
return n
}
-// same as beego.Get
+// Get same as beego.Get
// refer: https://godoc.org/github.com/astaxie/beego#Get
func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace {
n.handlers.Get(rootpath, f)
return n
}
-// same as beego.Post
+// Post same as beego.Post
// refer: https://godoc.org/github.com/astaxie/beego#Post
func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace {
n.handlers.Post(rootpath, f)
return n
}
-// same as beego.Delete
+// Delete same as beego.Delete
// refer: https://godoc.org/github.com/astaxie/beego#Delete
func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace {
n.handlers.Delete(rootpath, f)
return n
}
-// same as beego.Put
+// Put same as beego.Put
// refer: https://godoc.org/github.com/astaxie/beego#Put
func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace {
n.handlers.Put(rootpath, f)
return n
}
-// same as beego.Head
+// Head same as beego.Head
// refer: https://godoc.org/github.com/astaxie/beego#Head
func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace {
n.handlers.Head(rootpath, f)
return n
}
-// same as beego.Options
+// Options same as beego.Options
// refer: https://godoc.org/github.com/astaxie/beego#Options
func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace {
n.handlers.Options(rootpath, f)
return n
}
-// same as beego.Patch
+// Patch same as beego.Patch
// refer: https://godoc.org/github.com/astaxie/beego#Patch
func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace {
n.handlers.Patch(rootpath, f)
return n
}
-// same as beego.Any
+// Any same as beego.Any
// refer: https://godoc.org/github.com/astaxie/beego#Any
func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace {
n.handlers.Any(rootpath, f)
return n
}
-// same as beego.Handler
+// Handler same as beego.Handler
// refer: https://godoc.org/github.com/astaxie/beego#Handler
func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace {
n.handlers.Handler(rootpath, h)
return n
}
-// add include class
+// Include add include class
// refer: https://godoc.org/github.com/astaxie/beego#Include
func (n *Namespace) Include(cList ...ControllerInterface) *Namespace {
n.handlers.Include(cList...)
return n
}
-// nest Namespace
+// Namespace add nest Namespace
// usage:
//ns := beego.NewNamespace(“/v1”).
//Namespace(
@@ -230,7 +231,7 @@ func (n *Namespace) Namespace(ns ...*Namespace) *Namespace {
return n
}
-// register Namespace into beego.Handler
+// AddNamespace register Namespace into beego.Handler
// support multi Namespace
func AddNamespace(nl ...*Namespace) {
for _, n := range nl {
@@ -275,113 +276,113 @@ func addPrefix(t *Tree, prefix string) {
}
-// Namespace Condition
-func NSCond(cond namespaceCond) innnerNamespace {
+// NSCond is Namespace Condition
+func NSCond(cond namespaceCond) LinkNamespace {
return func(ns *Namespace) {
ns.Cond(cond)
}
}
-// Namespace BeforeRouter filter
-func NSBefore(filiterList ...FilterFunc) innnerNamespace {
+// NSBefore Namespace BeforeRouter filter
+func NSBefore(filiterList ...FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Filter("before", filiterList...)
}
}
-// Namespace FinishRouter filter
-func NSAfter(filiterList ...FilterFunc) innnerNamespace {
+// NSAfter add Namespace FinishRouter filter
+func NSAfter(filiterList ...FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Filter("after", filiterList...)
}
}
-// Namespace Include ControllerInterface
-func NSInclude(cList ...ControllerInterface) innnerNamespace {
+// NSInclude Namespace Include ControllerInterface
+func NSInclude(cList ...ControllerInterface) LinkNamespace {
return func(ns *Namespace) {
ns.Include(cList...)
}
}
-// Namespace Router
-func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) innnerNamespace {
+// NSRouter call Namespace Router
+func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace {
return func(ns *Namespace) {
ns.Router(rootpath, c, mappingMethods...)
}
}
-// Namespace Get
-func NSGet(rootpath string, f FilterFunc) innnerNamespace {
+// NSGet call Namespace Get
+func NSGet(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Get(rootpath, f)
}
}
-// Namespace Post
-func NSPost(rootpath string, f FilterFunc) innnerNamespace {
+// NSPost call Namespace Post
+func NSPost(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Post(rootpath, f)
}
}
-// Namespace Head
-func NSHead(rootpath string, f FilterFunc) innnerNamespace {
+// NSHead call Namespace Head
+func NSHead(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Head(rootpath, f)
}
}
-// Namespace Put
-func NSPut(rootpath string, f FilterFunc) innnerNamespace {
+// NSPut call Namespace Put
+func NSPut(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Put(rootpath, f)
}
}
-// Namespace Delete
-func NSDelete(rootpath string, f FilterFunc) innnerNamespace {
+// NSDelete call Namespace Delete
+func NSDelete(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Delete(rootpath, f)
}
}
-// Namespace Any
-func NSAny(rootpath string, f FilterFunc) innnerNamespace {
+// NSAny call Namespace Any
+func NSAny(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Any(rootpath, f)
}
}
-// Namespace Options
-func NSOptions(rootpath string, f FilterFunc) innnerNamespace {
+// NSOptions call Namespace Options
+func NSOptions(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Options(rootpath, f)
}
}
-// Namespace Patch
-func NSPatch(rootpath string, f FilterFunc) innnerNamespace {
+// NSPatch call Namespace Patch
+func NSPatch(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Patch(rootpath, f)
}
}
-//Namespace AutoRouter
-func NSAutoRouter(c ControllerInterface) innnerNamespace {
+// NSAutoRouter call Namespace AutoRouter
+func NSAutoRouter(c ControllerInterface) LinkNamespace {
return func(ns *Namespace) {
ns.AutoRouter(c)
}
}
-// Namespace AutoPrefix
-func NSAutoPrefix(prefix string, c ControllerInterface) innnerNamespace {
+// NSAutoPrefix call Namespace AutoPrefix
+func NSAutoPrefix(prefix string, c ControllerInterface) LinkNamespace {
return func(ns *Namespace) {
ns.AutoPrefix(prefix, c)
}
}
-// Namespace add sub Namespace
-func NSNamespace(prefix string, params ...innnerNamespace) innnerNamespace {
+// NSNamespace add sub Namespace
+func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace {
return func(ns *Namespace) {
n := NewNamespace(prefix, params...)
ns.Namespace(n)
diff --git a/orm/cmd.go b/orm/cmd.go
index 2358ef3c..3638a75c 100644
--- a/orm/cmd.go
+++ b/orm/cmd.go
@@ -46,7 +46,7 @@ func printHelp(errs ...string) {
os.Exit(2)
}
-// listen for orm command and then run it if command arguments passed.
+// RunCommand listen for orm command and then run it if command arguments passed.
func RunCommand() {
if len(os.Args) < 2 || os.Args[1] != "orm" {
return
@@ -100,7 +100,7 @@ func (d *commandSyncDb) Parse(args []string) {
func (d *commandSyncDb) Run() error {
var drops []string
if d.force {
- drops = getDbDropSql(d.al)
+ drops = getDbDropSQL(d.al)
}
db := d.al.DB
@@ -124,7 +124,7 @@ func (d *commandSyncDb) Run() error {
}
}
- sqls, indexes := getDbCreateSql(d.al)
+ sqls, indexes := getDbCreateSQL(d.al)
tables, err := d.al.DbBaser.GetTables(db)
if err != nil {
@@ -180,7 +180,7 @@ func (d *commandSyncDb) Run() error {
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
}
- query := idx.Sql
+ query := idx.SQL
_, err := db.Exec(query)
if d.verbose {
fmt.Printf(" %s\n", query)
@@ -203,7 +203,7 @@ func (d *commandSyncDb) Run() error {
queries := []string{sqls[i]}
for _, idx := range indexes[mi.table] {
- queries = append(queries, idx.Sql)
+ queries = append(queries, idx.SQL)
}
for _, query := range queries {
@@ -228,12 +228,12 @@ func (d *commandSyncDb) Run() error {
}
// database creation commander interface implement.
-type commandSqlAll struct {
+type commandSQLAll struct {
al *alias
}
// parse orm command line arguments.
-func (d *commandSqlAll) Parse(args []string) {
+func (d *commandSQLAll) Parse(args []string) {
var name string
flagSet := flag.NewFlagSet("orm command: sqlall", flag.ExitOnError)
@@ -244,13 +244,13 @@ func (d *commandSqlAll) Parse(args []string) {
}
// run orm line command.
-func (d *commandSqlAll) Run() error {
- sqls, indexes := getDbCreateSql(d.al)
+func (d *commandSQLAll) Run() error {
+ sqls, indexes := getDbCreateSQL(d.al)
var all []string
for i, mi := range modelCache.allOrdered() {
queries := []string{sqls[i]}
for _, idx := range indexes[mi.table] {
- queries = append(queries, idx.Sql)
+ queries = append(queries, idx.SQL)
}
sql := strings.Join(queries, "\n")
all = append(all, sql)
@@ -262,10 +262,10 @@ func (d *commandSqlAll) Run() error {
func init() {
commands["syncdb"] = new(commandSyncDb)
- commands["sqlall"] = new(commandSqlAll)
+ commands["sqlall"] = new(commandSQLAll)
}
-// run syncdb command line.
+// RunSyncdb run syncdb command line.
// name means table's alias name. default is "default".
// force means run next sql if the current is error.
// verbose means show all info when running command or not.
diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go
index ea105624..da0ee8ab 100644
--- a/orm/cmd_utils.go
+++ b/orm/cmd_utils.go
@@ -23,11 +23,11 @@ import (
type dbIndex struct {
Table string
Name string
- Sql string
+ SQL string
}
// create database drop sql.
-func getDbDropSql(al *alias) (sqls []string) {
+func getDbDropSQL(al *alias) (sqls []string) {
if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model")
os.Exit(2)
@@ -45,13 +45,14 @@ func getDbDropSql(al *alias) (sqls []string) {
func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
T := al.DbBaser.DbTypes()
fieldType := fi.fieldType
+ fieldSize := fi.size
checkColumn:
switch fieldType {
case TypeBooleanField:
col = T["bool"]
case TypeCharField:
- col = fmt.Sprintf(T["string"], fi.size)
+ col = fmt.Sprintf(T["string"], fieldSize)
case TypeTextField:
col = T["string-text"]
case TypeDateField:
@@ -65,7 +66,7 @@ checkColumn:
case TypeIntegerField:
col = T["int32"]
case TypeBigIntegerField:
- if al.Driver == DR_Sqlite {
+ if al.Driver == DRSqlite {
fieldType = TypeIntegerField
goto checkColumn
}
@@ -89,6 +90,7 @@ checkColumn:
}
case RelForeignKey, RelOneToOne:
fieldType = fi.relModelInfo.fields.pk.fieldType
+ fieldSize = fi.relModelInfo.fields.pk.size
goto checkColumn
}
@@ -104,15 +106,15 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string {
typ += " " + "NOT NULL"
}
- return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s",
- Q, fi.mi.table, Q,
- Q, fi.column, Q,
+ return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s",
+ Q, fi.mi.table, Q,
+ Q, fi.column, Q,
typ, getColumnDefault(fi),
)
}
// create database creation string.
-func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
+func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model")
os.Exit(2)
@@ -142,7 +144,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
if fi.auto {
switch al.Driver {
- case DR_Sqlite, DR_Postgres:
+ case DRSqlite, DRPostgres:
column += T["auto"]
default:
column += col + " " + T["auto"]
@@ -159,7 +161,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
//if fi.initial.String() != "" {
// column += " DEFAULT " + fi.initial.String()
//}
-
+
// Append attribute DEFAULT
column += getColumnDefault(fi)
@@ -201,7 +203,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
sql += strings.Join(columns, ",\n")
sql += "\n)"
- if al.Driver == DR_MySQL {
+ if al.Driver == DRMySQL {
var engine string
if mi.model != nil {
engine = getTableEngine(mi.addrField)
@@ -237,7 +239,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
index := dbIndex{}
index.Table = mi.table
index.Name = name
- index.Sql = sql
+ index.SQL = sql
tableIndexes[mi.table] = append(tableIndexes[mi.table], index)
}
@@ -247,7 +249,6 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
return
}
-
// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands
func getColumnDefault(fi *fieldInfo) string {
var (
@@ -263,16 +264,20 @@ func getColumnDefault(fi *fieldInfo) string {
// These defaults will be useful if there no config value orm:"default" and NOT NULL is on
switch fi.fieldType {
- case TypeDateField, TypeDateTimeField:
- return v;
-
- case TypeBooleanField, TypeBitField, TypeSmallIntegerField, TypeIntegerField,
- TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField,
+ case TypeDateField, TypeDateTimeField, TypeTextField:
+ return v
+
+ case TypeBitField, TypeSmallIntegerField, TypeIntegerField,
+ TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField,
TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField,
TypeDecimalField:
- d = "0"
+ t = " DEFAULT %s "
+ d = "0"
+ case TypeBooleanField:
+ t = " DEFAULT %s "
+ d = "FALSE"
}
-
+
if fi.colDefault {
if !fi.initial.Exist() {
v = fmt.Sprintf(t, "")
diff --git a/orm/db.go b/orm/db.go
index 10f65fee..b62c165b 100644
--- a/orm/db.go
+++ b/orm/db.go
@@ -24,12 +24,13 @@ import (
)
const (
- format_Date = "2006-01-02"
- format_DateTime = "2006-01-02 15:04:05"
+ formatDate = "2006-01-02"
+ formatDateTime = "2006-01-02 15:04:05"
)
var (
- ErrMissPK = errors.New("missed pk value") // missing pk error
+ // ErrMissPK missing pk error
+ ErrMissPK = errors.New("missed pk value")
)
var (
@@ -44,6 +45,8 @@ var (
"gte": true,
"lt": true,
"lte": true,
+ "eq": true,
+ "nq": true,
"startswith": true,
"endswith": true,
"istartswith": true,
@@ -214,14 +217,14 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
}
}
if fi.null == false && value == nil {
- return nil, errors.New(fmt.Sprintf("field `%s` cannot be NULL", fi.fullName))
+ return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName)
}
}
}
}
switch fi.fieldType {
case TypeDateField, TypeDateTimeField:
- if fi.auto_now || fi.auto_now_add && insert {
+ if fi.autoNow || fi.autoNowAdd && insert {
if insert {
if t, ok := value.(time.Time); ok && !t.IsZero() {
break
@@ -280,13 +283,12 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
var id int64
err := row.Scan(&id)
return id, err
- } else {
- if res, err := stmt.Exec(values...); err == nil {
- return res.LastInsertId()
- } else {
- return 0, err
- }
}
+ res, err := stmt.Exec(values...)
+ if err == nil {
+ return res.LastInsertId()
+ }
+ return 0, err
}
// query sql ,read records and persist in dbBaser.
@@ -324,7 +326,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q)
refs := make([]interface{}, colsNum)
- for i, _ := range refs {
+ for i := range refs {
var ref interface{}
refs[i] = &ref
}
@@ -337,15 +339,11 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
return ErrNoRows
}
return err
- } else {
- elm := reflect.New(mi.addrField.Elem().Type())
- mind := reflect.Indirect(elm)
-
- d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz)
-
- ind.Set(mind)
}
-
+ elm := reflect.New(mi.addrField.Elem().Type())
+ mind := reflect.Indirect(elm)
+ d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz)
+ ind.Set(mind)
return nil
}
@@ -423,7 +421,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
Q := d.ins.TableQuote()
marks := make([]string, len(names))
- for i, _ := range marks {
+ for i := range marks {
marks[i] = "?"
}
@@ -442,20 +440,19 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
- if res, err := q.Exec(query, values...); err == nil {
+ res, err := q.Exec(query, values...)
+ if err == nil {
if isMulti {
return res.RowsAffected()
}
return res.LastInsertId()
- } else {
- return 0, err
}
- } else {
- row := q.QueryRow(query, values...)
- var id int64
- err := row.Scan(&id)
- return id, err
+ return 0, err
}
+ row := q.QueryRow(query, values...)
+ var id int64
+ err := row.Scan(&id)
+ return id, err
}
// execute update sql dbQuerier with given struct reflect.Value.
@@ -491,11 +488,11 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
d.ins.ReplaceMarks(&query)
- if res, err := q.Exec(query, setValues...); err == nil {
+ res, err := q.Exec(query, setValues...)
+ if err == nil {
return res.RowsAffected()
- } else {
- return 0, err
}
+ return 0, err
}
// execute delete sql dbQuerier with given struct reflect.Value.
@@ -511,14 +508,12 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, pkName, Q)
d.ins.ReplaceMarks(&query)
-
- if res, err := q.Exec(query, pkValue); err == nil {
-
+ res, err := q.Exec(query, pkValue)
+ if err == nil {
num, err := res.RowsAffected()
if err != nil {
return 0, err
}
-
if num > 0 {
if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
@@ -527,17 +522,14 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
}
}
-
err := d.deleteRels(q, mi, []interface{}{pkValue}, tz)
if err != nil {
return num, err
}
}
-
return num, err
- } else {
- return 0, err
}
+ return 0, err
}
// update table-related record by querySet.
@@ -563,11 +555,11 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
tables.parseRelated(qs.related, qs.relDepth)
}
- where, args := tables.getCondSql(cond, false, tz)
+ where, args := tables.getCondSQL(cond, false, tz)
values = append(values, args...)
- join := tables.getJoinSql()
+ join := tables.getJoinSQL()
var query, T string
@@ -583,13 +575,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
col := fmt.Sprintf("%s%s%s%s", T, Q, v, Q)
if c, ok := values[i].(colValue); ok {
switch c.opt {
- case Col_Add:
+ case ColAdd:
cols = append(cols, col+" = "+col+" + ?")
- case Col_Minus:
+ case ColMinus:
cols = append(cols, col+" = "+col+" - ?")
- case Col_Multiply:
+ case ColMultiply:
cols = append(cols, col+" = "+col+" * ?")
- case Col_Except:
+ case ColExcept:
cols = append(cols, col+" = "+col+" / ?")
}
values[i] = c.value
@@ -608,12 +600,11 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
}
d.ins.ReplaceMarks(&query)
-
- if res, err := q.Exec(query, values...); err == nil {
+ res, err := q.Exec(query, values...)
+ if err == nil {
return res.RowsAffected()
- } else {
- return 0, err
}
+ return 0, err
}
// delete related records.
@@ -622,23 +613,23 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
for _, fi := range mi.fields.fieldsReverse {
fi = fi.reverseFieldInfo
switch fi.onDelete {
- case od_CASCADE:
+ case odCascade:
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
_, err := d.DeleteBatch(q, nil, fi.mi, cond, tz)
if err != nil {
return err
}
- case od_SET_DEFAULT, od_SET_NULL:
+ case odSetDefault, odSetNULL:
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
params := Params{fi.column: nil}
- if fi.onDelete == od_SET_DEFAULT {
+ if fi.onDelete == odSetDefault {
params[fi.column] = fi.initial.String()
}
_, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz)
if err != nil {
return err
}
- case od_DO_NOTHING:
+ case odDoNothing:
}
}
return nil
@@ -659,8 +650,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
Q := d.ins.TableQuote()
- where, args := tables.getCondSql(cond, false, tz)
- join := tables.getJoinSql()
+ where, args := tables.getCondSQL(cond, false, tz)
+ join := tables.getJoinSQL()
cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q)
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where)
@@ -668,16 +659,14 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
d.ins.ReplaceMarks(&query)
var rs *sql.Rows
- if r, err := q.Query(query, args...); err != nil {
+ r, err := q.Query(query, args...)
+ if err != nil {
return 0, err
- } else {
- rs = r
}
-
+ rs = r
defer rs.Close()
var ref interface{}
-
args = make([]interface{}, 0)
cnt := 0
for rs.Next() {
@@ -693,31 +682,28 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
}
marks := make([]string, len(args))
- for i, _ := range marks {
+ for i := range marks {
marks[i] = "?"
}
sql := fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql)
d.ins.ReplaceMarks(&query)
-
- if res, err := q.Exec(query, args...); err == nil {
+ res, err := q.Exec(query, args...)
+ if err == nil {
num, err := res.RowsAffected()
if err != nil {
return 0, err
}
-
if num > 0 {
err := d.deleteRels(q, mi, args, tz)
if err != nil {
return num, err
}
}
-
return num, nil
- } else {
- return 0, err
}
+ return 0, err
}
// read related records.
@@ -799,10 +785,11 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
- where, args := tables.getCondSql(cond, false, tz)
- orderBy := tables.getOrderSql(qs.orders)
- limit := tables.getLimitSql(mi, offset, rlimit)
- join := tables.getJoinSql()
+ where, args := tables.getCondSQL(cond, false, tz)
+ groupBy := tables.getGroupSQL(qs.groups)
+ orderBy := tables.getOrderSQL(qs.orders)
+ limit := tables.getLimitSQL(mi, offset, rlimit)
+ join := tables.getJoinSQL()
for _, tbl := range tables.tables {
if tbl.sel {
@@ -812,19 +799,23 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
}
}
- query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit)
+ sqlSelect := "SELECT"
+ if qs.distinct {
+ sqlSelect += " DISTINCT"
+ }
+ query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit)
d.ins.ReplaceMarks(&query)
var rs *sql.Rows
- if r, err := q.Query(query, args...); err != nil {
+ r, err := q.Query(query, args...)
+ if err != nil {
return 0, err
- } else {
- rs = r
}
+ rs = r
refs := make([]interface{}, colsNum)
- for i, _ := range refs {
+ for i := range refs {
var ref interface{}
refs[i] = &ref
}
@@ -935,9 +926,9 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
- where, args := tables.getCondSql(cond, false, tz)
- tables.getOrderSql(qs.orders)
- join := tables.getJoinSql()
+ where, args := tables.getCondSQL(cond, false, tz)
+ tables.getOrderSQL(qs.orders)
+ join := tables.getJoinSQL()
Q := d.ins.TableQuote()
@@ -952,7 +943,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
}
// generate sql with replacing operator string placeholders and replaced values.
-func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
+func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
sql := ""
params := getFlatParams(fi, args, tz)
@@ -964,7 +955,7 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
switch operator {
case "in":
marks := make([]string, len(params))
- for i, _ := range marks {
+ for i := range marks {
marks[i] = "?"
}
sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
@@ -977,7 +968,7 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
if len(params) > 1 {
panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params)))
}
- sql = d.ins.OperatorSql(operator)
+ sql = d.ins.OperatorSQL(operator)
switch operator {
case "exact":
if arg == nil {
@@ -1105,12 +1096,12 @@ setValue:
)
if len(s) >= 19 {
s = s[:19]
- t, err = time.ParseInLocation(format_DateTime, s, tz)
+ t, err = time.ParseInLocation(formatDateTime, s, tz)
} else {
if len(s) > 10 {
s = s[:10]
}
- t, err = time.ParseInLocation(format_Date, s, tz)
+ t, err = time.ParseInLocation(formatDate, s, tz)
}
t = t.In(DefaultTimeLoc)
@@ -1441,26 +1432,24 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
}
}
- where, args := tables.getCondSql(cond, false, tz)
- orderBy := tables.getOrderSql(qs.orders)
- limit := tables.getLimitSql(mi, qs.offset, qs.limit)
- join := tables.getJoinSql()
+ where, args := tables.getCondSQL(cond, false, tz)
+ groupBy := tables.getGroupSQL(qs.groups)
+ orderBy := tables.getOrderSQL(qs.orders)
+ limit := tables.getLimitSQL(mi, qs.offset, qs.limit)
+ join := tables.getJoinSQL()
sels := strings.Join(cols, ", ")
- query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit)
+ query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s%s", sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit)
d.ins.ReplaceMarks(&query)
- var rs *sql.Rows
- if r, err := q.Query(query, args...); err != nil {
+ rs, err := q.Query(query, args...)
+ if err != nil {
return 0, err
- } else {
- rs = r
}
-
refs := make([]interface{}, len(cols))
- for i, _ := range refs {
+ for i := range refs {
var ref interface{}
refs[i] = &ref
}
@@ -1473,11 +1462,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
)
for rs.Next() {
if cnt == 0 {
- if cols, err := rs.Columns(); err != nil {
+ cols, err := rs.Columns()
+ if err != nil {
return 0, err
- } else {
- columns = cols
}
+ columns = cols
}
if err := rs.Scan(refs...); err != nil {
@@ -1641,7 +1630,7 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e
}
// not implement.
-func (d *dbBase) OperatorSql(operator string) string {
+func (d *dbBase) OperatorSQL(operator string) string {
panic(ErrNotImplement)
}
diff --git a/orm/db_alias.go b/orm/db_alias.go
index 0a862241..79576b8e 100644
--- a/orm/db_alias.go
+++ b/orm/db_alias.go
@@ -22,15 +22,17 @@ import (
"time"
)
-// database driver constant int.
+// DriverType database driver constant int.
type DriverType int
+// Enum the Database driver
const (
- _ DriverType = iota // int enum type
- DR_MySQL // mysql
- DR_Sqlite // sqlite
- DR_Oracle // oracle
- DR_Postgres // pgsql
+ _ DriverType = iota // int enum type
+ DRMySQL // mysql
+ DRSqlite // sqlite
+ DROracle // oracle
+ DRPostgres // pgsql
+ DRTiDB // TiDB
)
// database driver string.
@@ -53,15 +55,17 @@ var _ Driver = new(driver)
var (
dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
drivers = map[string]DriverType{
- "mysql": DR_MySQL,
- "postgres": DR_Postgres,
- "sqlite3": DR_Sqlite,
+ "mysql": DRMySQL,
+ "postgres": DRPostgres,
+ "sqlite3": DRSqlite,
+ "tidb": DRTiDB,
}
dbBasers = map[DriverType]dbBaser{
- DR_MySQL: newdbBaseMysql(),
- DR_Sqlite: newdbBaseSqlite(),
- DR_Oracle: newdbBaseMysql(),
- DR_Postgres: newdbBasePostgres(),
+ DRMySQL: newdbBaseMysql(),
+ DRSqlite: newdbBaseSqlite(),
+ DROracle: newdbBaseOracle(),
+ DRPostgres: newdbBasePostgres(),
+ DRTiDB: newdbBaseTidb(),
}
)
@@ -119,7 +123,7 @@ func detectTZ(al *alias) {
}
switch al.Driver {
- case DR_MySQL:
+ case DRMySQL:
row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
var tz string
row.Scan(&tz)
@@ -147,10 +151,10 @@ func detectTZ(al *alias) {
al.Engine = "INNODB"
}
- case DR_Sqlite:
+ case DRSqlite:
al.TZ = time.UTC
- case DR_Postgres:
+ case DRPostgres:
row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
var tz string
row.Scan(&tz)
@@ -188,12 +192,13 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
return al, nil
}
+// AddAliasWthDB add a aliasName for the drivename
func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
_, err := addAliasWthDB(aliasName, driverName, db)
return err
}
-// Setting the database connect params. Use the database driver self dataSource args.
+// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args.
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
var (
err error
@@ -236,7 +241,7 @@ end:
return err
}
-// Register a database driver use specify driver name, this can be definition the driver is which database type.
+// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
func RegisterDriver(driverName string, typ DriverType) error {
if t, ok := drivers[driverName]; ok == false {
drivers[driverName] = typ
@@ -248,7 +253,7 @@ func RegisterDriver(driverName string, typ DriverType) error {
return nil
}
-// Change the database default used timezone
+// SetDataBaseTZ Change the database default used timezone
func SetDataBaseTZ(aliasName string, tz *time.Location) error {
if al, ok := dataBaseCache.get(aliasName); ok {
al.TZ = tz
@@ -258,14 +263,14 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error {
return nil
}
-// Change the max idle conns for *sql.DB, use specify database alias name
+// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name
func SetMaxIdleConns(aliasName string, maxIdleConns int) {
al := getDbAlias(aliasName)
al.MaxIdleConns = maxIdleConns
al.DB.SetMaxIdleConns(maxIdleConns)
}
-// Change the max open conns for *sql.DB, use specify database alias name
+// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
func SetMaxOpenConns(aliasName string, maxOpenConns int) {
al := getDbAlias(aliasName)
al.MaxOpenConns = maxOpenConns
@@ -275,7 +280,7 @@ func SetMaxOpenConns(aliasName string, maxOpenConns int) {
}
}
-// Get *sql.DB from registered database by db alias name.
+// GetDB Get *sql.DB from registered database by db alias name.
// Use "default" as alias name if you not set.
func GetDB(aliasNames ...string) (*sql.DB, error) {
var name string
@@ -284,9 +289,9 @@ func GetDB(aliasNames ...string) (*sql.DB, error) {
} else {
name = "default"
}
- if al, ok := dataBaseCache.get(name); ok {
+ al, ok := dataBaseCache.get(name)
+ if ok {
return al.DB, nil
- } else {
- return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
}
+ return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
}
diff --git a/orm/db_mysql.go b/orm/db_mysql.go
index 3c2ad3a4..10fe2657 100644
--- a/orm/db_mysql.go
+++ b/orm/db_mysql.go
@@ -30,6 +30,8 @@ var mysqlOperators = map[string]string{
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
+ "eq": "= ?",
+ "ne": "!= ?",
"startswith": "LIKE BINARY ?",
"endswith": "LIKE BINARY ?",
"istartswith": "LIKE ?",
@@ -65,7 +67,7 @@ type dbBaseMysql struct {
var _ dbBaser = new(dbBaseMysql)
// get mysql operator.
-func (d *dbBaseMysql) OperatorSql(operator string) string {
+func (d *dbBaseMysql) OperatorSQL(operator string) string {
return mysqlOperators[operator]
}
diff --git a/orm/db_postgres.go b/orm/db_postgres.go
index 296ee6a0..7dbef95a 100644
--- a/orm/db_postgres.go
+++ b/orm/db_postgres.go
@@ -29,6 +29,8 @@ var postgresOperators = map[string]string{
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
+ "eq": "= ?",
+ "ne": "!= ?",
"startswith": "LIKE ?",
"endswith": "LIKE ?",
"istartswith": "LIKE UPPER(?)",
@@ -64,7 +66,7 @@ type dbBasePostgres struct {
var _ dbBaser = new(dbBasePostgres)
// get postgresql operator.
-func (d *dbBasePostgres) OperatorSql(operator string) string {
+func (d *dbBasePostgres) OperatorSQL(operator string) string {
return postgresOperators[operator]
}
@@ -99,7 +101,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
num := 0
for _, c := range q {
if c == '?' {
- num += 1
+ num++
}
}
if num == 0 {
@@ -112,7 +114,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
if c == '?' {
data = append(data, '$')
data = append(data, []byte(strconv.Itoa(num))...)
- num += 1
+ num++
} else {
data = append(data, c)
}
diff --git a/orm/db_sqlite.go b/orm/db_sqlite.go
index 0a2f32c8..a3cb69a7 100644
--- a/orm/db_sqlite.go
+++ b/orm/db_sqlite.go
@@ -29,6 +29,8 @@ var sqliteOperators = map[string]string{
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
+ "eq": "= ?",
+ "ne": "!= ?",
"startswith": "LIKE ? ESCAPE '\\'",
"endswith": "LIKE ? ESCAPE '\\'",
"istartswith": "LIKE ? ESCAPE '\\'",
@@ -64,7 +66,7 @@ type dbBaseSqlite struct {
var _ dbBaser = new(dbBaseSqlite)
// get sqlite operator.
-func (d *dbBaseSqlite) OperatorSql(operator string) string {
+func (d *dbBaseSqlite) OperatorSQL(operator string) string {
return sqliteOperators[operator]
}
diff --git a/orm/db_tables.go b/orm/db_tables.go
index a9aa10ab..e4c74ace 100644
--- a/orm/db_tables.go
+++ b/orm/db_tables.go
@@ -164,7 +164,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
}
// generate join string.
-func (t *dbTables) getJoinSql() (join string) {
+func (t *dbTables) getJoinSQL() (join string) {
Q := t.base.TableQuote()
for _, jt := range t.tables {
@@ -186,7 +186,7 @@ func (t *dbTables) getJoinSql() (join string) {
table = jt.mi.table
switch {
- case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
+ case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
c1 = jt.fi.mi.fields.pk.column
for _, ffi := range jt.mi.fields.fieldsRel {
if jt.fi.mi == ffi.relModelInfo {
@@ -220,7 +220,7 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
)
num := len(exprs) - 1
- names := make([]string, 0)
+ var names []string
inner := true
@@ -326,7 +326,7 @@ loopFor:
}
// generate condition sql.
-func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
+func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
if cond == nil || cond.IsEmpty() {
return
}
@@ -347,7 +347,7 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
where += "NOT "
}
if p.isCond {
- w, ps := t.getCondSql(p.cond, true, tz)
+ w, ps := t.getCondSQL(p.cond, true, tz)
if w != "" {
w = fmt.Sprintf("( %s) ", w)
}
@@ -372,12 +372,12 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
operator = "exact"
}
- operSql, args := t.base.GenerateOperatorSql(mi, fi, operator, p.args, tz)
+ operSQL, args := t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz)
leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
t.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
- where += fmt.Sprintf("%s %s ", leftCol, operSql)
+ where += fmt.Sprintf("%s %s ", leftCol, operSQL)
params = append(params, args...)
}
@@ -390,8 +390,32 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
return
}
+// generate group sql.
+func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
+ if len(groups) == 0 {
+ return
+ }
+
+ Q := t.base.TableQuote()
+
+ groupSqls := make([]string, 0, len(groups))
+ for _, group := range groups {
+ exprs := strings.Split(group, ExprSep)
+
+ index, _, fi, suc := t.parseExprs(t.mi, exprs)
+ if suc == false {
+ panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
+ }
+
+ groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q))
+ }
+
+ groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", "))
+ return
+}
+
// generate order sql.
-func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
+func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
if len(orders) == 0 {
return
}
@@ -415,12 +439,12 @@ func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc))
}
- orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
+ orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
return
}
// generate limit sql.
-func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) {
+func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) {
if limit == 0 {
limit = int64(DefaultRowsLimit)
}
diff --git a/orm/db_tidb.go b/orm/db_tidb.go
new file mode 100644
index 00000000..6020a488
--- /dev/null
+++ b/orm/db_tidb.go
@@ -0,0 +1,63 @@
+// Copyright 2015 TiDB Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package orm
+
+import (
+ "fmt"
+)
+
+// mysql dbBaser implementation.
+type dbBaseTidb struct {
+ dbBase
+}
+
+var _ dbBaser = new(dbBaseTidb)
+
+// get mysql operator.
+func (d *dbBaseTidb) OperatorSQL(operator string) string {
+ return mysqlOperators[operator]
+}
+
+// get mysql table field types.
+func (d *dbBaseTidb) DbTypes() map[string]string {
+ return mysqlTypes
+}
+
+// show table sql for mysql.
+func (d *dbBaseTidb) ShowTablesQuery() string {
+ return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
+}
+
+// show columns sql of table for mysql.
+func (d *dbBaseTidb) ShowColumnsQuery(table string) string {
+ return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
+ "WHERE table_schema = DATABASE() AND table_name = '%s'", table)
+}
+
+// execute sql to check index exist.
+func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool {
+ row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
+ "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
+ var cnt int
+ row.Scan(&cnt)
+ return cnt > 0
+}
+
+// create new mysql dbBaser.
+func newdbBaseTidb() dbBaser {
+ b := new(dbBaseTidb)
+ b.ins = b
+ return b
+}
diff --git a/orm/db_utils.go b/orm/db_utils.go
index 4a3ba464..ae9b1625 100644
--- a/orm/db_utils.go
+++ b/orm/db_utils.go
@@ -24,9 +24,8 @@ import (
func getDbAlias(name string) *alias {
if al, ok := dataBaseCache.get(name); ok {
return al
- } else {
- panic(fmt.Errorf("unknown DataBase alias name %s", name))
}
+ panic(fmt.Errorf("unknown DataBase alias name %s", name))
}
// get pk column info.
@@ -80,19 +79,19 @@ outFor:
var err error
if len(v) >= 19 {
s := v[:19]
- t, err = time.ParseInLocation(format_DateTime, s, DefaultTimeLoc)
+ t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc)
} else {
s := v
if len(v) > 10 {
s = v[:10]
}
- t, err = time.ParseInLocation(format_Date, s, tz)
+ t, err = time.ParseInLocation(formatDate, s, tz)
}
if err == nil {
if fi.fieldType == TypeDateField {
- v = t.In(tz).Format(format_Date)
+ v = t.In(tz).Format(formatDate)
} else {
- v = t.In(tz).Format(format_DateTime)
+ v = t.In(tz).Format(formatDateTime)
}
}
}
@@ -137,9 +136,9 @@ outFor:
case reflect.Struct:
if v, ok := arg.(time.Time); ok {
if fi != nil && fi.fieldType == TypeDateField {
- arg = v.In(tz).Format(format_Date)
+ arg = v.In(tz).Format(formatDate)
} else {
- arg = v.In(tz).Format(format_DateTime)
+ arg = v.In(tz).Format(formatDateTime)
}
} else {
typ := val.Type()
diff --git a/orm/models.go b/orm/models.go
index dcb32b55..faf551be 100644
--- a/orm/models.go
+++ b/orm/models.go
@@ -19,10 +19,10 @@ import (
)
const (
- od_CASCADE = "cascade"
- od_SET_NULL = "set_null"
- od_SET_DEFAULT = "set_default"
- od_DO_NOTHING = "do_nothing"
+ odCascade = "cascade"
+ odSetNULL = "set_null"
+ odSetDefault = "set_default"
+ odDoNothing = "do_nothing"
defaultStructTagName = "orm"
defaultStructTagDelim = ";"
)
@@ -113,7 +113,7 @@ func (mc *_modelCache) clean() {
mc.done = false
}
-// Clean model cache. Then you can re-RegisterModel.
+// ResetModelCache Clean model cache. Then you can re-RegisterModel.
// Common use this api for test case.
func ResetModelCache() {
modelCache.clean()
diff --git a/orm/models_boot.go b/orm/models_boot.go
index cb44bc05..3690557b 100644
--- a/orm/models_boot.go
+++ b/orm/models_boot.go
@@ -51,19 +51,16 @@ func registerModel(prefix string, model interface{}) {
}
info := newModelInfo(val)
-
if info.fields.pk == nil {
outFor:
for _, fi := range info.fields.fieldsDB {
- if fi.name == "Id" {
- if fi.sf.Tag.Get(defaultStructTagName) == "" {
- switch fi.addrValue.Elem().Kind() {
- case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
- fi.auto = true
- fi.pk = true
- info.fields.pk = fi
- break outFor
- }
+ if strings.ToLower(fi.name) == "id" {
+ switch fi.addrValue.Elem().Kind() {
+ case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
+ fi.auto = true
+ fi.pk = true
+ info.fields.pk = fi
+ break outFor
}
}
}
@@ -269,7 +266,10 @@ func bootStrap() {
if found == false {
mForC:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] {
- if ffi.relModelInfo == mi {
+ conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough ||
+ fi.relTable != "" && fi.relTable == ffi.relTable ||
+ fi.relThrough == "" && fi.relTable == ""
+ if ffi.relModelInfo == mi && conditions {
found = true
fi.reverseField = ffi.reverseFieldInfoTwo.name
@@ -298,12 +298,12 @@ end:
}
}
-// register models
+// RegisterModel register models
func RegisterModel(models ...interface{}) {
RegisterModelWithPrefix("", models...)
}
-// register models with a prefix
+// RegisterModelWithPrefix register models with a prefix
func RegisterModelWithPrefix(prefix string, models ...interface{}) {
if modelCache.done {
panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
@@ -314,7 +314,7 @@ func RegisterModelWithPrefix(prefix string, models ...interface{}) {
}
}
-// bootrap models.
+// BootStrap bootrap models.
// make all model parsed and can not add more models
func BootStrap() {
if modelCache.done {
diff --git a/orm/models_fields.go b/orm/models_fields.go
index f038dd0f..a8cf8e4f 100644
--- a/orm/models_fields.go
+++ b/orm/models_fields.go
@@ -15,49 +15,28 @@
package orm
import (
- "errors"
"fmt"
"strconv"
"time"
)
+// Define the Type enum
const (
- // bool
TypeBooleanField = 1 << iota
-
- // string
TypeCharField
-
- // string
TypeTextField
-
- // time.Time
TypeDateField
- // time.Time
TypeDateTimeField
-
- // int8
TypeBitField
- // int16
TypeSmallIntegerField
- // int32
TypeIntegerField
- // int64
TypeBigIntegerField
- // uint8
TypePositiveBitField
- // uint16
TypePositiveSmallIntegerField
- // uint32
TypePositiveIntegerField
- // uint64
TypePositiveBigIntegerField
-
- // float64
TypeFloatField
- // float64
TypeDecimalField
-
RelForeignKey
RelOneToOne
RelManyToMany
@@ -65,6 +44,7 @@ const (
RelReverseMany
)
+// Define some logic enum
const (
IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5
IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 8 << 9
@@ -72,25 +52,30 @@ const (
IsFieldType = ^-RelReverseMany<<1 + 1
)
-// A true/false field.
+// BooleanField A true/false field.
type BooleanField bool
+// Value return the BooleanField
func (e BooleanField) Value() bool {
return bool(e)
}
+// Set will set the BooleanField
func (e *BooleanField) Set(d bool) {
*e = BooleanField(d)
}
+// String format the Bool to string
func (e *BooleanField) String() string {
return strconv.FormatBool(e.Value())
}
+// FieldType return BooleanField the type
func (e *BooleanField) FieldType() int {
return TypeBooleanField
}
+// SetRaw set the interface to bool
func (e *BooleanField) SetRaw(value interface{}) error {
switch d := value.(type) {
case bool:
@@ -102,56 +87,65 @@ func (e *BooleanField) SetRaw(value interface{}) error {
}
return err
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return the current value
func (e *BooleanField) RawValue() interface{} {
return e.Value()
}
+// verify the BooleanField implement the Fielder interface
var _ Fielder = new(BooleanField)
-// A string field
+// CharField A string field
// required values tag: size
// The size is enforced at the database level and in models’s validation.
// eg: `orm:"size(120)"`
type CharField string
+// Value return the CharField's Value
func (e CharField) Value() string {
return string(e)
}
+// Set CharField value
func (e *CharField) Set(d string) {
*e = CharField(d)
}
+// String return the CharField
func (e *CharField) String() string {
return e.Value()
}
+// FieldType return the enum type
func (e *CharField) FieldType() int {
return TypeCharField
}
+// SetRaw set the interface to string
func (e *CharField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return the CharField value
func (e *CharField) RawValue() interface{} {
return e.Value()
}
+// verify CharField implement Fielder
var _ Fielder = new(CharField)
-// A date, represented in go by a time.Time instance.
+// DateField A date, represented in go by a time.Time instance.
// only date values like 2006-01-02
// Has a few extra, optional attr tag:
//
@@ -166,106 +160,125 @@ var _ Fielder = new(CharField)
// eg: `orm:"auto_now"` or `orm:"auto_now_add"`
type DateField time.Time
+// Value return the time.Time
func (e DateField) Value() time.Time {
return time.Time(e)
}
+// Set set the DateField's value
func (e *DateField) Set(d time.Time) {
*e = DateField(d)
}
+// String convert datatime to string
func (e *DateField) String() string {
return e.Value().String()
}
+// FieldType return enum type Date
func (e *DateField) FieldType() int {
return TypeDateField
}
+// SetRaw convert the interface to time.Time. Allow string and time.Time
func (e *DateField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
- v, err := timeParse(d, format_Date)
+ v, err := timeParse(d, formatDate)
if err != nil {
e.Set(v)
}
return err
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return Date value
func (e *DateField) RawValue() interface{} {
return e.Value()
}
+// verify DateField implement fielder interface
var _ Fielder = new(DateField)
-// A date, represented in go by a time.Time instance.
+// DateTimeField A date, represented in go by a time.Time instance.
// datetime values like 2006-01-02 15:04:05
// Takes the same extra arguments as DateField.
type DateTimeField time.Time
+// Value return the datatime value
func (e DateTimeField) Value() time.Time {
return time.Time(e)
}
+// Set set the time.Time to datatime
func (e *DateTimeField) Set(d time.Time) {
*e = DateTimeField(d)
}
+// String return the time's String
func (e *DateTimeField) String() string {
return e.Value().String()
}
+// FieldType return the enum TypeDateTimeField
func (e *DateTimeField) FieldType() int {
return TypeDateTimeField
}
+// SetRaw convert the string or time.Time to DateTimeField
func (e *DateTimeField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
- v, err := timeParse(d, format_DateTime)
+ v, err := timeParse(d, formatDateTime)
if err != nil {
e.Set(v)
}
return err
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return the datatime value
func (e *DateTimeField) RawValue() interface{} {
return e.Value()
}
+// verify datatime implement fielder
var _ Fielder = new(DateTimeField)
-// A floating-point number represented in go by a float32 value.
+// FloatField A floating-point number represented in go by a float32 value.
type FloatField float64
+// Value return the FloatField value
func (e FloatField) Value() float64 {
return float64(e)
}
+// Set the Float64
func (e *FloatField) Set(d float64) {
*e = FloatField(d)
}
+// String return the string
func (e *FloatField) String() string {
return ToStr(e.Value(), -1, 32)
}
+// FieldType return the enum type
func (e *FloatField) FieldType() int {
return TypeFloatField
}
+// SetRaw converter interface Float64 float32 or string to FloatField
func (e *FloatField) SetRaw(value interface{}) error {
switch d := value.(type) {
case float32:
@@ -278,36 +291,43 @@ func (e *FloatField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return the FloatField value
func (e *FloatField) RawValue() interface{} {
return e.Value()
}
+// verify FloatField implement Fielder
var _ Fielder = new(FloatField)
-// -32768 to 32767
+// SmallIntegerField -32768 to 32767
type SmallIntegerField int16
+// Value return int16 value
func (e SmallIntegerField) Value() int16 {
return int16(e)
}
+// Set the SmallIntegerField value
func (e *SmallIntegerField) Set(d int16) {
*e = SmallIntegerField(d)
}
+// String convert smallint to string
func (e *SmallIntegerField) String() string {
return ToStr(e.Value())
}
+// FieldType return enum type SmallIntegerField
func (e *SmallIntegerField) FieldType() int {
return TypeSmallIntegerField
}
+// SetRaw convert interface int16/string to int16
func (e *SmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int16:
@@ -318,36 +338,43 @@ func (e *SmallIntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return smallint value
func (e *SmallIntegerField) RawValue() interface{} {
return e.Value()
}
+// verify SmallIntegerField implement Fielder
var _ Fielder = new(SmallIntegerField)
-// -2147483648 to 2147483647
+// IntegerField -2147483648 to 2147483647
type IntegerField int32
+// Value return the int32
func (e IntegerField) Value() int32 {
return int32(e)
}
+// Set IntegerField value
func (e *IntegerField) Set(d int32) {
*e = IntegerField(d)
}
+// String convert Int32 to string
func (e *IntegerField) String() string {
return ToStr(e.Value())
}
+// FieldType return the enum type
func (e *IntegerField) FieldType() int {
return TypeIntegerField
}
+// SetRaw convert interface int32/string to int32
func (e *IntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int32:
@@ -358,36 +385,43 @@ func (e *IntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return IntegerField value
func (e *IntegerField) RawValue() interface{} {
return e.Value()
}
+// verify IntegerField implement Fielder
var _ Fielder = new(IntegerField)
-// -9223372036854775808 to 9223372036854775807.
+// BigIntegerField -9223372036854775808 to 9223372036854775807.
type BigIntegerField int64
+// Value return int64
func (e BigIntegerField) Value() int64 {
return int64(e)
}
+// Set the BigIntegerField value
func (e *BigIntegerField) Set(d int64) {
*e = BigIntegerField(d)
}
+// String convert BigIntegerField to string
func (e *BigIntegerField) String() string {
return ToStr(e.Value())
}
+// FieldType return enum type
func (e *BigIntegerField) FieldType() int {
return TypeBigIntegerField
}
+// SetRaw convert interface int64/string to int64
func (e *BigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int64:
@@ -398,36 +432,43 @@ func (e *BigIntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return BigIntegerField value
func (e *BigIntegerField) RawValue() interface{} {
return e.Value()
}
+// verify BigIntegerField implement Fielder
var _ Fielder = new(BigIntegerField)
-// 0 to 65535
+// PositiveSmallIntegerField 0 to 65535
type PositiveSmallIntegerField uint16
+// Value return uint16
func (e PositiveSmallIntegerField) Value() uint16 {
return uint16(e)
}
+// Set PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) Set(d uint16) {
*e = PositiveSmallIntegerField(d)
}
+// String convert uint16 to string
func (e *PositiveSmallIntegerField) String() string {
return ToStr(e.Value())
}
+// FieldType return enum type
func (e *PositiveSmallIntegerField) FieldType() int {
return TypePositiveSmallIntegerField
}
+// SetRaw convert Interface uint16/string to uint16
func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint16:
@@ -438,36 +479,43 @@ func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue returns PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) RawValue() interface{} {
return e.Value()
}
+// verify PositiveSmallIntegerField implement Fielder
var _ Fielder = new(PositiveSmallIntegerField)
-// 0 to 4294967295
+// PositiveIntegerField 0 to 4294967295
type PositiveIntegerField uint32
+// Value return PositiveIntegerField value. Uint32
func (e PositiveIntegerField) Value() uint32 {
return uint32(e)
}
+// Set the PositiveIntegerField value
func (e *PositiveIntegerField) Set(d uint32) {
*e = PositiveIntegerField(d)
}
+// String convert PositiveIntegerField to string
func (e *PositiveIntegerField) String() string {
return ToStr(e.Value())
}
+// FieldType return enum type
func (e *PositiveIntegerField) FieldType() int {
return TypePositiveIntegerField
}
+// SetRaw convert interface uint32/string to Uint32
func (e *PositiveIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint32:
@@ -478,36 +526,43 @@ func (e *PositiveIntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return the PositiveIntegerField Value
func (e *PositiveIntegerField) RawValue() interface{} {
return e.Value()
}
+// verify PositiveIntegerField implement Fielder
var _ Fielder = new(PositiveIntegerField)
-// 0 to 18446744073709551615
+// PositiveBigIntegerField 0 to 18446744073709551615
type PositiveBigIntegerField uint64
+// Value return uint64
func (e PositiveBigIntegerField) Value() uint64 {
return uint64(e)
}
+// Set PositiveBigIntegerField value
func (e *PositiveBigIntegerField) Set(d uint64) {
*e = PositiveBigIntegerField(d)
}
+// String convert PositiveBigIntegerField to string
func (e *PositiveBigIntegerField) String() string {
return ToStr(e.Value())
}
+// FieldType return enum type
func (e *PositiveBigIntegerField) FieldType() int {
return TypePositiveIntegerField
}
+// SetRaw convert interface uint64/string to Uint64
func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint64:
@@ -518,48 +573,57 @@ func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return PositiveBigIntegerField value
func (e *PositiveBigIntegerField) RawValue() interface{} {
return e.Value()
}
+// verify PositiveBigIntegerField implement Fielder
var _ Fielder = new(PositiveBigIntegerField)
-// A large text field.
+// TextField A large text field.
type TextField string
+// Value return TextField value
func (e TextField) Value() string {
return string(e)
}
+// Set the TextField value
func (e *TextField) Set(d string) {
*e = TextField(d)
}
+// String convert TextField to string
func (e *TextField) String() string {
return e.Value()
}
+// FieldType return enum type
func (e *TextField) FieldType() int {
return TypeTextField
}
+// SetRaw convert interface string to string
func (e *TextField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
- return errors.New(fmt.Sprintf(" unknown value `%s`", value))
+ return fmt.Errorf(" unknown value `%s`", value)
}
return nil
}
+// RawValue return TextField value
func (e *TextField) RawValue() interface{} {
return e.Value()
}
+// verify TextField implement Fielder
var _ Fielder = new(TextField)
diff --git a/orm/models_info_f.go b/orm/models_info_f.go
index 84a0c024..14e1f2c6 100644
--- a/orm/models_info_f.go
+++ b/orm/models_info_f.go
@@ -119,8 +119,8 @@ type fieldInfo struct {
colDefault bool
initial StrTo
size int
- auto_now bool
- auto_now_add bool
+ autoNow bool
+ autoNowAdd bool
rel bool
reverse bool
reverseField string
@@ -223,6 +223,11 @@ checkType:
break checkType
case "many":
fieldType = RelReverseMany
+ if tv := tags["rel_table"]; tv != "" {
+ fi.relTable = tv
+ } else if tv := tags["rel_through"]; tv != "" {
+ fi.relThrough = tv
+ }
break checkType
default:
err = fmt.Errorf("error")
@@ -309,20 +314,20 @@ checkType:
if fi.rel && fi.dbcol {
switch onDelete {
- case od_CASCADE, od_DO_NOTHING:
- case od_SET_DEFAULT:
+ case odCascade, odDoNothing:
+ case odSetDefault:
if initial.Exist() == false {
err = errors.New("on_delete: set_default need set field a default value")
goto end
}
- case od_SET_NULL:
+ case odSetNULL:
if fi.null == false {
err = errors.New("on_delete: set_null need set field null")
goto end
}
default:
if onDelete == "" {
- onDelete = od_CASCADE
+ onDelete = odCascade
} else {
err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete)
goto end
@@ -350,9 +355,9 @@ checkType:
fi.unique = false
case TypeDateField, TypeDateTimeField:
if attrs["auto_now"] {
- fi.auto_now = true
+ fi.autoNow = true
} else if attrs["auto_now_add"] {
- fi.auto_now_add = true
+ fi.autoNowAdd = true
}
case TypeFloatField:
case TypeDecimalField:
diff --git a/orm/models_info_m.go b/orm/models_info_m.go
index 3600ee7c..2654cdb5 100644
--- a/orm/models_info_m.go
+++ b/orm/models_info_m.go
@@ -15,7 +15,6 @@
package orm
import (
- "errors"
"fmt"
"os"
"reflect"
@@ -72,13 +71,13 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
added := info.fields.Add(fi)
if added == false {
- err = errors.New(fmt.Sprintf("duplicate column name: %s", fi.column))
+ err = fmt.Errorf("duplicate column name: %s", fi.column)
break
}
if fi.pk {
if info.fields.pk != nil {
- err = errors.New(fmt.Sprintf("one model must have one pk field only"))
+ err = fmt.Errorf("one model must have one pk field only")
break
} else {
info.fields.pk = fi
diff --git a/orm/models_test.go b/orm/models_test.go
index 1a92ef5d..ee56e8e8 100644
--- a/orm/models_test.go
+++ b/orm/models_test.go
@@ -25,6 +25,9 @@ import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
+
+ // As tidb can't use go get, so disable the tidb testing now
+ // _ "github.com/pingcap/tidb"
)
// A slice string field.
@@ -76,21 +79,21 @@ func (e *SliceStringField) RawValue() interface{} {
var _ Fielder = new(SliceStringField)
// A json field.
-type JsonField struct {
+type JSONField struct {
Name string
Data string
}
-func (e *JsonField) String() string {
+func (e *JSONField) String() string {
data, _ := json.Marshal(e)
return string(data)
}
-func (e *JsonField) FieldType() int {
+func (e *JSONField) FieldType() int {
return TypeTextField
}
-func (e *JsonField) SetRaw(value interface{}) error {
+func (e *JSONField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
return json.Unmarshal([]byte(d), e)
@@ -99,14 +102,14 @@ func (e *JsonField) SetRaw(value interface{}) error {
}
}
-func (e *JsonField) RawValue() interface{} {
+func (e *JSONField) RawValue() interface{} {
return e.String()
}
-var _ Fielder = new(JsonField)
+var _ Fielder = new(JSONField)
type Data struct {
- Id int
+ ID int `orm:"column(id)"`
Boolean bool
Char string `orm:"size(50)"`
Text string `orm:"type(text)"`
@@ -130,7 +133,7 @@ type Data struct {
}
type DataNull struct {
- Id int
+ ID int `orm:"column(id)"`
Boolean bool `orm:"null"`
Char string `orm:"null;size(50)"`
Text string `orm:"null;type(text)"`
@@ -193,7 +196,7 @@ type Float32 float64
type Float64 float64
type DataCustom struct {
- Id int
+ ID int `orm:"column(id)"`
Boolean Boolean
Char string `orm:"size(50)"`
Text string `orm:"type(text)"`
@@ -216,40 +219,40 @@ type DataCustom struct {
// only for mysql
type UserBig struct {
- Id uint64
+ ID uint64 `orm:"column(id)"`
Name string
}
type User struct {
- Id int
- UserName string `orm:"size(30);unique"`
- Email string `orm:"size(100)"`
- Password string `orm:"size(100)"`
- Status int16 `orm:"column(Status)"`
- IsStaff bool
- IsActive bool `orm:"default(1)"`
- Created time.Time `orm:"auto_now_add;type(date)"`
- Updated time.Time `orm:"auto_now"`
- Profile *Profile `orm:"null;rel(one);on_delete(set_null)"`
- Posts []*Post `orm:"reverse(many)" json:"-"`
- ShouldSkip string `orm:"-"`
- Nums int
- Langs SliceStringField `orm:"size(100)"`
- Extra JsonField `orm:"type(text)"`
- unexport bool `orm:"-"`
- unexport_ bool
+ ID int `orm:"column(id)"`
+ UserName string `orm:"size(30);unique"`
+ Email string `orm:"size(100)"`
+ Password string `orm:"size(100)"`
+ Status int16 `orm:"column(Status)"`
+ IsStaff bool
+ IsActive bool `orm:"default(true)"`
+ Created time.Time `orm:"auto_now_add;type(date)"`
+ Updated time.Time `orm:"auto_now"`
+ Profile *Profile `orm:"null;rel(one);on_delete(set_null)"`
+ Posts []*Post `orm:"reverse(many)" json:"-"`
+ ShouldSkip string `orm:"-"`
+ Nums int
+ Langs SliceStringField `orm:"size(100)"`
+ Extra JSONField `orm:"type(text)"`
+ unexport bool `orm:"-"`
+ unexportBool bool
}
func (u *User) TableIndex() [][]string {
return [][]string{
- []string{"Id", "UserName"},
- []string{"Id", "Created"},
+ {"Id", "UserName"},
+ {"Id", "Created"},
}
}
func (u *User) TableUnique() [][]string {
return [][]string{
- []string{"UserName", "Email"},
+ {"UserName", "Email"},
}
}
@@ -259,7 +262,7 @@ func NewUser() *User {
}
type Profile struct {
- Id int
+ ID int `orm:"column(id)"`
Age int16
Money float64
User *User `orm:"reverse(one)" json:"-"`
@@ -276,7 +279,7 @@ func NewProfile() *Profile {
}
type Post struct {
- Id int
+ ID int `orm:"column(id)"`
User *User `orm:"rel(fk)"`
Title string `orm:"size(60)"`
Content string `orm:"type(text)"`
@@ -287,7 +290,7 @@ type Post struct {
func (u *Post) TableIndex() [][]string {
return [][]string{
- []string{"Id", "Created"},
+ {"Id", "Created"},
}
}
@@ -297,7 +300,7 @@ func NewPost() *Post {
}
type Tag struct {
- Id int
+ ID int `orm:"column(id)"`
Name string `orm:"size(30)"`
BestPost *Post `orm:"rel(one);null"`
Posts []*Post `orm:"reverse(many)" json:"-"`
@@ -309,7 +312,7 @@ func NewTag() *Tag {
}
type PostTags struct {
- Id int
+ ID int `orm:"column(id)"`
Post *Post `orm:"rel(fk)"`
Tag *Tag `orm:"rel(fk)"`
}
@@ -319,7 +322,7 @@ func (m *PostTags) TableName() string {
}
type Comment struct {
- Id int
+ ID int `orm:"column(id)"`
Post *Post `orm:"rel(fk);column(post)"`
Content string `orm:"type(text)"`
Parent *Comment `orm:"null;rel(fk)"`
@@ -331,6 +334,24 @@ func NewComment() *Comment {
return obj
}
+type Group struct {
+ ID int `orm:"column(gid);size(32)"`
+ Name string
+ Permissions []*Permission `orm:"reverse(many)" json:"-"`
+}
+
+type Permission struct {
+ ID int `orm:"column(id)"`
+ Name string
+ Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.GroupPermissions)"`
+}
+
+type GroupPermissions struct {
+ ID int `orm:"column(id)"`
+ Group *Group `orm:"rel(fk)"`
+ Permission *Permission `orm:"rel(fk)"`
+}
+
var DBARGS = struct {
Driver string
Source string
@@ -345,6 +366,7 @@ var (
IsMysql = DBARGS.Driver == "mysql"
IsSqlite = DBARGS.Driver == "sqlite3"
IsPostgres = DBARGS.Driver == "postgres"
+ IsTidb = DBARGS.Driver == "tidb"
)
var (
@@ -364,6 +386,7 @@ Default DB Drivers.
mysql: https://github.com/go-sql-driver/mysql
sqlite3: https://github.com/mattn/go-sqlite3
postgres: https://github.com/lib/pq
+tidb: https://github.com/pingcap/tidb
usage:
@@ -371,6 +394,7 @@ go get -u github.com/astaxie/beego/orm
go get -u github.com/go-sql-driver/mysql
go get -u github.com/mattn/go-sqlite3
go get -u github.com/lib/pq
+go get -u github.com/pingcap/tidb
#### MySQL
mysql -u root -e 'create database orm_test;'
@@ -390,6 +414,12 @@ psql -c 'create database orm_test;' -U postgres
export ORM_DRIVER=postgres
export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
go test -v github.com/astaxie/beego/orm
+
+#### TiDB
+export ORM_DRIVER=tidb
+export ORM_SOURCE='memory://test/test'
+go test -v github.com/astaxie/beego/orm
+
`)
os.Exit(2)
}
@@ -397,7 +427,7 @@ go test -v github.com/astaxie/beego/orm
RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20)
alias := getDbAlias("default")
- if alias.Driver == DR_MySQL {
+ if alias.Driver == DRMySQL {
alias.Engine = "INNODB"
}
diff --git a/orm/orm.go b/orm/orm.go
index f881433b..d00d6d03 100644
--- a/orm/orm.go
+++ b/orm/orm.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package orm provide ORM for MySQL/PostgreSQL/sqlite
// Simple Usage
//
// package main
@@ -59,12 +60,13 @@ import (
"time"
)
+// DebugQueries define the debug
const (
- Debug_Queries = iota
+ DebugQueries = iota
)
+// Define common vars
var (
- // DebugLevel = Debug_Queries
Debug = false
DebugLog = NewLog(os.Stderr)
DefaultRowsLimit = 1000
@@ -79,7 +81,10 @@ var (
ErrNotImplement = errors.New("have not implement")
)
+// Params stores the Params
type Params map[string]interface{}
+
+// ParamsList stores paramslist
type ParamsList []interface{}
type orm struct {
@@ -188,7 +193,7 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
o.setPk(mi, ind, id)
- cnt += 1
+ cnt++
}
} else {
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
@@ -489,7 +494,7 @@ func (o *orm) Driver() Driver {
return driver(o.alias.Name)
}
-// create new orm
+// NewOrm create new orm
func NewOrm() Ormer {
BootStrap() // execute only once
@@ -501,7 +506,7 @@ func NewOrm() Ormer {
return o
}
-// create a new ormer object with specify *sql.DB for query
+// NewOrmWithDB create a new ormer object with specify *sql.DB for query
func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
var al *alias
diff --git a/orm/orm_conds.go b/orm/orm_conds.go
index 6344653a..e56d6fbb 100644
--- a/orm/orm_conds.go
+++ b/orm/orm_conds.go
@@ -19,6 +19,7 @@ import (
"strings"
)
+// ExprSep define the expression separation
const (
ExprSep = "__"
)
@@ -32,19 +33,19 @@ type condValue struct {
isCond bool
}
-// condition struct.
+// Condition struct.
// work for WHERE conditions.
type Condition struct {
params []condValue
}
-// return new condition struct
+// NewCondition return new condition struct
func NewCondition() *Condition {
c := &Condition{}
return c
}
-// add expression to condition
+// And add expression to condition
func (c Condition) And(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf(" args cannot empty"))
@@ -53,7 +54,7 @@ func (c Condition) And(expr string, args ...interface{}) *Condition {
return &c
}
-// add NOT expression to condition
+// AndNot add NOT expression to condition
func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf(" args cannot empty"))
@@ -62,7 +63,7 @@ func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
return &c
}
-// combine a condition to current condition
+// AndCond combine a condition to current condition
func (c *Condition) AndCond(cond *Condition) *Condition {
c = c.clone()
if c == cond {
@@ -74,7 +75,7 @@ func (c *Condition) AndCond(cond *Condition) *Condition {
return c
}
-// add OR expression to condition
+// Or add OR expression to condition
func (c Condition) Or(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf(" args cannot empty"))
@@ -83,7 +84,7 @@ func (c Condition) Or(expr string, args ...interface{}) *Condition {
return &c
}
-// add OR NOT expression to condition
+// OrNot add OR NOT expression to condition
func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf(" args cannot empty"))
@@ -92,7 +93,7 @@ func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
return &c
}
-// combine a OR condition to current condition
+// OrCond combine a OR condition to current condition
func (c *Condition) OrCond(cond *Condition) *Condition {
c = c.clone()
if c == cond {
@@ -104,12 +105,12 @@ func (c *Condition) OrCond(cond *Condition) *Condition {
return c
}
-// check the condition arguments are empty or not.
+// IsEmpty check the condition arguments are empty or not.
func (c *Condition) IsEmpty() bool {
return len(c.params) == 0
}
-// clone a condition
+// clone clone a condition
func (c Condition) clone() *Condition {
return &c
}
diff --git a/orm/orm_log.go b/orm/orm_log.go
index 419d8e11..712eb219 100644
--- a/orm/orm_log.go
+++ b/orm/orm_log.go
@@ -23,11 +23,12 @@ import (
"time"
)
+// Log implement the log.Logger
type Log struct {
*log.Logger
}
-// set io.Writer to create a Logger.
+// NewLog set io.Writer to create a Logger.
func NewLog(out io.Writer) *Log {
d := new(Log)
d.Logger = log.New(out, "[ORM]", 1e9)
@@ -41,7 +42,7 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error
if err != nil {
flag = "FAIL"
}
- con := fmt.Sprintf(" - %s - [Queries/%s] - [%s / %11s / %7.1fms] - [%s]", t.Format(format_DateTime), alias.Name, flag, operaton, elsp, query)
+ con := fmt.Sprintf(" - %s - [Queries/%s] - [%s / %11s / %7.1fms] - [%s]", t.Format(formatDateTime), alias.Name, flag, operaton, elsp, query)
cons := make([]string, 0, len(args))
for _, arg := range args {
cons = append(cons, fmt.Sprintf("%v", arg))
diff --git a/orm/orm_querym2m.go b/orm/orm_querym2m.go
index 1eaccf72..b220bda6 100644
--- a/orm/orm_querym2m.go
+++ b/orm/orm_querym2m.go
@@ -14,9 +14,7 @@
package orm
-import (
- "reflect"
-)
+import "reflect"
// model to model struct
type queryM2M struct {
@@ -44,7 +42,21 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
dbase := orm.alias.DbBaser
var models []interface{}
+ var otherValues []interface{}
+ var otherNames []string
+ for _, colname := range mi.fields.dbcols {
+ if colname != mfi.column && colname != rfi.column && colname != fi.mi.fields.pk.column &&
+ mi.fields.columns[colname] != mi.fields.pk {
+ otherNames = append(otherNames, colname)
+ }
+ }
+ for i, md := range mds {
+ if reflect.Indirect(reflect.ValueOf(md)).Kind() != reflect.Struct && i > 0 {
+ otherValues = append(otherValues, md)
+ mds = append(mds[:i], mds[i+1:]...)
+ }
+ }
for _, md := range mds {
val := reflect.ValueOf(md)
if val.Kind() == reflect.Slice || val.Kind() == reflect.Array {
@@ -67,11 +79,9 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
names := []string{mfi.column, rfi.column}
values := make([]interface{}, 0, len(models)*2)
-
for _, md := range models {
ind := reflect.Indirect(reflect.ValueOf(md))
-
var v2 interface{}
if ind.Kind() != reflect.Struct {
v2 = ind.Interface()
@@ -81,11 +91,11 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
panic(ErrMissPK)
}
}
-
values = append(values, v1, v2)
}
-
+ names = append(names, otherNames...)
+ values = append(values, otherValues...)
return dbase.InsertValue(orm.db, mi, true, names, values)
}
diff --git a/orm/orm_queryset.go b/orm/orm_queryset.go
index 4f5d5485..802a1fe0 100644
--- a/orm/orm_queryset.go
+++ b/orm/orm_queryset.go
@@ -25,11 +25,12 @@ type colValue struct {
type operator int
+// define Col operations
const (
- Col_Add operator = iota
- Col_Minus
- Col_Multiply
- Col_Except
+ ColAdd operator = iota
+ ColMinus
+ ColMultiply
+ ColExcept
)
// ColValue do the field raw changes. e.g Nums = Nums + 10. usage:
@@ -38,7 +39,7 @@ const (
// }
func ColValue(opt operator, value interface{}) interface{} {
switch opt {
- case Col_Add, Col_Minus, Col_Multiply, Col_Except:
+ case ColAdd, ColMinus, ColMultiply, ColExcept:
default:
panic(fmt.Errorf("orm.ColValue wrong operator"))
}
@@ -60,7 +61,9 @@ type querySet struct {
relDepth int
limit int64
offset int64
+ groups []string
orders []string
+ distinct bool
orm *orm
}
@@ -105,6 +108,12 @@ func (o querySet) Offset(offset interface{}) QuerySeter {
return &o
}
+// add GROUP expression
+func (o querySet) GroupBy(exprs ...string) QuerySeter {
+ o.groups = exprs
+ return &o
+}
+
// add ORDER expression.
// "column" means ASC, "-column" means DESC.
func (o querySet) OrderBy(exprs ...string) QuerySeter {
@@ -112,17 +121,22 @@ func (o querySet) OrderBy(exprs ...string) QuerySeter {
return &o
}
+// add DISTINCT to SELECT
+func (o querySet) Distinct() QuerySeter {
+ o.distinct = true
+ return &o
+}
+
// set relation model to query together.
// it will query relation models and assign to parent model.
func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
- var related []string
if len(params) == 0 {
o.relDepth = DefaultRelsDepth
} else {
for _, p := range params {
switch val := p.(type) {
case string:
- related = append(o.related, val)
+ o.related = append(o.related, val)
case int:
o.relDepth = val
default:
@@ -130,7 +144,6 @@ func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
}
}
}
- o.related = related
return &o
}
diff --git a/orm/orm_raw.go b/orm/orm_raw.go
index 1393d414..cbb18064 100644
--- a/orm/orm_raw.go
+++ b/orm/orm_raw.go
@@ -165,14 +165,14 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
if str != "" {
if len(str) >= 19 {
str = str[:19]
- t, err := time.ParseInLocation(format_DateTime, str, o.orm.alias.TZ)
+ t, err := time.ParseInLocation(formatDateTime, str, o.orm.alias.TZ)
if err == nil {
t = t.In(DefaultTimeLoc)
ind.Set(reflect.ValueOf(t))
}
} else if len(str) >= 10 {
str = str[:10]
- t, err := time.ParseInLocation(format_Date, str, DefaultTimeLoc)
+ t, err := time.ParseInLocation(formatDate, str, DefaultTimeLoc)
if err == nil {
ind.Set(reflect.ValueOf(t))
}
@@ -255,12 +255,13 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr
// query data and map to container
func (o *rawSet) QueryRow(containers ...interface{}) error {
- refs := make([]interface{}, 0, len(containers))
- sInds := make([]reflect.Value, 0)
- eTyps := make([]reflect.Type, 0)
-
+ var (
+ refs = make([]interface{}, 0, len(containers))
+ sInds []reflect.Value
+ eTyps []reflect.Type
+ sMi *modelInfo
+ )
structMode := false
- var sMi *modelInfo
for _, container := range containers {
val := reflect.ValueOf(container)
ind := reflect.Indirect(val)
@@ -385,12 +386,13 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
// query data rows and map to container
func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
- refs := make([]interface{}, 0, len(containers))
- sInds := make([]reflect.Value, 0)
- eTyps := make([]reflect.Type, 0)
-
+ var (
+ refs = make([]interface{}, 0, len(containers))
+ sInds []reflect.Value
+ eTyps []reflect.Type
+ sMi *modelInfo
+ )
structMode := false
- var sMi *modelInfo
for _, container := range containers {
val := reflect.ValueOf(container)
sInd := reflect.Indirect(val)
@@ -557,10 +559,9 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
var rs *sql.Rows
- if r, err := o.orm.db.Query(query, args...); err != nil {
+ rs, err := o.orm.db.Query(query, args...)
+ if err != nil {
return 0, err
- } else {
- rs = r
}
defer rs.Close()
@@ -574,30 +575,30 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er
for rs.Next() {
if cnt == 0 {
- if columns, err := rs.Columns(); err != nil {
+ columns, err := rs.Columns()
+ if err != nil {
return 0, err
+ }
+ if len(needCols) > 0 {
+ indexs = make([]int, 0, len(needCols))
} else {
+ indexs = make([]int, 0, len(columns))
+ }
+
+ cols = columns
+ refs = make([]interface{}, len(cols))
+ for i := range refs {
+ var ref sql.NullString
+ refs[i] = &ref
+
if len(needCols) > 0 {
- indexs = make([]int, 0, len(needCols))
- } else {
- indexs = make([]int, 0, len(columns))
- }
-
- cols = columns
- refs = make([]interface{}, len(cols))
- for i, _ := range refs {
- var ref sql.NullString
- refs[i] = &ref
-
- if len(needCols) > 0 {
- for _, c := range needCols {
- if c == cols[i] {
- indexs = append(indexs, i)
- }
+ for _, c := range needCols {
+ if c == cols[i] {
+ indexs = append(indexs, i)
}
- } else {
- indexs = append(indexs, i)
}
+ } else {
+ indexs = append(indexs, i)
}
}
}
@@ -684,11 +685,9 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
- var rs *sql.Rows
- if r, err := o.orm.db.Query(query, args...); err != nil {
+ rs, err := o.orm.db.Query(query, args...)
+ if err != nil {
return 0, err
- } else {
- rs = r
}
defer rs.Close()
@@ -706,32 +705,29 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
for rs.Next() {
if cnt == 0 {
- if columns, err := rs.Columns(); err != nil {
+ columns, err := rs.Columns()
+ if err != nil {
return 0, err
- } else {
- cols = columns
- refs = make([]interface{}, len(cols))
- for i, _ := range refs {
- if keyCol == cols[i] {
- keyIndex = i
- }
-
- if typ == 1 || keyIndex == i {
- var ref sql.NullString
- refs[i] = &ref
- } else {
- var ref interface{}
- refs[i] = &ref
- }
-
- if valueCol == cols[i] {
- valueIndex = i
- }
+ }
+ cols = columns
+ refs = make([]interface{}, len(cols))
+ for i := range refs {
+ if keyCol == cols[i] {
+ keyIndex = i
}
-
- if keyIndex == -1 || valueIndex == -1 {
- panic(fmt.Errorf(" RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol))
+ if typ == 1 || keyIndex == i {
+ var ref sql.NullString
+ refs[i] = &ref
+ } else {
+ var ref interface{}
+ refs[i] = &ref
}
+ if valueCol == cols[i] {
+ valueIndex = i
+ }
+ }
+ if keyIndex == -1 || valueIndex == -1 {
+ panic(fmt.Errorf(" RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol))
}
}
diff --git a/orm/orm_test.go b/orm/orm_test.go
index e1c8e0f0..d6f6c7a9 100644
--- a/orm/orm_test.go
+++ b/orm/orm_test.go
@@ -31,13 +31,26 @@ import (
var _ = os.PathSeparator
var (
- test_Date = format_Date + " -0700"
- test_DateTime = format_DateTime + " -0700"
+ testDate = formatDate + " -0700"
+ testDateTime = formatDateTime + " -0700"
)
-func ValuesCompare(is bool, a interface{}, args ...interface{}) (err error, ok bool) {
+type argAny []interface{}
+
+// get interface by index from interface slice
+func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
+ if i >= 0 && i < len(a) {
+ r = a[i]
+ }
+ if len(args) > 0 {
+ r = args[0]
+ }
+ return
+}
+
+func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err error) {
if len(args) == 0 {
- return fmt.Errorf("miss args"), false
+ return false, fmt.Errorf("miss args")
}
b := args[0]
arg := argAny(args)
@@ -71,21 +84,21 @@ func ValuesCompare(is bool, a interface{}, args ...interface{}) (err error, ok b
wrongArg:
if err != nil {
- return err, false
+ return false, err
}
- return nil, true
+ return true, nil
}
func AssertIs(a interface{}, args ...interface{}) error {
- if err, ok := ValuesCompare(true, a, args...); ok == false {
+ if ok, err := ValuesCompare(true, a, args...); ok == false {
return err
}
return nil
}
func AssertNot(a interface{}, args ...interface{}) error {
- if err, ok := ValuesCompare(false, a, args...); ok == false {
+ if ok, err := ValuesCompare(false, a, args...); ok == false {
return err
}
return nil
@@ -171,8 +184,11 @@ func TestSyncDb(t *testing.T) {
RegisterModel(new(Comment))
RegisterModel(new(UserBig))
RegisterModel(new(PostTags))
+ RegisterModel(new(Group))
+ RegisterModel(new(Permission))
+ RegisterModel(new(GroupPermissions))
- err := RunSyncdb("default", true, false)
+ err := RunSyncdb("default", true, Debug)
throwFail(t, err)
modelCache.clean()
@@ -187,6 +203,9 @@ func TestRegisterModels(t *testing.T) {
RegisterModel(new(Comment))
RegisterModel(new(UserBig))
RegisterModel(new(PostTags))
+ RegisterModel(new(Group))
+ RegisterModel(new(Permission))
+ RegisterModel(new(GroupPermissions))
BootStrap()
@@ -208,7 +227,7 @@ func TestModelSyntax(t *testing.T) {
}
}
-var Data_Values = map[string]interface{}{
+var DataValues = map[string]interface{}{
"Boolean": true,
"Char": "char",
"Text": "text",
@@ -235,7 +254,7 @@ func TestDataTypes(t *testing.T) {
d := Data{}
ind := reflect.Indirect(reflect.ValueOf(&d))
- for name, value := range Data_Values {
+ for name, value := range DataValues {
e := ind.FieldByName(name)
e.Set(reflect.ValueOf(value))
}
@@ -244,22 +263,22 @@ func TestDataTypes(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
- d = Data{Id: 1}
+ d = Data{ID: 1}
err = dORM.Read(&d)
throwFail(t, err)
ind = reflect.Indirect(reflect.ValueOf(&d))
- for name, value := range Data_Values {
+ for name, value := range DataValues {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
case "Date":
- vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
- value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
+ vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
+ value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
case "DateTime":
- vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
- value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
+ vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
+ value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
}
throwFail(t, AssertIs(vu == value, true), value, vu)
}
@@ -278,7 +297,7 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
- d = DataNull{Id: 1}
+ d = DataNull{ID: 1}
err = dORM.Read(&d)
throwFail(t, err)
@@ -309,7 +328,7 @@ func TestNullDataTypes(t *testing.T) {
_, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec()
throwFail(t, err)
- d = DataNull{Id: 2}
+ d = DataNull{ID: 2}
err = dORM.Read(&d)
throwFail(t, err)
@@ -362,7 +381,7 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 3))
- d = DataNull{Id: 3}
+ d = DataNull{ID: 3}
err = dORM.Read(&d)
throwFail(t, err)
@@ -402,7 +421,7 @@ func TestDataCustomTypes(t *testing.T) {
d := DataCustom{}
ind := reflect.Indirect(reflect.ValueOf(&d))
- for name, value := range Data_Values {
+ for name, value := range DataValues {
e := ind.FieldByName(name)
if !e.IsValid() {
continue
@@ -414,13 +433,13 @@ func TestDataCustomTypes(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
- d = DataCustom{Id: 1}
+ d = DataCustom{ID: 1}
err = dORM.Read(&d)
throwFail(t, err)
ind = reflect.Indirect(reflect.ValueOf(&d))
- for name, value := range Data_Values {
+ for name, value := range DataValues {
e := ind.FieldByName(name)
if !e.IsValid() {
continue
@@ -451,7 +470,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
- u := &User{Id: user.Id}
+ u := &User{ID: user.ID}
err = dORM.Read(u)
throwFail(t, err)
@@ -461,8 +480,8 @@ func TestCRUD(t *testing.T) {
throwFail(t, AssertIs(u.Status, 3))
throwFail(t, AssertIs(u.IsStaff, true))
throwFail(t, AssertIs(u.IsActive, true))
- throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), test_Date))
- throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), test_DateTime))
+ throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), testDate))
+ throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), testDateTime))
user.UserName = "astaxie"
user.Profile = profile
@@ -470,11 +489,11 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
- u = &User{Id: user.Id}
+ u = &User{ID: user.ID}
err = dORM.Read(u)
throwFailNow(t, err)
throwFail(t, AssertIs(u.UserName, "astaxie"))
- throwFail(t, AssertIs(u.Profile.Id, profile.Id))
+ throwFail(t, AssertIs(u.Profile.ID, profile.ID))
u = &User{UserName: "astaxie", Password: "pass"}
err = dORM.Read(u, "UserName")
@@ -487,7 +506,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
- u = &User{Id: user.Id}
+ u = &User{ID: user.ID}
err = dORM.Read(u)
throwFailNow(t, err)
throwFail(t, AssertIs(u.UserName, "QQ"))
@@ -497,7 +516,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
- u = &User{Id: user.Id}
+ u = &User{ID: user.ID}
err = dORM.Read(u)
throwFail(t, err)
throwFail(t, AssertIs(true, u.Profile == nil))
@@ -506,7 +525,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
- u = &User{Id: 100}
+ u = &User{ID: 100}
err = dORM.Read(u)
throwFail(t, AssertIs(err, ErrNoRows))
@@ -516,7 +535,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
- ub = UserBig{Id: 1}
+ ub = UserBig{ID: 1}
err = dORM.Read(&ub)
throwFail(t, err)
throwFail(t, AssertIs(ub.Name, "name"))
@@ -586,29 +605,29 @@ func TestInsertTestData(t *testing.T) {
throwFail(t, AssertIs(id, 4))
tags := []*Tag{
- &Tag{Name: "golang", BestPost: &Post{Id: 2}},
- &Tag{Name: "example"},
- &Tag{Name: "format"},
- &Tag{Name: "c++"},
+ {Name: "golang", BestPost: &Post{ID: 2}},
+ {Name: "example"},
+ {Name: "format"},
+ {Name: "c++"},
}
posts := []*Post{
- &Post{User: users[0], Tags: []*Tag{tags[0]}, Title: "Introduction", Content: `Go is a new language. Although it borrows ideas from existing languages, it has unusual properties that make effective Go programs different in character from programs written in its relatives. A straightforward translation of a C++ or Java program into Go is unlikely to produce a satisfactory result—Java programs are written in Java, not Go. On the other hand, thinking about the problem from a Go perspective could produce a successful but quite different program. In other words, to write Go well, it's important to understand its properties and idioms. It's also important to know the established conventions for programming in Go, such as naming, formatting, program construction, and so on, so that programs you write will be easy for other Go programmers to understand.
+ {User: users[0], Tags: []*Tag{tags[0]}, Title: "Introduction", Content: `Go is a new language. Although it borrows ideas from existing languages, it has unusual properties that make effective Go programs different in character from programs written in its relatives. A straightforward translation of a C++ or Java program into Go is unlikely to produce a satisfactory result—Java programs are written in Java, not Go. On the other hand, thinking about the problem from a Go perspective could produce a successful but quite different program. In other words, to write Go well, it's important to understand its properties and idioms. It's also important to know the established conventions for programming in Go, such as naming, formatting, program construction, and so on, so that programs you write will be easy for other Go programmers to understand.
This document gives tips for writing clear, idiomatic Go code. It augments the language specification, the Tour of Go, and How to Write Go Code, all of which you should read first.`},
- &Post{User: users[1], Tags: []*Tag{tags[0], tags[1]}, Title: "Examples", Content: `The Go package sources are intended to serve not only as the core library but also as examples of how to use the language. Moreover, many of the packages contain working, self-contained executable examples you can run directly from the golang.org web site, such as this one (click on the word "Example" to open it up). If you have a question about how to approach a problem or how something might be implemented, the documentation, code and examples in the library can provide answers, ideas and background.`},
- &Post{User: users[1], Tags: []*Tag{tags[0], tags[2]}, Title: "Formatting", Content: `Formatting issues are the most contentious but the least consequential. People can adapt to different formatting styles but it's better if they don't have to, and less time is devoted to the topic if everyone adheres to the same style. The problem is how to approach this Utopia without a long prescriptive style guide.
+ {User: users[1], Tags: []*Tag{tags[0], tags[1]}, Title: "Examples", Content: `The Go package sources are intended to serve not only as the core library but also as examples of how to use the language. Moreover, many of the packages contain working, self-contained executable examples you can run directly from the golang.org web site, such as this one (click on the word "Example" to open it up). If you have a question about how to approach a problem or how something might be implemented, the documentation, code and examples in the library can provide answers, ideas and background.`},
+ {User: users[1], Tags: []*Tag{tags[0], tags[2]}, Title: "Formatting", Content: `Formatting issues are the most contentious but the least consequential. People can adapt to different formatting styles but it's better if they don't have to, and less time is devoted to the topic if everyone adheres to the same style. The problem is how to approach this Utopia without a long prescriptive style guide.
With Go we take an unusual approach and let the machine take care of most formatting issues. The gofmt program (also available as go fmt, which operates at the package level rather than source file level) reads a Go program and emits the source in a standard style of indentation and vertical alignment, retaining and if necessary reformatting comments. If you want to know how to handle some new layout situation, run gofmt; if the answer doesn't seem right, rearrange your program (or file a bug about gofmt), don't work around it.`},
- &Post{User: users[2], Tags: []*Tag{tags[3]}, Title: "Commentary", Content: `Go provides C-style /* */ block comments and C++-style // line comments. Line comments are the norm; block comments appear mostly as package comments, but are useful within an expression or to disable large swaths of code.
+ {User: users[2], Tags: []*Tag{tags[3]}, Title: "Commentary", Content: `Go provides C-style /* */ block comments and C++-style // line comments. Line comments are the norm; block comments appear mostly as package comments, but are useful within an expression or to disable large swaths of code.
The program—and web server—godoc processes Go source files to extract documentation about the contents of the package. Comments that appear before top-level declarations, with no intervening newlines, are extracted along with the declaration to serve as explanatory text for the item. The nature and style of these comments determines the quality of the documentation godoc produces.`},
}
comments := []*Comment{
- &Comment{Post: posts[0], Content: "a comment"},
- &Comment{Post: posts[1], Content: "yes"},
- &Comment{Post: posts[1]},
- &Comment{Post: posts[1]},
- &Comment{Post: posts[2]},
- &Comment{Post: posts[2]},
+ {Post: posts[0], Content: "a comment"},
+ {Post: posts[1], Content: "yes"},
+ {Post: posts[1]},
+ {Post: posts[1]},
+ {Post: posts[2]},
+ {Post: posts[2]},
}
for _, tag := range tags {
@@ -635,10 +654,47 @@ The program—and web server—godoc processes Go source files to extract docume
throwFail(t, err)
throwFail(t, AssertIs(id > 0, true))
}
+
+ permissions := []*Permission{
+ {Name: "writePosts"},
+ {Name: "readComments"},
+ {Name: "readPosts"},
+ }
+
+ groups := []*Group{
+ {
+ Name: "admins",
+ Permissions: []*Permission{permissions[0], permissions[1], permissions[2]},
+ },
+ {
+ Name: "users",
+ Permissions: []*Permission{permissions[1], permissions[2]},
+ },
+ }
+
+ for _, permission := range permissions {
+ id, err := dORM.Insert(permission)
+ throwFail(t, err)
+ throwFail(t, AssertIs(id > 0, true))
+ }
+
+ for _, group := range groups {
+ _, err := dORM.Insert(group)
+ throwFail(t, err)
+ throwFail(t, AssertIs(id > 0, true))
+
+ num := len(group.Permissions)
+ if num > 0 {
+ nums, err := dORM.QueryM2M(group, "permissions").Add(group.Permissions)
+ throwFailNow(t, err)
+ throwFailNow(t, AssertIs(nums, num))
+ }
+ }
+
}
func TestCustomField(t *testing.T) {
- user := User{Id: 2}
+ user := User{ID: 2}
err := dORM.Read(&user)
throwFailNow(t, err)
@@ -648,7 +704,7 @@ func TestCustomField(t *testing.T) {
_, err = dORM.Update(&user, "Langs", "Extra")
throwFailNow(t, err)
- user = User{Id: 2}
+ user = User{ID: 2}
err = dORM.Read(&user)
throwFailNow(t, err)
throwFailNow(t, AssertIs(len(user.Langs), 2))
@@ -702,7 +758,7 @@ func TestOperators(t *testing.T) {
var shouldNum int
- if IsSqlite {
+ if IsSqlite || IsTidb {
shouldNum = 2
} else {
shouldNum = 0
@@ -740,7 +796,7 @@ func TestOperators(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
- if IsSqlite {
+ if IsSqlite || IsTidb {
shouldNum = 1
} else {
shouldNum = 0
@@ -758,7 +814,7 @@ func TestOperators(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 2))
- if IsSqlite {
+ if IsSqlite || IsTidb {
shouldNum = 2
} else {
shouldNum = 0
@@ -889,9 +945,9 @@ func TestAll(t *testing.T) {
throwFailNow(t, AssertIs(users2[0].UserName, "slene"))
throwFailNow(t, AssertIs(users2[1].UserName, "astaxie"))
throwFailNow(t, AssertIs(users2[2].UserName, "nobody"))
- throwFailNow(t, AssertIs(users2[0].Id, 0))
- throwFailNow(t, AssertIs(users2[1].Id, 0))
- throwFailNow(t, AssertIs(users2[2].Id, 0))
+ throwFailNow(t, AssertIs(users2[0].ID, 0))
+ throwFailNow(t, AssertIs(users2[1].ID, 0))
+ throwFailNow(t, AssertIs(users2[2].ID, 0))
throwFailNow(t, AssertIs(users2[0].Profile == nil, false))
throwFailNow(t, AssertIs(users2[1].Profile == nil, false))
throwFailNow(t, AssertIs(users2[2].Profile == nil, true))
@@ -986,6 +1042,10 @@ func TestValuesFlat(t *testing.T) {
}
func TestRelatedSel(t *testing.T) {
+ if IsTidb {
+ // Skip it. TiDB does not support relation now.
+ return
+ }
qs := dORM.QueryTable("user")
num, err := qs.Filter("profile__age", 28).Count()
throwFail(t, err)
@@ -1112,7 +1172,7 @@ func TestReverseQuery(t *testing.T) {
func TestLoadRelated(t *testing.T) {
// load reverse foreign key
- user := User{Id: 3}
+ user := User{ID: 3}
err := dORM.Read(&user)
throwFailNow(t, err)
@@ -1121,7 +1181,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 2))
throwFailNow(t, AssertIs(len(user.Posts), 2))
- throwFailNow(t, AssertIs(user.Posts[0].User.Id, 3))
+ throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3))
num, err = dORM.LoadRelated(&user, "Posts", true)
throwFailNow(t, err)
@@ -1143,8 +1203,8 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting"))
// load reverse one to one
- profile := Profile{Id: 3}
- profile.BestPost = &Post{Id: 2}
+ profile := Profile{ID: 3}
+ profile.BestPost = &Post{ID: 2}
num, err = dORM.Update(&profile, "BestPost")
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1))
@@ -1183,7 +1243,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(user.Profile.BestPost == nil, false))
throwFailNow(t, AssertIs(user.Profile.BestPost.Title, "Examples"))
- post := Post{Id: 2}
+ post := Post{ID: 2}
// load rel foreign key
err = dORM.Read(&post)
@@ -1204,7 +1264,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(post.User.Profile.Age, 30))
// load rel m2m
- post = Post{Id: 2}
+ post = Post{ID: 2}
err = dORM.Read(&post)
throwFailNow(t, err)
@@ -1224,7 +1284,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(post.Tags[0].BestPost.User.UserName, "astaxie"))
// load reverse m2m
- tag := Tag{Id: 1}
+ tag := Tag{ID: 1}
err = dORM.Read(&tag)
throwFailNow(t, err)
@@ -1233,22 +1293,22 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3))
throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction"))
- throwFailNow(t, AssertIs(tag.Posts[0].User.Id, 2))
+ throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2))
throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true))
num, err = dORM.LoadRelated(&tag, "Posts", true)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3))
throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction"))
- throwFailNow(t, AssertIs(tag.Posts[0].User.Id, 2))
+ throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2))
throwFailNow(t, AssertIs(tag.Posts[0].User.UserName, "slene"))
}
func TestQueryM2M(t *testing.T) {
- post := Post{Id: 4}
+ post := Post{ID: 4}
m2m := dORM.QueryM2M(&post, "Tags")
- tag1 := []*Tag{&Tag{Name: "TestTag1"}, &Tag{Name: "TestTag2"}}
+ tag1 := []*Tag{{Name: "TestTag1"}, {Name: "TestTag2"}}
tag2 := &Tag{Name: "TestTag3"}
tag3 := []interface{}{&Tag{Name: "TestTag4"}}
@@ -1311,7 +1371,7 @@ func TestQueryM2M(t *testing.T) {
m2m = dORM.QueryM2M(&tag, "Posts")
- post1 := []*Post{&Post{Title: "TestPost1"}, &Post{Title: "TestPost2"}}
+ post1 := []*Post{{Title: "TestPost1"}, {Title: "TestPost2"}}
post2 := &Post{Title: "TestPost3"}
post3 := []interface{}{&Post{Title: "TestPost4"}}
@@ -1319,7 +1379,7 @@ func TestQueryM2M(t *testing.T) {
for _, post := range posts {
p := post.(*Post)
- p.User = &User{Id: 1}
+ p.User = &User{ID: 1}
_, err := dORM.Insert(post)
throwFailNow(t, err)
}
@@ -1394,6 +1454,18 @@ func TestQueryRelate(t *testing.T) {
// throwFailNow(t, AssertIs(num, 2))
}
+func TestPkManyRelated(t *testing.T) {
+ permission := &Permission{Name: "readPosts"}
+ err := dORM.Read(permission, "Name")
+ throwFailNow(t, err)
+
+ var groups []*Group
+ qs := dORM.QueryTable("Group")
+ num, err := qs.Filter("Permissions__Permission", permission.ID).All(&groups)
+ throwFailNow(t, err)
+ throwFailNow(t, AssertIs(num, 2))
+}
+
func TestPrepareInsert(t *testing.T) {
qs := dORM.QueryTable("user")
i, err := qs.PrepareInsert()
@@ -1459,10 +1531,10 @@ func TestRawQueryRow(t *testing.T) {
Decimal float64
)
- data_values := make(map[string]interface{}, len(Data_Values))
+ dataValues := make(map[string]interface{}, len(DataValues))
- for k, v := range Data_Values {
- data_values[strings.ToLower(k)] = v
+ for k, v := range DataValues {
+ dataValues[strings.ToLower(k)] = v
}
Q := dDbBaser.TableQuote()
@@ -1488,14 +1560,14 @@ func TestRawQueryRow(t *testing.T) {
throwFail(t, AssertIs(id, 1))
case "date":
v = v.(time.Time).In(DefaultTimeLoc)
- value := data_values[col].(time.Time).In(DefaultTimeLoc)
- throwFail(t, AssertIs(v, value, test_Date))
+ value := dataValues[col].(time.Time).In(DefaultTimeLoc)
+ throwFail(t, AssertIs(v, value, testDate))
case "datetime":
v = v.(time.Time).In(DefaultTimeLoc)
- value := data_values[col].(time.Time).In(DefaultTimeLoc)
- throwFail(t, AssertIs(v, value, test_DateTime))
+ value := dataValues[col].(time.Time).In(DefaultTimeLoc)
+ throwFail(t, AssertIs(v, value, testDateTime))
default:
- throwFail(t, AssertIs(v, data_values[col]))
+ throwFail(t, AssertIs(v, dataValues[col]))
}
}
@@ -1529,16 +1601,16 @@ func TestQueryRows(t *testing.T) {
ind := reflect.Indirect(reflect.ValueOf(datas[0]))
- for name, value := range Data_Values {
+ for name, value := range DataValues {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
case "Date":
- vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
- value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
+ vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
+ value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
case "DateTime":
- vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
- value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
+ vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
+ value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
}
throwFail(t, AssertIs(vu == value, true), value, vu)
}
@@ -1553,16 +1625,16 @@ func TestQueryRows(t *testing.T) {
ind = reflect.Indirect(reflect.ValueOf(datas2[0]))
- for name, value := range Data_Values {
+ for name, value := range DataValues {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
case "Date":
- vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
- value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
+ vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
+ value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
case "DateTime":
- vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
- value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
+ vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
+ value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
}
throwFail(t, AssertIs(vu == value, true), value, vu)
}
@@ -1699,25 +1771,25 @@ func TestUpdate(t *testing.T) {
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{
- "Nums": ColValue(Col_Add, 100),
+ "Nums": ColValue(ColAdd, 100),
})
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{
- "Nums": ColValue(Col_Minus, 50),
+ "Nums": ColValue(ColMinus, 50),
})
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{
- "Nums": ColValue(Col_Multiply, 3),
+ "Nums": ColValue(ColMultiply, 3),
})
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{
- "Nums": ColValue(Col_Except, 5),
+ "Nums": ColValue(ColExcept, 5),
})
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
@@ -1838,15 +1910,15 @@ func TestReadOrCreate(t *testing.T) {
throwFail(t, AssertIs(u.Status, 7))
throwFail(t, AssertIs(u.IsStaff, false))
throwFail(t, AssertIs(u.IsActive, true))
- throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), test_Date))
- throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), test_DateTime))
+ throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), testDate))
+ throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), testDateTime))
nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"}
created, pk, err = dORM.ReadOrCreate(nu, "UserName")
throwFail(t, err)
throwFail(t, AssertIs(created, false))
- throwFail(t, AssertIs(nu.Id, u.Id))
- throwFail(t, AssertIs(pk, u.Id))
+ throwFail(t, AssertIs(nu.ID, u.ID))
+ throwFail(t, AssertIs(pk, u.ID))
throwFail(t, AssertIs(nu.UserName, u.UserName))
throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above
throwFail(t, AssertIs(nu.Password, u.Password))
diff --git a/orm/qb.go b/orm/qb.go
index efe368db..9f778916 100644
--- a/orm/qb.go
+++ b/orm/qb.go
@@ -16,6 +16,7 @@ package orm
import "errors"
+// QueryBuilder is the Query builder interface
type QueryBuilder interface {
Select(fields ...string) QueryBuilder
From(tables ...string) QueryBuilder
@@ -43,15 +44,18 @@ type QueryBuilder interface {
String() string
}
+// NewQueryBuilder return the QueryBuilder
func NewQueryBuilder(driver string) (qb QueryBuilder, err error) {
if driver == "mysql" {
qb = new(MySQLQueryBuilder)
+ } else if driver == "tidb" {
+ qb = new(TiDBQueryBuilder)
} else if driver == "postgres" {
- err = errors.New("postgres query builder is not supported yet!")
+ err = errors.New("postgres query builder is not supported yet")
} else if driver == "sqlite" {
- err = errors.New("sqlite query builder is not supported yet!")
+ err = errors.New("sqlite query builder is not supported yet")
} else {
- err = errors.New("unknown driver for query builder!")
+ err = errors.New("unknown driver for query builder")
}
return
}
diff --git a/orm/qb_mysql.go b/orm/qb_mysql.go
index 9ce9b7d9..886bc50e 100644
--- a/orm/qb_mysql.go
+++ b/orm/qb_mysql.go
@@ -20,134 +20,160 @@ import (
"strings"
)
-const COMMA_SPACE = ", "
+// CommaSpace is the separation
+const CommaSpace = ", "
+// MySQLQueryBuilder is the SQL build
type MySQLQueryBuilder struct {
Tokens []string
}
+// Select will join the fields
func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder {
- qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, COMMA_SPACE))
+ qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace))
return qb
}
+// From join the tables
func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder {
- qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, COMMA_SPACE))
+ qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace))
return qb
}
+// InnerJoin INNER JOIN the table
func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INNER JOIN", table)
return qb
}
+// LeftJoin LEFT JOIN the table
func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LEFT JOIN", table)
return qb
}
+// RightJoin RIGHT JOIN the table
func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table)
return qb
}
+// On join with on cond
func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "ON", cond)
return qb
}
+// Where join the Where cond
func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "WHERE", cond)
return qb
}
+// And join the and cond
func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "AND", cond)
return qb
}
+// Or join the or cond
func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OR", cond)
return qb
}
+// In join the IN (vals)
func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder {
- qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, COMMA_SPACE), ")")
+ qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")")
return qb
}
+// OrderBy join the Order by fields
func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder {
- qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, COMMA_SPACE))
+ qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace))
return qb
}
+// Asc join the asc
func (qb *MySQLQueryBuilder) Asc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "ASC")
return qb
}
+// Desc join the desc
func (qb *MySQLQueryBuilder) Desc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "DESC")
return qb
}
+// Limit join the limit num
func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit))
return qb
}
+// Offset join the offset num
func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset))
return qb
}
+// GroupBy join the Group by fields
func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder {
- qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, COMMA_SPACE))
+ qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace))
return qb
}
+// Having join the Having cond
func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "HAVING", cond)
return qb
}
+// Update join the update table
func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder {
- qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, COMMA_SPACE))
+ qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace))
return qb
}
+// Set join the set kv
func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder {
- qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, COMMA_SPACE))
+ qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace))
return qb
}
+// Delete join the Delete tables
func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "DELETE")
if len(tables) != 0 {
- qb.Tokens = append(qb.Tokens, strings.Join(tables, COMMA_SPACE))
+ qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace))
}
return qb
}
+// InsertInto join the insert SQL
func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INSERT INTO", table)
if len(fields) != 0 {
- fieldsStr := strings.Join(fields, COMMA_SPACE)
+ fieldsStr := strings.Join(fields, CommaSpace)
qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")")
}
return qb
}
+// Values join the Values(vals)
func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder {
- valsStr := strings.Join(vals, COMMA_SPACE)
+ valsStr := strings.Join(vals, CommaSpace)
qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")")
return qb
}
+// Subquery join the sub as alias
func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string {
return fmt.Sprintf("(%s) AS %s", sub, alias)
}
+// String join all Tokens
func (qb *MySQLQueryBuilder) String() string {
return strings.Join(qb.Tokens, " ")
}
diff --git a/orm/qb_tidb.go b/orm/qb_tidb.go
new file mode 100644
index 00000000..c504049e
--- /dev/null
+++ b/orm/qb_tidb.go
@@ -0,0 +1,176 @@
+// Copyright 2015 TiDB Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package orm
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+)
+
+// TiDBQueryBuilder is the SQL build
+type TiDBQueryBuilder struct {
+ Tokens []string
+}
+
+// Select will join the fields
+func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace))
+ return qb
+}
+
+// From join the tables
+func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace))
+ return qb
+}
+
+// InnerJoin INNER JOIN the table
+func (qb *TiDBQueryBuilder) InnerJoin(table string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "INNER JOIN", table)
+ return qb
+}
+
+// LeftJoin LEFT JOIN the table
+func (qb *TiDBQueryBuilder) LeftJoin(table string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "LEFT JOIN", table)
+ return qb
+}
+
+// RightJoin RIGHT JOIN the table
+func (qb *TiDBQueryBuilder) RightJoin(table string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table)
+ return qb
+}
+
+// On join with on cond
+func (qb *TiDBQueryBuilder) On(cond string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "ON", cond)
+ return qb
+}
+
+// Where join the Where cond
+func (qb *TiDBQueryBuilder) Where(cond string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "WHERE", cond)
+ return qb
+}
+
+// And join the and cond
+func (qb *TiDBQueryBuilder) And(cond string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "AND", cond)
+ return qb
+}
+
+// Or join the or cond
+func (qb *TiDBQueryBuilder) Or(cond string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "OR", cond)
+ return qb
+}
+
+// In join the IN (vals)
+func (qb *TiDBQueryBuilder) In(vals ...string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")")
+ return qb
+}
+
+// OrderBy join the Order by fields
+func (qb *TiDBQueryBuilder) OrderBy(fields ...string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace))
+ return qb
+}
+
+// Asc join the asc
+func (qb *TiDBQueryBuilder) Asc() QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "ASC")
+ return qb
+}
+
+// Desc join the desc
+func (qb *TiDBQueryBuilder) Desc() QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "DESC")
+ return qb
+}
+
+// Limit join the limit num
+func (qb *TiDBQueryBuilder) Limit(limit int) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit))
+ return qb
+}
+
+// Offset join the offset num
+func (qb *TiDBQueryBuilder) Offset(offset int) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset))
+ return qb
+}
+
+// GroupBy join the Group by fields
+func (qb *TiDBQueryBuilder) GroupBy(fields ...string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace))
+ return qb
+}
+
+// Having join the Having cond
+func (qb *TiDBQueryBuilder) Having(cond string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "HAVING", cond)
+ return qb
+}
+
+// Update join the update table
+func (qb *TiDBQueryBuilder) Update(tables ...string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace))
+ return qb
+}
+
+// Set join the set kv
+func (qb *TiDBQueryBuilder) Set(kv ...string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace))
+ return qb
+}
+
+// Delete join the Delete tables
+func (qb *TiDBQueryBuilder) Delete(tables ...string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "DELETE")
+ if len(tables) != 0 {
+ qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace))
+ }
+ return qb
+}
+
+// InsertInto join the insert SQL
+func (qb *TiDBQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "INSERT INTO", table)
+ if len(fields) != 0 {
+ fieldsStr := strings.Join(fields, CommaSpace)
+ qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")")
+ }
+ return qb
+}
+
+// Values join the Values(vals)
+func (qb *TiDBQueryBuilder) Values(vals ...string) QueryBuilder {
+ valsStr := strings.Join(vals, CommaSpace)
+ qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")")
+ return qb
+}
+
+// Subquery join the sub as alias
+func (qb *TiDBQueryBuilder) Subquery(sub string, alias string) string {
+ return fmt.Sprintf("(%s) AS %s", sub, alias)
+}
+
+// String join all Tokens
+func (qb *TiDBQueryBuilder) String() string {
+ return strings.Join(qb.Tokens, " ")
+}
diff --git a/orm/types.go b/orm/types.go
index b46be4fc..5fac5fed 100644
--- a/orm/types.go
+++ b/orm/types.go
@@ -20,13 +20,13 @@ import (
"time"
)
-// database driver
+// Driver define database driver
type Driver interface {
Name() string
Type() DriverType
}
-// field info
+// Fielder define field info
type Fielder interface {
String() string
FieldType() int
@@ -34,84 +34,315 @@ type Fielder interface {
RawValue() interface{}
}
-// orm struct
+// Ormer define the orm interface
type Ormer interface {
- Read(interface{}, ...string) error
- ReadOrCreate(interface{}, string, ...string) (bool, int64, error)
+ // read data to model
+ // for example:
+ // this will find User by Id field
+ // u = &User{Id: user.Id}
+ // err = Ormer.Read(u)
+ // this will find User by UserName field
+ // u = &User{UserName: "astaxie", Password: "pass"}
+ // err = Ormer.Read(u, "UserName")
+ Read(md interface{}, cols ...string) error
+ // Try to read a row from the database, or insert one if it doesn't exist
+ ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error)
+ // insert model data to database
+ // for example:
+ // user := new(User)
+ // id, err = Ormer.Insert(user)
+ // user must a pointer and Insert will set user's pk field
Insert(interface{}) (int64, error)
- InsertMulti(int, interface{}) (int64, error)
- Update(interface{}, ...string) (int64, error)
- Delete(interface{}) (int64, error)
- LoadRelated(interface{}, string, ...interface{}) (int64, error)
- QueryM2M(interface{}, string) QueryM2Mer
- QueryTable(interface{}) QuerySeter
- Using(string) error
+ // insert some models to database
+ InsertMulti(bulk int, mds interface{}) (int64, error)
+ // update model to database.
+ // cols set the columns those want to update.
+ // find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns
+ // for example:
+ // user := User{Id: 2}
+ // user.Langs = append(user.Langs, "zh-CN", "en-US")
+ // user.Extra.Name = "beego"
+ // user.Extra.Data = "orm"
+ // num, err = Ormer.Update(&user, "Langs", "Extra")
+ Update(md interface{}, cols ...string) (int64, error)
+ // delete model in database
+ Delete(md interface{}) (int64, error)
+ // load related models to md model.
+ // args are limit, offset int and order string.
+ //
+ // example:
+ // Ormer.LoadRelated(post,"Tags")
+ // for _,tag := range post.Tags{...}
+ //args[0] bool true useDefaultRelsDepth ; false depth 0
+ //args[0] int loadRelationDepth
+ //args[1] int limit default limit 1000
+ //args[2] int offset default offset 0
+ //args[3] string order for example : "-Id"
+ // make sure the relation is defined in model struct tags.
+ LoadRelated(md interface{}, name string, args ...interface{}) (int64, error)
+ // create a models to models queryer
+ // for example:
+ // post := Post{Id: 4}
+ // m2m := Ormer.QueryM2M(&post, "Tags")
+ QueryM2M(md interface{}, name string) QueryM2Mer
+ // return a QuerySeter for table operations.
+ // table name can be string or struct.
+ // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
+ QueryTable(ptrStructOrTableName interface{}) QuerySeter
+ // switch to another registered database driver by given name.
+ Using(name string) error
+ // begin transaction
+ // for example:
+ // o := NewOrm()
+ // err := o.Begin()
+ // ...
+ // err = o.Rollback()
Begin() error
+ // commit transaction
Commit() error
+ // rollback transaction
Rollback() error
- Raw(string, ...interface{}) RawSeter
+ // return a raw query seter for raw sql string.
+ // for example:
+ // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec()
+ // // update user testing's name to slene
+ Raw(query string, args ...interface{}) RawSeter
Driver() Driver
}
-// insert prepared statement
+// Inserter insert prepared statement
type Inserter interface {
Insert(interface{}) (int64, error)
Close() error
}
-// query seter
+// QuerySeter query seter
type QuerySeter interface {
+ // add condition expression to QuerySeter.
+ // for example:
+ // filter by UserName == 'slene'
+ // qs.Filter("UserName", "slene")
+ // sql : left outer join profile on t0.id1==t1.id2 where t1.age == 28
+ // Filter("profile__Age", 28)
+ // // time compare
+ // qs.Filter("created", time.Now())
Filter(string, ...interface{}) QuerySeter
+ // add NOT condition to querySeter.
+ // have the same usage as Filter
Exclude(string, ...interface{}) QuerySeter
+ // set condition to QuerySeter.
+ // sql's where condition
+ // cond := orm.NewCondition()
+ // cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000)
+ // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
+ // num, err := qs.SetCond(cond1).Count()
SetCond(*Condition) QuerySeter
- Limit(interface{}, ...interface{}) QuerySeter
- Offset(interface{}) QuerySeter
- OrderBy(...string) QuerySeter
- RelatedSel(...interface{}) QuerySeter
+ // add LIMIT value.
+ // args[0] means offset, e.g. LIMIT num,offset.
+ // if Limit <= 0 then Limit will be set to default limit ,eg 1000
+ // if QuerySeter doesn't call Limit, the sql's Limit will be set to default limit, eg 1000
+ // for example:
+ // qs.Limit(10, 2)
+ // // sql-> limit 10 offset 2
+ Limit(limit interface{}, args ...interface{}) QuerySeter
+ // add OFFSET value
+ // same as Limit function's args[0]
+ Offset(offset interface{}) QuerySeter
+ // add ORDER expression.
+ // "column" means ASC, "-column" means DESC.
+ // for example:
+ // qs.OrderBy("-status")
+ OrderBy(exprs ...string) QuerySeter
+ // set relation model to query together.
+ // it will query relation models and assign to parent model.
+ // for example:
+ // // will load all related fields use left join .
+ // qs.RelatedSel().One(&user)
+ // // will load related field only profile
+ // qs.RelatedSel("profile").One(&user)
+ // user.Profile.Age = 32
+ RelatedSel(params ...interface{}) QuerySeter
+ // return QuerySeter execution result number
+ // for example:
+ // num, err = qs.Filter("profile__age__gt", 28).Count()
Count() (int64, error)
+ // check result empty or not after QuerySeter executed
+ // the same as QuerySeter.Count > 0
Exist() bool
- Update(Params) (int64, error)
+ // execute update with parameters
+ // for example:
+ // num, err = qs.Filter("user_name", "slene").Update(Params{
+ // "Nums": ColValue(Col_Minus, 50),
+ // }) // user slene's Nums will minus 50
+ // num, err = qs.Filter("UserName", "slene").Update(Params{
+ // "user_name": "slene2"
+ // }) // user slene's name will change to slene2
+ Update(values Params) (int64, error)
+ // delete from table
+ //for example:
+ // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete()
+ // //delete two user who's name is testing1 or testing2
Delete() (int64, error)
+ // return a insert queryer.
+ // it can be used in times.
+ // example:
+ // i,err := sq.PrepareInsert()
+ // num, err = i.Insert(&user1) // user table will add one record user1 at once
+ // num, err = i.Insert(&user2) // user table will add one record user2 at once
+ // err = i.Close() //don't forget call Close
PrepareInsert() (Inserter, error)
- All(interface{}, ...string) (int64, error)
- One(interface{}, ...string) error
- Values(*[]Params, ...string) (int64, error)
- ValuesList(*[]ParamsList, ...string) (int64, error)
- ValuesFlat(*ParamsList, string) (int64, error)
- RowsToMap(*Params, string, string) (int64, error)
- RowsToStruct(interface{}, string, string) (int64, error)
+ // query all data and map to containers.
+ // cols means the columns when querying.
+ // for example:
+ // var users []*User
+ // qs.All(&users) // users[0],users[1],users[2] ...
+ All(container interface{}, cols ...string) (int64, error)
+ // query one row data and map to containers.
+ // cols means the columns when querying.
+ // for example:
+ // var user User
+ // qs.One(&user) //user.UserName == "slene"
+ One(container interface{}, cols ...string) error
+ // query all data and map to []map[string]interface.
+ // expres means condition expression.
+ // it converts data to []map[column]value.
+ // for example:
+ // var maps []Params
+ // qs.Values(&maps) //maps[0]["UserName"]=="slene"
+ Values(results *[]Params, exprs ...string) (int64, error)
+ // query all data and map to [][]interface
+ // it converts data to [][column_index]value
+ // for example:
+ // var list []ParamsList
+ // qs.ValuesList(&list) // list[0][1] == "slene"
+ ValuesList(results *[]ParamsList, exprs ...string) (int64, error)
+ // query all data and map to []interface.
+ // it's designed for one column record set, auto change to []value, not [][column]value.
+ // for example:
+ // var list ParamsList
+ // qs.ValuesFlat(&list, "UserName") // list[0] == "slene"
+ ValuesFlat(result *ParamsList, expr string) (int64, error)
+ // query all rows into map[string]interface with specify key and value column name.
+ // keyCol = "name", valueCol = "value"
+ // table data
+ // name | value
+ // total | 100
+ // found | 200
+ // to map[string]interface{}{
+ // "total": 100,
+ // "found": 200,
+ // }
+ RowsToMap(result *Params, keyCol, valueCol string) (int64, error)
+ // query all rows into struct with specify key and value column name.
+ // keyCol = "name", valueCol = "value"
+ // table data
+ // name | value
+ // total | 100
+ // found | 200
+ // to struct {
+ // Total int
+ // Found int
+ // }
+ RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
}
-// model to model query struct
+// QueryM2Mer model to model query struct
+// all operations are on the m2m table only, will not affect the origin model table
type QueryM2Mer interface {
+ // add models to origin models when creating queryM2M.
+ // example:
+ // m2m := orm.QueryM2M(post,"Tag")
+ // m2m.Add(&Tag1{},&Tag2{})
+ // for _,tag := range post.Tags{}{ ... }
+ // param could also be any of the follow
+ // []*Tag{{Id:3,Name: "TestTag1"}, {Id:4,Name: "TestTag2"}}
+ // &Tag{Id:5,Name: "TestTag3"}
+ // []interface{}{&Tag{Id:6,Name: "TestTag4"}}
+ // insert one or more rows to m2m table
+ // make sure the relation is defined in post model struct tag.
Add(...interface{}) (int64, error)
+ // remove models following the origin model relationship
+ // only delete rows from m2m table
+ // for example:
+ //tag3 := &Tag{Id:5,Name: "TestTag3"}
+ //num, err = m2m.Remove(tag3)
Remove(...interface{}) (int64, error)
+ // check model is existed in relationship of origin model
Exist(interface{}) bool
+ // clean all models in related of origin model
Clear() (int64, error)
+ // count all related models of origin model
Count() (int64, error)
}
-// raw query statement
+// RawPreparer raw query statement
type RawPreparer interface {
Exec(...interface{}) (sql.Result, error)
Close() error
}
-// raw query seter
+// RawSeter raw query seter
+// create From Ormer.Raw
+// for example:
+// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q)
+// rs := Ormer.Raw(sql, 1)
type RawSeter interface {
+ //execute sql and get result
Exec() (sql.Result, error)
- QueryRow(...interface{}) error
- QueryRows(...interface{}) (int64, error)
+ //query data and map to container
+ //for example:
+ // var name string
+ // var id int
+ // rs.QueryRow(&id,&name) // id==2 name=="slene"
+ QueryRow(containers ...interface{}) error
+
+ // query data rows and map to container
+ // var ids []int
+ // var names []int
+ // query = fmt.Sprintf("SELECT 'id','name' FROM %suser%s", Q, Q)
+ // num, err = dORM.Raw(query).QueryRows(&ids,&names) // ids=>{1,2},names=>{"nobody","slene"}
+ QueryRows(containers ...interface{}) (int64, error)
SetArgs(...interface{}) RawSeter
- Values(*[]Params, ...string) (int64, error)
- ValuesList(*[]ParamsList, ...string) (int64, error)
- ValuesFlat(*ParamsList, ...string) (int64, error)
- RowsToMap(*Params, string, string) (int64, error)
- RowsToStruct(interface{}, string, string) (int64, error)
+ // query data to []map[string]interface
+ // see QuerySeter's Values
+ Values(container *[]Params, cols ...string) (int64, error)
+ // query data to [][]interface
+ // see QuerySeter's ValuesList
+ ValuesList(container *[]ParamsList, cols ...string) (int64, error)
+ // query data to []interface
+ // see QuerySeter's ValuesFlat
+ ValuesFlat(container *ParamsList, cols ...string) (int64, error)
+ // query all rows into map[string]interface with specify key and value column name.
+ // keyCol = "name", valueCol = "value"
+ // table data
+ // name | value
+ // total | 100
+ // found | 200
+ // to map[string]interface{}{
+ // "total": 100,
+ // "found": 200,
+ // }
+ RowsToMap(result *Params, keyCol, valueCol string) (int64, error)
+ // query all rows into struct with specify key and value column name.
+ // keyCol = "name", valueCol = "value"
+ // table data
+ // name | value
+ // total | 100
+ // found | 200
+ // to struct {
+ // Total int
+ // Found int
+ // }
+ RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
+
+ // return prepared raw statement for used in times.
+ // for example:
+ // pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare()
+ // r, err := pre.Exec("name1") // INSERT INTO tag (name) VALUES (`name1`)
Prepare() (RawPreparer, error)
}
-// statement querier
+// stmtQuerier statement querier
type stmtQuerier interface {
Close() error
Exec(args ...interface{}) (sql.Result, error)
@@ -160,8 +391,8 @@ type dbBaser interface {
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
- OperatorSql(string) string
- GenerateOperatorSql(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
+ OperatorSQL(string) string
+ GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
GenerateOperatorLeftCol(*fieldInfo, string, *string)
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)
diff --git a/orm/utils.go b/orm/utils.go
index df6147c1..99437c7b 100644
--- a/orm/utils.go
+++ b/orm/utils.go
@@ -22,9 +22,10 @@ import (
"time"
)
+// StrTo is the target string
type StrTo string
-// set string
+// Set string
func (f *StrTo) Set(v string) {
if v != "" {
*f = StrTo(v)
@@ -33,93 +34,93 @@ func (f *StrTo) Set(v string) {
}
}
-// clean string
+// Clear string
func (f *StrTo) Clear() {
*f = StrTo(0x1E)
}
-// check string exist
+// Exist check string exist
func (f StrTo) Exist() bool {
return string(f) != string(0x1E)
}
-// string to bool
+// Bool string to bool
func (f StrTo) Bool() (bool, error) {
return strconv.ParseBool(f.String())
}
-// string to float32
+// Float32 string to float32
func (f StrTo) Float32() (float32, error) {
v, err := strconv.ParseFloat(f.String(), 32)
return float32(v), err
}
-// string to float64
+// Float64 string to float64
func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64)
}
-// string to int
+// Int string to int
func (f StrTo) Int() (int, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int(v), err
}
-// string to int8
+// Int8 string to int8
func (f StrTo) Int8() (int8, error) {
v, err := strconv.ParseInt(f.String(), 10, 8)
return int8(v), err
}
-// string to int16
+// Int16 string to int16
func (f StrTo) Int16() (int16, error) {
v, err := strconv.ParseInt(f.String(), 10, 16)
return int16(v), err
}
-// string to int32
+// Int32 string to int32
func (f StrTo) Int32() (int32, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int32(v), err
}
-// string to int64
+// Int64 string to int64
func (f StrTo) Int64() (int64, error) {
v, err := strconv.ParseInt(f.String(), 10, 64)
return int64(v), err
}
-// string to uint
+// Uint string to uint
func (f StrTo) Uint() (uint, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint(v), err
}
-// string to uint8
+// Uint8 string to uint8
func (f StrTo) Uint8() (uint8, error) {
v, err := strconv.ParseUint(f.String(), 10, 8)
return uint8(v), err
}
-// string to uint16
+// Uint16 string to uint16
func (f StrTo) Uint16() (uint16, error) {
v, err := strconv.ParseUint(f.String(), 10, 16)
return uint16(v), err
}
-// string to uint31
+// Uint32 string to uint31
func (f StrTo) Uint32() (uint32, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint32(v), err
}
-// string to uint64
+// Uint64 string to uint64
func (f StrTo) Uint64() (uint64, error) {
v, err := strconv.ParseUint(f.String(), 10, 64)
return uint64(v), err
}
-// string to string
+// String string to string
func (f StrTo) String() string {
if f.Exist() {
return string(f)
@@ -127,7 +128,7 @@ func (f StrTo) String() string {
return ""
}
-// interface to string
+// ToStr interface to string
func ToStr(value interface{}, args ...int) (s string) {
switch v := value.(type) {
case bool:
@@ -166,7 +167,7 @@ func ToStr(value interface{}, args ...int) (s string) {
return s
}
-// interface to int64
+// ToInt64 interface to int64
func ToInt64(value interface{}) (d int64) {
val := reflect.ValueOf(value)
switch value.(type) {
@@ -195,7 +196,7 @@ func snakeString(s string) string {
}
data = append(data, d)
}
- return strings.ToLower(string(data[:len(data)]))
+ return strings.ToLower(string(data[:]))
}
// camel string, xx_yy to XxYy
@@ -220,7 +221,7 @@ func camelString(s string) string {
}
data = append(data, d)
}
- return string(data[:len(data)])
+ return string(data[:])
}
type argString []string
@@ -248,30 +249,12 @@ func (a argInt) Get(i int, args ...int) (r int) {
return
}
-type argAny []interface{}
-
-// get interface by index from interface slice
-func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
- if i >= 0 && i < len(a) {
- r = a[i]
- }
- if len(args) > 0 {
- r = args[0]
- }
- return
-}
-
// parse time to string with location
func timeParse(dateString, format string) (time.Time, error) {
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
return tp, err
}
-// format time string
-func timeFormat(t time.Time, format string) string {
- return t.Format(format)
-}
-
// get pointer indirect type
func indirectType(v reflect.Type) reflect.Type {
switch v.Kind() {
diff --git a/parser.go b/parser.go
index be91b4cb..b14d74b9 100644
--- a/parser.go
+++ b/parser.go
@@ -24,6 +24,7 @@ import (
"io/ioutil"
"os"
"path"
+ "sort"
"strings"
"github.com/astaxie/beego/utils"
@@ -36,18 +37,18 @@ import (
)
func init() {
- {{.globalinfo}}
+{{.globalinfo}}
}
`
var (
- lastupdateFilename string = "lastupdate.tmp"
+ lastupdateFilename = "lastupdate.tmp"
commentFilename string
pkgLastupdate map[string]int64
genInfoList map[string][]ControllerComments
)
-const COMMENTFL = "commentsRouter_"
+const coomentPrefix = "commentsRouter_"
func init() {
pkgLastupdate = make(map[string]int64)
@@ -55,7 +56,7 @@ func init() {
func parserPkg(pkgRealpath, pkgpath string) error {
rep := strings.NewReplacer("/", "_", ".", "_")
- commentFilename = COMMENTFL + rep.Replace(pkgpath) + ".go"
+ commentFilename = coomentPrefix + rep.Replace(pkgpath) + ".go"
if !compareFile(pkgRealpath) {
Info(pkgRealpath + " has not changed, not reloading")
return nil
@@ -76,7 +77,10 @@ func parserPkg(pkgRealpath, pkgpath string) error {
switch specDecl := d.(type) {
case *ast.FuncDecl:
if specDecl.Recv != nil {
- parserComments(specDecl.Doc, specDecl.Name.String(), fmt.Sprint(specDecl.Recv.List[0].Type.(*ast.StarExpr).X), pkgpath)
+ exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser
+ if ok {
+ parserComments(specDecl.Doc, specDecl.Name.String(), fmt.Sprint(exp.X), pkgpath)
+ }
}
}
}
@@ -126,10 +130,18 @@ func parserComments(comments *ast.CommentGroup, funcName, controllerName, pkgpat
}
func genRouterCode() {
- os.Mkdir(path.Join(workPath, "routers"), 0755)
+ os.Mkdir("routers", 0755)
Info("generate router from comments")
- var globalinfo string
- for k, cList := range genInfoList {
+ var (
+ globalinfo string
+ sortKey []string
+ )
+ for k := range genInfoList {
+ sortKey = append(sortKey, k)
+ }
+ sort.Strings(sortKey)
+ for _, k := range sortKey {
+ cList := genInfoList[k]
for _, c := range cList {
allmethod := "nil"
if len(c.AllowHTTPMethods) > 0 {
@@ -160,7 +172,7 @@ func genRouterCode() {
}
}
if globalinfo != "" {
- f, err := os.Create(path.Join(workPath, "routers", commentFilename))
+ f, err := os.Create(path.Join("routers", commentFilename))
if err != nil {
panic(err)
}
@@ -170,11 +182,11 @@ func genRouterCode() {
}
func compareFile(pkgRealpath string) bool {
- if !utils.FileExists(path.Join(workPath, "routers", commentFilename)) {
+ if !utils.FileExists(path.Join("routers", commentFilename)) {
return true
}
- if utils.FileExists(path.Join(workPath, lastupdateFilename)) {
- content, err := ioutil.ReadFile(path.Join(workPath, lastupdateFilename))
+ if utils.FileExists(lastupdateFilename) {
+ content, err := ioutil.ReadFile(lastupdateFilename)
if err != nil {
return true
}
@@ -202,7 +214,7 @@ func savetoFile(pkgRealpath string) {
if err != nil {
return
}
- ioutil.WriteFile(path.Join(workPath, lastupdateFilename), d, os.ModePerm)
+ ioutil.WriteFile(lastupdateFilename, d, os.ModePerm)
}
func getpathTime(pkgRealpath string) (lastupdate int64, err error) {
diff --git a/plugins/apiauth/apiauth.go b/plugins/apiauth/apiauth.go
index bbae7def..8af08088 100644
--- a/plugins/apiauth/apiauth.go
+++ b/plugins/apiauth/apiauth.go
@@ -33,7 +33,7 @@
// // maybe store in configure, maybe in database
// }
//
-// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APIAuthWithFunc(getAppSecret, 360))
+// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APISecretAuth(getAppSecret, 360))
//
// Infomation:
//
@@ -41,7 +41,7 @@
//
// 1. appid
//
-// appid is asigned to the application
+// appid is assigned to the application
//
// 2. signature
//
@@ -68,8 +68,10 @@ import (
"github.com/astaxie/beego/context"
)
-type AppIdToAppSecret func(string) string
+// AppIDToAppSecret is used to get appsecret throw appid
+type AppIDToAppSecret func(string) string
+// APIBaiscAuth use the basic appid/appkey as the AppIdToAppSecret
func APIBaiscAuth(appid, appkey string) beego.FilterFunc {
ft := func(aid string) string {
if aid == appid {
@@ -77,52 +79,54 @@ func APIBaiscAuth(appid, appkey string) beego.FilterFunc {
}
return ""
}
- return APIAuthWithFunc(ft, 300)
+ return APISecretAuth(ft, 300)
}
-func APIAuthWithFunc(f AppIdToAppSecret, timeout int) beego.FilterFunc {
+// APISecretAuth use AppIdToAppSecret verify and
+func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc {
return func(ctx *context.Context) {
if ctx.Input.Query("appid") == "" {
- ctx.Output.SetStatus(403)
+ ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("miss query param: appid")
return
}
appsecret := f(ctx.Input.Query("appid"))
if appsecret == "" {
- ctx.Output.SetStatus(403)
+ ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("not exist this appid")
return
}
if ctx.Input.Query("signature") == "" {
- ctx.Output.SetStatus(403)
+ ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("miss query param: signature")
return
}
if ctx.Input.Query("timestamp") == "" {
- ctx.Output.SetStatus(403)
+ ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("miss query param: timestamp")
return
}
u, err := time.Parse("2006-01-02 15:04:05", ctx.Input.Query("timestamp"))
if err != nil {
- ctx.Output.SetStatus(403)
+ ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("timestamp format is error, should 2006-01-02 15:04:05")
return
}
t := time.Now()
if t.Sub(u).Seconds() > float64(timeout) {
- ctx.Output.SetStatus(403)
+ ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("timeout! the request time is long ago, please try again")
return
}
if ctx.Input.Query("signature") !=
- Signature(appsecret, ctx.Input.Method(), ctx.Request.Form, ctx.Input.Uri()) {
- ctx.Output.SetStatus(403)
+ Signature(appsecret, ctx.Input.Method(), ctx.Request.Form, ctx.Input.URI()) {
+ ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("auth failed")
}
}
}
+// Signature used to generate signature with the appsecret/method/params/RequestURI
func Signature(appsecret, method string, params url.Values, RequestURI string) (result string) {
var query string
pa := make(map[string]string)
@@ -139,11 +143,11 @@ func Signature(appsecret, method string, params url.Values, RequestURI string) (
query = fmt.Sprintf("%v%v%v", query, vs.Keys[i], vs.Vals[i])
}
}
- string_to_sign := fmt.Sprintf("%v\n%v\n%v\n", method, query, RequestURI)
+ stringToSign := fmt.Sprintf("%v\n%v\n%v\n", method, query, RequestURI)
sha256 := sha256.New
hash := hmac.New(sha256, []byte(appsecret))
- hash.Write([]byte(string_to_sign))
+ hash.Write([]byte(stringToSign))
return base64.StdEncoding.EncodeToString(hash.Sum(nil))
}
diff --git a/plugins/auth/basic.go b/plugins/auth/basic.go
index 946b8457..c478044a 100644
--- a/plugins/auth/basic.go
+++ b/plugins/auth/basic.go
@@ -46,6 +46,7 @@ import (
var defaultRealm = "Authorization Required"
+// Basic is the http basic auth
func Basic(username string, password string) beego.FilterFunc {
secrets := func(user, pass string) bool {
return user == username && pass == password
@@ -53,6 +54,7 @@ func Basic(username string, password string) beego.FilterFunc {
return NewBasicAuthenticator(secrets, defaultRealm)
}
+// NewBasicAuthenticator return the BasicAuth
func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFunc {
return func(ctx *context.Context) {
a := &BasicAuth{Secrets: secrets, Realm: Realm}
@@ -62,17 +64,19 @@ func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFun
}
}
+// SecretProvider is the SecretProvider function
type SecretProvider func(user, pass string) bool
+// BasicAuth store the SecretProvider and Realm
type BasicAuth struct {
Secrets SecretProvider
Realm string
}
-//Checks the username/password combination from the request. Returns
-//either an empty string (authentication failed) or the name of the
-//authenticated user.
-//Supports MD5 and SHA1 password entries
+// CheckAuth Checks the username/password combination from the request. Returns
+// either an empty string (authentication failed) or the name of the
+// authenticated user.
+// Supports MD5 and SHA1 password entries
func (a *BasicAuth) CheckAuth(r *http.Request) string {
s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
if len(s) != 2 || s[0] != "Basic" {
@@ -94,8 +98,8 @@ func (a *BasicAuth) CheckAuth(r *http.Request) string {
return ""
}
-//http.Handler for BasicAuth which initiates the authentication process
-//(or requires reauthentication).
+// RequireAuth http.Handler for BasicAuth which initiates the authentication process
+// (or requires reauthentication).
func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("WWW-Authenticate", `Basic realm="`+a.Realm+`"`)
w.WriteHeader(401)
diff --git a/plugins/cors/cors.go b/plugins/cors/cors.go
index 052d3bc6..1e973a40 100644
--- a/plugins/cors/cors.go
+++ b/plugins/cors/cors.go
@@ -24,7 +24,7 @@
// // - PUT and PATCH methods
// // - Origin header
// // - Credentials share
-// beego.InsertFilter("*", beego.BeforeRouter,cors.Allow(&cors.Options{
+// beego.InsertFilter("*", beego.BeforeRouter, cors.Allow(&cors.Options{
// AllowOrigins: []string{"https://*.foo.com"},
// AllowMethods: []string{"PUT", "PATCH"},
// AllowHeaders: []string{"Origin"},
@@ -36,7 +36,6 @@
package cors
import (
- "net/http"
"regexp"
"strconv"
"strings"
@@ -216,8 +215,6 @@ func Allow(opts *Options) beego.FilterFunc {
for key, value := range headers {
ctx.Output.Header(key, value)
}
- ctx.Output.SetStatus(http.StatusOK)
- ctx.WriteString("")
return
}
headers = opts.Header(origin)
diff --git a/plugins/cors/cors_test.go b/plugins/cors/cors_test.go
index 5c02ab98..34039143 100644
--- a/plugins/cors/cors_test.go
+++ b/plugins/cors/cors_test.go
@@ -25,21 +25,23 @@ import (
"github.com/astaxie/beego/context"
)
-type HttpHeaderGuardRecorder struct {
+// HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header
+type HTTPHeaderGuardRecorder struct {
*httptest.ResponseRecorder
savedHeaderMap http.Header
}
-func NewRecorder() *HttpHeaderGuardRecorder {
- return &HttpHeaderGuardRecorder{httptest.NewRecorder(), nil}
+// NewRecorder return HttpHeaderGuardRecorder
+func NewRecorder() *HTTPHeaderGuardRecorder {
+ return &HTTPHeaderGuardRecorder{httptest.NewRecorder(), nil}
}
-func (gr *HttpHeaderGuardRecorder) WriteHeader(code int) {
+func (gr *HTTPHeaderGuardRecorder) WriteHeader(code int) {
gr.ResponseRecorder.WriteHeader(code)
gr.savedHeaderMap = gr.ResponseRecorder.Header()
}
-func (gr *HttpHeaderGuardRecorder) Header() http.Header {
+func (gr *HTTPHeaderGuardRecorder) Header() http.Header {
if gr.savedHeaderMap != nil {
// headers were written. clone so we don't get updates
clone := make(http.Header)
@@ -47,9 +49,8 @@ func (gr *HttpHeaderGuardRecorder) Header() http.Header {
clone[k] = v
}
return clone
- } else {
- return gr.ResponseRecorder.Header()
}
+ return gr.ResponseRecorder.Header()
}
func Test_AllowAll(t *testing.T) {
@@ -219,13 +220,13 @@ func Test_Preflight(t *testing.T) {
func Benchmark_WithoutCORS(b *testing.B) {
recorder := httptest.NewRecorder()
handler := beego.NewControllerRegister()
- beego.RunMode = "prod"
+ beego.BConfig.RunMode = beego.PROD
handler.Any("/foo", func(ctx *context.Context) {
ctx.Output.SetStatus(500)
})
b.ResetTimer()
- for i := 0; i < 100; i++ {
- r, _ := http.NewRequest("PUT", "/foo", nil)
+ r, _ := http.NewRequest("PUT", "/foo", nil)
+ for i := 0; i < b.N; i++ {
handler.ServeHTTP(recorder, r)
}
}
@@ -233,7 +234,7 @@ func Benchmark_WithoutCORS(b *testing.B) {
func Benchmark_WithCORS(b *testing.B) {
recorder := httptest.NewRecorder()
handler := beego.NewControllerRegister()
- beego.RunMode = "prod"
+ beego.BConfig.RunMode = beego.PROD
handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
AllowAllOrigins: true,
AllowCredentials: true,
@@ -245,8 +246,8 @@ func Benchmark_WithCORS(b *testing.B) {
ctx.Output.SetStatus(500)
})
b.ResetTimer()
- for i := 0; i < 100; i++ {
- r, _ := http.NewRequest("PUT", "/foo", nil)
+ r, _ := http.NewRequest("PUT", "/foo", nil)
+ for i := 0; i < b.N; i++ {
handler.ServeHTTP(recorder, r)
}
}
diff --git a/router.go b/router.go
index b9d649a2..0e1d1d32 100644
--- a/router.go
+++ b/router.go
@@ -15,10 +15,7 @@
package beego
import (
- "bufio"
- "errors"
"fmt"
- "net"
"net/http"
"os"
"path"
@@ -27,6 +24,7 @@ import (
"runtime"
"strconv"
"strings"
+ "sync"
"time"
beecontext "github.com/astaxie/beego/context"
@@ -34,8 +32,8 @@ import (
"github.com/astaxie/beego/utils"
)
+// default filter execution points
const (
- // default filter execution points
BeforeStatic = iota
BeforeRouter
BeforeExec
@@ -50,7 +48,7 @@ const (
)
var (
- // supported http methods.
+ // HTTPMETHOD list the supported http methods.
HTTPMETHOD = map[string]string{
"GET": "GET",
"POST": "POST",
@@ -71,10 +69,12 @@ var (
"SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml",
"GetControllerAndAction"}
- url_placeholder = "{{placeholder}}"
- DefaultLogFilter FilterHandler = &logFilter{}
+ urlPlaceholder = "{{placeholder}}"
+ // DefaultAccessLogFilter will skip the accesslog if return true
+ DefaultAccessLogFilter FilterHandler = &logFilter{}
)
+// FilterHandler is an interface for
type FilterHandler interface {
Filter(*beecontext.Context) bool
}
@@ -84,11 +84,11 @@ type logFilter struct {
}
func (l *logFilter) Filter(ctx *beecontext.Context) bool {
- requestPath := path.Clean(ctx.Input.Request.URL.Path)
+ requestPath := path.Clean(ctx.Request.URL.Path)
if requestPath == "/favicon.ico" || requestPath == "/robots.txt" {
return true
}
- for prefix, _ := range StaticDir {
+ for prefix := range BConfig.WebConfig.StaticDir {
if strings.HasPrefix(requestPath, prefix) {
return true
}
@@ -96,7 +96,7 @@ func (l *logFilter) Filter(ctx *beecontext.Context) bool {
return false
}
-// To append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter
+// ExceptMethodAppend to append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter
func ExceptMethodAppend(action string) {
exceptMethod = append(exceptMethod, action)
}
@@ -106,26 +106,31 @@ type controllerInfo struct {
controllerType reflect.Type
methods map[string]string
handler http.Handler
- runfunction FilterFunc
+ runFunction FilterFunc
routerType int
}
-// ControllerRegistor containers registered router rules, controller handlers and filters.
-type ControllerRegistor struct {
+// ControllerRegister containers registered router rules, controller handlers and filters.
+type ControllerRegister struct {
routers map[string]*Tree
enableFilter bool
filters map[int][]*FilterRouter
+ pool sync.Pool
}
-// NewControllerRegister returns a new ControllerRegistor.
-func NewControllerRegister() *ControllerRegistor {
- return &ControllerRegistor{
+// NewControllerRegister returns a new ControllerRegister.
+func NewControllerRegister() *ControllerRegister {
+ cr := &ControllerRegister{
routers: make(map[string]*Tree),
filters: make(map[int][]*FilterRouter),
}
+ cr.pool.New = func() interface{} {
+ return beecontext.NewContext()
+ }
+ return cr
}
-// Add controller handler and pattern rules to ControllerRegistor.
+// Add controller handler and pattern rules to ControllerRegister.
// usage:
// default methods is the same name as method
// Add("/user",&UserController{})
@@ -133,9 +138,9 @@ func NewControllerRegister() *ControllerRegistor {
// Add("/api/create",&RestController{},"post:CreateFood")
// Add("/api/update",&RestController{},"put:UpdateFood")
// Add("/api/delete",&RestController{},"delete:DeleteFood")
-// Add("/api",&RestController{},"get,post:ApiFunc")
+// Add("/api",&RestController{},"get,post:ApiFunc"
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
-func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
+func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
reflectVal := reflect.ValueOf(c)
t := reflect.Indirect(reflectVal).Type()
methods := make(map[string]string)
@@ -171,7 +176,7 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
p.addToRouter(m, pattern, route)
}
} else {
- for k, _ := range methods {
+ for k := range methods {
if k == "*" {
for _, m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
@@ -183,8 +188,8 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
}
}
-func (p *ControllerRegistor) addToRouter(method, pattern string, r *controllerInfo) {
- if !RouterCaseSensitive {
+func (p *ControllerRegister) addToRouter(method, pattern string, r *controllerInfo) {
+ if !BConfig.RouterCaseSensitive {
pattern = strings.ToLower(pattern)
}
if t, ok := p.routers[method]; ok {
@@ -196,10 +201,10 @@ func (p *ControllerRegistor) addToRouter(method, pattern string, r *controllerIn
}
}
-// only when the Runmode is dev will generate router file in the router/auto.go from the controller
+// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller
// Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{})
-func (p *ControllerRegistor) Include(cList ...ControllerInterface) {
- if RunMode == "dev" {
+func (p *ControllerRegister) Include(cList ...ControllerInterface) {
+ if BConfig.RunMode == DEV {
skip := make(map[string]bool, 10)
for _, c := range cList {
reflectVal := reflect.ValueOf(c)
@@ -238,101 +243,102 @@ func (p *ControllerRegistor) Include(cList ...ControllerInterface) {
}
}
-// add get method
+// Get add get method
// usage:
// Get("/", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
-func (p *ControllerRegistor) Get(pattern string, f FilterFunc) {
+func (p *ControllerRegister) Get(pattern string, f FilterFunc) {
p.AddMethod("get", pattern, f)
}
-// add post method
+// Post add post method
// usage:
// Post("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
-func (p *ControllerRegistor) Post(pattern string, f FilterFunc) {
+func (p *ControllerRegister) Post(pattern string, f FilterFunc) {
p.AddMethod("post", pattern, f)
}
-// add put method
+// Put add put method
// usage:
// Put("/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
-func (p *ControllerRegistor) Put(pattern string, f FilterFunc) {
+func (p *ControllerRegister) Put(pattern string, f FilterFunc) {
p.AddMethod("put", pattern, f)
}
-// add delete method
+// Delete add delete method
// usage:
// Delete("/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
-func (p *ControllerRegistor) Delete(pattern string, f FilterFunc) {
+func (p *ControllerRegister) Delete(pattern string, f FilterFunc) {
p.AddMethod("delete", pattern, f)
}
-// add head method
+// Head add head method
// usage:
// Head("/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
-func (p *ControllerRegistor) Head(pattern string, f FilterFunc) {
+func (p *ControllerRegister) Head(pattern string, f FilterFunc) {
p.AddMethod("head", pattern, f)
}
-// add patch method
+// Patch add patch method
// usage:
// Patch("/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
-func (p *ControllerRegistor) Patch(pattern string, f FilterFunc) {
+func (p *ControllerRegister) Patch(pattern string, f FilterFunc) {
p.AddMethod("patch", pattern, f)
}
-// add options method
+// Options add options method
// usage:
// Options("/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
-func (p *ControllerRegistor) Options(pattern string, f FilterFunc) {
+func (p *ControllerRegister) Options(pattern string, f FilterFunc) {
p.AddMethod("options", pattern, f)
}
-// add all method
+// Any add all method
// usage:
// Any("/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
-func (p *ControllerRegistor) Any(pattern string, f FilterFunc) {
+func (p *ControllerRegister) Any(pattern string, f FilterFunc) {
p.AddMethod("*", pattern, f)
}
-// add http method router
+// AddMethod add http method router
// usage:
// AddMethod("get","/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
-func (p *ControllerRegistor) AddMethod(method, pattern string, f FilterFunc) {
- if _, ok := HTTPMETHOD[strings.ToUpper(method)]; method != "*" && !ok {
+func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
+ method = strings.ToUpper(method)
+ if _, ok := HTTPMETHOD[method]; method != "*" && !ok {
panic("not support http method: " + method)
}
route := &controllerInfo{}
route.pattern = pattern
route.routerType = routerTypeRESTFul
- route.runfunction = f
+ route.runFunction = f
methods := make(map[string]string)
if method == "*" {
for _, val := range HTTPMETHOD {
methods[val] = val
}
} else {
- methods[strings.ToUpper(method)] = strings.ToUpper(method)
+ methods[method] = method
}
route.methods = methods
- for k, _ := range methods {
+ for k := range methods {
if k == "*" {
for _, m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
@@ -343,15 +349,15 @@ func (p *ControllerRegistor) AddMethod(method, pattern string, f FilterFunc) {
}
}
-// add user defined Handler
-func (p *ControllerRegistor) Handler(pattern string, h http.Handler, options ...interface{}) {
+// Handler add user defined Handler
+func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) {
route := &controllerInfo{}
route.pattern = pattern
route.routerType = routerTypeHandler
route.handler = h
if len(options) > 0 {
if _, ok := options[0].(bool); ok {
- pattern = path.Join(pattern, "?:all")
+ pattern = path.Join(pattern, "?:all(.*)")
}
}
for _, m := range HTTPMETHOD {
@@ -359,21 +365,21 @@ func (p *ControllerRegistor) Handler(pattern string, h http.Handler, options ...
}
}
-// Add auto router to ControllerRegistor.
+// AddAuto router to ControllerRegister.
// example beego.AddAuto(&MainContorlller{}),
// MainController has method List and Page.
// visit the url /main/list to execute List function
// /main/page to execute Page function.
-func (p *ControllerRegistor) AddAuto(c ControllerInterface) {
+func (p *ControllerRegister) AddAuto(c ControllerInterface) {
p.AddAutoPrefix("/", c)
}
-// Add auto router to ControllerRegistor with prefix.
+// AddAutoPrefix Add auto router to ControllerRegister with prefix.
// example beego.AddAutoPrefix("/admin",&MainContorlller{}),
// MainController has method List and Page.
// visit the url /admin/main/list to execute List function
// /admin/main/page to execute Page function.
-func (p *ControllerRegistor) AddAutoPrefix(prefix string, c ControllerInterface) {
+func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) {
reflectVal := reflect.ValueOf(c)
rt := reflectVal.Type()
ct := reflect.Indirect(reflectVal).Type()
@@ -386,28 +392,28 @@ func (p *ControllerRegistor) AddAutoPrefix(prefix string, c ControllerInterface)
route.controllerType = ct
pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*")
patternInit := path.Join(prefix, controllerName, rt.Method(i).Name, "*")
- patternfix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name))
- patternfixInit := path.Join(prefix, controllerName, rt.Method(i).Name)
+ patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name))
+ patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name)
route.pattern = pattern
for _, m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
p.addToRouter(m, patternInit, route)
- p.addToRouter(m, patternfix, route)
- p.addToRouter(m, patternfixInit, route)
+ p.addToRouter(m, patternFix, route)
+ p.addToRouter(m, patternFixInit, route)
}
}
}
}
-// Add a FilterFunc with pattern rule and action constant.
+// InsertFilter Add a FilterFunc with pattern rule and action constant.
// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute)
-func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error {
+func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error {
mr := new(FilterRouter)
mr.tree = NewTree()
mr.pattern = pattern
mr.filterFunc = filter
- if !RouterCaseSensitive {
+ if !BConfig.RouterCaseSensitive {
pattern = strings.ToLower(pattern)
}
if len(params) == 0 {
@@ -420,15 +426,15 @@ func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter Filter
}
// add Filter into
-func (p *ControllerRegistor) insertFilterRouter(pos int, mr *FilterRouter) error {
+func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) error {
p.filters[pos] = append(p.filters[pos], mr)
p.enableFilter = true
return nil
}
-// UrlFor does another controller handler in this request function.
+// URLFor does another controller handler in this request function.
// it can access any controller method.
-func (p *ControllerRegistor) UrlFor(endpoint string, values ...interface{}) string {
+func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string {
paths := strings.Split(endpoint, ".")
if len(paths) <= 1 {
Warn("urlfor endpoint must like path.controller.method")
@@ -460,16 +466,16 @@ func (p *ControllerRegistor) UrlFor(endpoint string, values ...interface{}) stri
return ""
}
-func (p *ControllerRegistor) geturl(t *Tree, url, controllName, methodName string, params map[string]string, httpMethod string) (bool, string) {
- for k, subtree := range t.fixrouters {
- u := path.Join(url, k)
+func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName string, params map[string]string, httpMethod string) (bool, string) {
+ for _, subtree := range t.fixrouters {
+ u := path.Join(url, subtree.prefix)
ok, u := p.geturl(subtree, u, controllName, methodName, params, httpMethod)
if ok {
return ok, u
}
}
if t.wildcard != nil {
- u := path.Join(url, url_placeholder)
+ u := path.Join(url, urlPlaceholder)
ok, u := p.geturl(t.wildcard, u, controllName, methodName, params, httpMethod)
if ok {
return ok, u
@@ -499,22 +505,21 @@ func (p *ControllerRegistor) geturl(t *Tree, url, controllName, methodName strin
if find {
if l.regexps == nil {
if len(l.wildcards) == 0 {
- return true, strings.Replace(url, "/"+url_placeholder, "", 1) + tourl(params)
+ return true, strings.Replace(url, "/"+urlPlaceholder, "", 1) + toURL(params)
}
if len(l.wildcards) == 1 {
if v, ok := params[l.wildcards[0]]; ok {
delete(params, l.wildcards[0])
- return true, strings.Replace(url, url_placeholder, v, 1) + tourl(params)
- } else {
- return false, ""
+ return true, strings.Replace(url, urlPlaceholder, v, 1) + toURL(params)
}
+ return false, ""
}
if len(l.wildcards) == 3 && l.wildcards[0] == "." {
if p, ok := params[":path"]; ok {
if e, isok := params[":ext"]; isok {
delete(params, ":path")
delete(params, ":ext")
- return true, strings.Replace(url, url_placeholder, p+"."+e, -1) + tourl(params)
+ return true, strings.Replace(url, urlPlaceholder, p+"."+e, -1) + toURL(params)
}
}
}
@@ -526,45 +531,43 @@ func (p *ControllerRegistor) geturl(t *Tree, url, controllName, methodName strin
}
if u, ok := params[v]; ok {
delete(params, v)
- url = strings.Replace(url, url_placeholder, u, 1)
+ url = strings.Replace(url, urlPlaceholder, u, 1)
} else {
if canskip {
canskip = false
continue
- } else {
- return false, ""
}
+ return false, ""
}
}
- return true, url + tourl(params)
- } else {
- var i int
- var startreg bool
- regurl := ""
- for _, v := range strings.Trim(l.regexps.String(), "^$") {
- if v == '(' {
- startreg = true
- continue
- } else if v == ')' {
- startreg = false
- if v, ok := params[l.wildcards[i]]; ok {
- delete(params, l.wildcards[i])
- regurl = regurl + v
- i++
- } else {
- break
- }
- } else if !startreg {
- regurl = string(append([]rune(regurl), v))
+ return true, url + toURL(params)
+ }
+ var i int
+ var startreg bool
+ regurl := ""
+ for _, v := range strings.Trim(l.regexps.String(), "^$") {
+ if v == '(' {
+ startreg = true
+ continue
+ } else if v == ')' {
+ startreg = false
+ if v, ok := params[l.wildcards[i]]; ok {
+ delete(params, l.wildcards[i])
+ regurl = regurl + v
+ i++
+ } else {
+ break
}
+ } else if !startreg {
+ regurl = string(append([]rune(regurl), v))
}
- if l.regexps.MatchString(regurl) {
- ps := strings.Split(regurl, "/")
- for _, p := range ps {
- url = strings.Replace(url, url_placeholder, p, 1)
- }
- return true, url + tourl(params)
+ }
+ if l.regexps.MatchString(regurl) {
+ ps := strings.Split(regurl, "/")
+ for _, p := range ps {
+ url = strings.Replace(url, urlPlaceholder, p, 1)
}
+ return true, url + toURL(params)
}
}
}
@@ -574,159 +577,139 @@ func (p *ControllerRegistor) geturl(t *Tree, url, controllName, methodName strin
return false, ""
}
+func (p *ControllerRegister) execFilter(context *beecontext.Context, pos int, urlPath string) (started bool) {
+ if p.enableFilter {
+ if l, ok := p.filters[pos]; ok {
+ for _, filterR := range l {
+ if filterR.returnOnOutput && context.ResponseWriter.Started {
+ return true
+ }
+ if ok := filterR.ValidRouter(urlPath, context); ok {
+ filterR.filterFunc(context)
+ }
+ if filterR.returnOnOutput && context.ResponseWriter.Started {
+ return true
+ }
+ }
+ }
+ }
+ return false
+}
+
// Implement http.Handler interface.
-func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
- starttime := time.Now()
- var runrouter reflect.Type
- var findrouter bool
- var runMethod string
- var routerInfo *controllerInfo
-
- w := &responseWriter{writer: rw}
-
- if RunMode == "dev" {
- w.Header().Set("Server", BeegoServerName)
- }
-
- // init context
- context := &beecontext.Context{
- ResponseWriter: w,
- Request: r,
- Input: beecontext.NewInput(r),
- Output: beecontext.NewOutput(),
- }
- context.Output.Context = context
- context.Output.EnableGzip = EnableGzip
-
+func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
+ startTime := time.Now()
+ var (
+ runRouter reflect.Type
+ findRouter bool
+ runMethod string
+ routerInfo *controllerInfo
+ )
+ context := p.pool.Get().(*beecontext.Context)
+ context.Reset(rw, r)
+ defer p.pool.Put(context)
defer p.recoverPanic(context)
+ context.Output.EnableGzip = BConfig.EnableGzip
+
+ if BConfig.RunMode == DEV {
+ context.Output.Header("Server", BConfig.ServerName)
+ }
+
var urlPath string
- if !RouterCaseSensitive {
+ if !BConfig.RouterCaseSensitive {
urlPath = strings.ToLower(r.URL.Path)
} else {
urlPath = r.URL.Path
}
- // defined filter function
- do_filter := func(pos int) (started bool) {
- if p.enableFilter {
- if l, ok := p.filters[pos]; ok {
- for _, filterR := range l {
- if ok, p := filterR.ValidRouter(urlPath); ok {
- context.Input.Params = p
- filterR.filterFunc(context)
- if filterR.returnOnOutput && w.started {
- return true
- }
- }
- }
- }
- }
- return false
- }
-
- // filter wrong httpmethod
+ // filter wrong http method
if _, ok := HTTPMETHOD[r.Method]; !ok {
- http.Error(w, "Method Not Allowed", 405)
+ http.Error(rw, "Method Not Allowed", 405)
goto Admin
}
// filter for static file
- if do_filter(BeforeStatic) {
+ if p.execFilter(context, BeforeStatic, urlPath) {
goto Admin
}
serverStaticRouter(context)
- if w.started {
- findrouter = true
+ if context.ResponseWriter.Started {
+ findRouter = true
goto Admin
}
// session init
- if SessionOn {
+ if BConfig.WebConfig.Session.SessionOn {
var err error
- context.Input.CruSession, err = GlobalSessions.SessionStart(w, r)
+ context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r)
if err != nil {
Error(err)
exception("503", context)
return
}
defer func() {
- context.Input.CruSession.SessionRelease(w)
+ if context.Input.CruSession != nil {
+ context.Input.CruSession.SessionRelease(rw)
+ }
}()
}
if r.Method != "GET" && r.Method != "HEAD" {
- if CopyRequestBody && !context.Input.IsUpload() {
- context.Input.CopyBody()
+ if BConfig.CopyRequestBody && !context.Input.IsUpload() {
+ context.Input.CopyBody(BConfig.MaxMemory)
}
- context.Input.ParseFormOrMulitForm(MaxMemory)
+ context.Input.ParseFormOrMulitForm(BConfig.MaxMemory)
}
- if do_filter(BeforeRouter) {
+ if p.execFilter(context, BeforeRouter, urlPath) {
goto Admin
}
- if context.Input.RunController != nil && context.Input.RunMethod != "" {
- findrouter = true
- runMethod = context.Input.RunMethod
- runrouter = context.Input.RunController
- }
-
- if !findrouter {
- http_method := r.Method
-
- if http_method == "POST" && context.Input.Query("_method") == "PUT" {
- http_method = "PUT"
- }
-
- if http_method == "POST" && context.Input.Query("_method") == "DELETE" {
- http_method = "DELETE"
- }
-
- if t, ok := p.routers[http_method]; ok {
- runObject, p := t.Match(urlPath)
+ if !findRouter {
+ httpMethod := r.Method
+ if t, ok := p.routers[httpMethod]; ok {
+ runObject := t.Match(urlPath, context)
if r, ok := runObject.(*controllerInfo); ok {
routerInfo = r
- findrouter = true
- if splat, ok := p[":splat"]; ok {
- splatlist := strings.Split(splat, "/")
- for k, v := range splatlist {
- p[strconv.Itoa(k)] = v
+ findRouter = true
+ if splat := context.Input.Param(":splat"); splat != "" {
+ for k, v := range strings.Split(splat, "/") {
+ context.Input.SetParam(strconv.Itoa(k), v)
}
}
- context.Input.Params = p
}
}
}
//if no matches to url, throw a not found exception
- if !findrouter {
+ if !findRouter {
exception("404", context)
goto Admin
}
- if findrouter {
+ if findRouter {
//execute middleware filters
- if do_filter(BeforeExec) {
+ if p.execFilter(context, BeforeExec, urlPath) {
goto Admin
}
- isRunable := false
+ isRunnable := false
if routerInfo != nil {
if routerInfo.routerType == routerTypeRESTFul {
if _, ok := routerInfo.methods[r.Method]; ok {
- isRunable = true
- routerInfo.runfunction(context)
+ isRunnable = true
+ routerInfo.runFunction(context)
} else {
exception("405", context)
goto Admin
}
} else if routerInfo.routerType == routerTypeHandler {
- isRunable = true
+ isRunnable = true
routerInfo.handler.ServeHTTP(rw, r)
} else {
- runrouter = routerInfo.controllerType
+ runRouter = routerInfo.controllerType
method := r.Method
if r.Method == "POST" && context.Input.Query("_method") == "PUT" {
method = "PUT"
@@ -744,33 +727,33 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
}
}
- // also defined runrouter & runMethod from filter
- if !isRunable {
+ // also defined runRouter & runMethod from filter
+ if !isRunnable {
//Invoke the request handler
- vc := reflect.New(runrouter)
+ vc := reflect.New(runRouter)
execController, ok := vc.Interface().(ControllerInterface)
if !ok {
panic("controller is not ControllerInterface")
}
//call the controller init function
- execController.Init(context, runrouter.Name(), runMethod, vc.Interface())
+ execController.Init(context, runRouter.Name(), runMethod, vc.Interface())
//call prepare function
execController.Prepare()
//if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf
- if EnableXSRF {
- execController.XsrfToken()
+ if BConfig.WebConfig.EnableXSRF {
+ execController.XSRFToken()
if r.Method == "POST" || r.Method == "DELETE" || r.Method == "PUT" ||
(r.Method == "POST" && (context.Input.Query("_method") == "DELETE" || context.Input.Query("_method") == "PUT")) {
- execController.CheckXsrfCookie()
+ execController.CheckXSRFCookie()
}
}
execController.URLMapping()
- if !w.started {
+ if !context.ResponseWriter.Started {
//exec main logic
switch runMethod {
case "GET":
@@ -789,15 +772,15 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
execController.Options()
default:
if !execController.HandlerFunc(runMethod) {
- in := make([]reflect.Value, 0)
+ var in []reflect.Value
method := vc.MethodByName(runMethod)
method.Call(in)
}
}
//render template
- if !w.started && context.Output.Status == 0 {
- if AutoRender {
+ if !context.ResponseWriter.Started && context.Output.Status == 0 {
+ if BConfig.WebConfig.AutoRender {
if err := execController.Render(); err != nil {
panic(err)
}
@@ -805,151 +788,86 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
}
}
- // finish all runrouter. release resource
+ // finish all runRouter. release resource
execController.Finish()
}
//execute middleware filters
- if do_filter(AfterExec) {
+ if p.execFilter(context, AfterExec, urlPath) {
goto Admin
}
}
- do_filter(FinishRouter)
+ p.execFilter(context, FinishRouter, urlPath)
Admin:
- timeend := time.Since(starttime)
+ timeDur := time.Since(startTime)
//admin module record QPS
- if EnableAdmin {
- if FilterMonitorFunc(r.Method, r.URL.Path, timeend) {
- if runrouter != nil {
- go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runrouter.Name(), timeend)
+ if BConfig.Listen.EnableAdmin {
+ if FilterMonitorFunc(r.Method, r.URL.Path, timeDur) {
+ if runRouter != nil {
+ go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runRouter.Name(), timeDur)
} else {
- go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, "", timeend)
+ go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, "", timeDur)
}
}
}
- if RunMode == "dev" || AccessLogs {
- var devinfo string
- if findrouter {
+ if BConfig.RunMode == DEV || BConfig.Log.AccessLogs {
+ var devInfo string
+ if findRouter {
if routerInfo != nil {
- devinfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s | % -40s |", r.Method, r.URL.Path, timeend.String(), "match", routerInfo.pattern)
+ devInfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s | % -40s |", r.Method, r.URL.Path, timeDur.String(), "match", routerInfo.pattern)
} else {
- devinfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeend.String(), "match")
+ devInfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeDur.String(), "match")
}
} else {
- devinfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeend.String(), "notmatch")
+ devInfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeDur.String(), "notmatch")
}
- if DefaultLogFilter == nil || !DefaultLogFilter.Filter(context) {
- Debug(devinfo)
+ if DefaultAccessLogFilter == nil || !DefaultAccessLogFilter.Filter(context) {
+ Debug(devInfo)
}
}
// Call WriteHeader if status code has been set changed
if context.Output.Status != 0 {
- w.writer.WriteHeader(context.Output.Status)
+ context.ResponseWriter.WriteHeader(context.Output.Status)
}
}
-func (p *ControllerRegistor) recoverPanic(context *beecontext.Context) {
+func (p *ControllerRegister) recoverPanic(context *beecontext.Context) {
if err := recover(); err != nil {
- if err == USERSTOPRUN {
+ if err == ErrAbort {
return
}
- if RunMode == "dev" {
- if !RecoverPanic {
- panic(err)
- } else {
- if ErrorsShow {
- if handler, ok := ErrorMaps[fmt.Sprint(err)]; ok {
- executeError(handler, context)
- return
- }
- }
- var stack string
- Critical("the request url is ", context.Input.Url())
- Critical("Handler crashed with error", err)
- for i := 1; ; i++ {
- _, file, line, ok := runtime.Caller(i)
- if !ok {
- break
- }
- Critical(fmt.Sprintf("%s:%d", file, line))
- stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line))
- }
- showErr(err, context, stack)
- }
+ if !BConfig.RecoverPanic {
+ panic(err)
} else {
- if !RecoverPanic {
- panic(err)
- } else {
- // in production model show all infomation
- if ErrorsShow {
- if handler, ok := ErrorMaps[fmt.Sprint(err)]; ok {
- executeError(handler, context)
- return
- } else if handler, ok := ErrorMaps["503"]; ok {
- executeError(handler, context)
- return
- } else {
- context.WriteString(fmt.Sprint(err))
- }
- } else {
- Critical("the request url is ", context.Input.Url())
- Critical("Handler crashed with error", err)
- for i := 1; ; i++ {
- _, file, line, ok := runtime.Caller(i)
- if !ok {
- break
- }
- Critical(fmt.Sprintf("%s:%d", file, line))
- }
+ if BConfig.EnableErrorsShow {
+ if _, ok := ErrorMaps[fmt.Sprint(err)]; ok {
+ exception(fmt.Sprint(err), context)
+ return
}
}
+ var stack string
+ Critical("the request url is ", context.Input.URL())
+ Critical("Handler crashed with error", err)
+ for i := 1; ; i++ {
+ _, file, line, ok := runtime.Caller(i)
+ if !ok {
+ break
+ }
+ Critical(fmt.Sprintf("%s:%d", file, line))
+ stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line))
+ }
+ if BConfig.RunMode == DEV {
+ showErr(err, context, stack)
+ }
}
}
}
-//responseWriter is a wrapper for the http.ResponseWriter
-//started set to true if response was written to then don't execute other handler
-type responseWriter struct {
- writer http.ResponseWriter
- started bool
- status int
-}
-
-// Header returns the header map that will be sent by WriteHeader.
-func (w *responseWriter) Header() http.Header {
- return w.writer.Header()
-}
-
-// Write writes the data to the connection as part of an HTTP reply,
-// and sets `started` to true.
-// started means the response has sent out.
-func (w *responseWriter) Write(p []byte) (int, error) {
- w.started = true
- return w.writer.Write(p)
-}
-
-// WriteHeader sends an HTTP response header with status code,
-// and sets `started` to true.
-func (w *responseWriter) WriteHeader(code int) {
- w.status = code
- w.started = true
- w.writer.WriteHeader(code)
-}
-
-// hijacker for http
-func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
- hj, ok := w.writer.(http.Hijacker)
- if !ok {
- return nil, nil, errors.New("webserver doesn't support hijacking")
- }
- return hj.Hijack()
-}
-
-func tourl(params map[string]string) string {
+func toURL(params map[string]string) string {
if len(params) == 0 {
return ""
}
diff --git a/router_test.go b/router_test.go
index ee712167..b0ae7a18 100644
--- a/router_test.go
+++ b/router_test.go
@@ -45,24 +45,24 @@ func (tc *TestController) List() {
}
func (tc *TestController) Params() {
- tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Params["0"] + tc.Ctx.Input.Params["1"] + tc.Ctx.Input.Params["2"]))
+ tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Param("0") + tc.Ctx.Input.Param("1") + tc.Ctx.Input.Param("2")))
}
func (tc *TestController) Myext() {
tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Param(":ext")))
}
-func (tc *TestController) GetUrl() {
- tc.Ctx.Output.Body([]byte(tc.UrlFor(".Myext")))
+func (tc *TestController) GetURL() {
+ tc.Ctx.Output.Body([]byte(tc.URLFor(".Myext")))
}
-func (t *TestController) GetParams() {
- t.Ctx.WriteString(t.Ctx.Input.Query(":last") + "+" +
- t.Ctx.Input.Query(":first") + "+" + t.Ctx.Input.Query("learn"))
+func (tc *TestController) GetParams() {
+ tc.Ctx.WriteString(tc.Ctx.Input.Query(":last") + "+" +
+ tc.Ctx.Input.Query(":first") + "+" + tc.Ctx.Input.Query("learn"))
}
-func (t *TestController) GetManyRouter() {
- t.Ctx.WriteString(t.Ctx.Input.Query(":id") + t.Ctx.Input.Query(":page"))
+func (tc *TestController) GetManyRouter() {
+ tc.Ctx.WriteString(tc.Ctx.Input.Query(":id") + tc.Ctx.Input.Query(":page"))
}
type ResStatus struct {
@@ -70,29 +70,29 @@ type ResStatus struct {
Msg string
}
-type JsonController struct {
+type JSONController struct {
Controller
}
-func (this *JsonController) Prepare() {
- this.Data["json"] = "prepare"
- this.ServeJson(true)
+func (jc *JSONController) Prepare() {
+ jc.Data["json"] = "prepare"
+ jc.ServeJSON(true)
}
-func (this *JsonController) Get() {
- this.Data["Username"] = "astaxie"
- this.Ctx.Output.Body([]byte("ok"))
+func (jc *JSONController) Get() {
+ jc.Data["Username"] = "astaxie"
+ jc.Ctx.Output.Body([]byte("ok"))
}
func TestUrlFor(t *testing.T) {
handler := NewControllerRegister()
handler.Add("/api/list", &TestController{}, "*:List")
handler.Add("/person/:last/:first", &TestController{}, "*:Param")
- if a := handler.UrlFor("TestController.List"); a != "/api/list" {
+ if a := handler.URLFor("TestController.List"); a != "/api/list" {
Info(a)
t.Errorf("TestController.List must equal to /api/list")
}
- if a := handler.UrlFor("TestController.Param", ":last", "xie", ":first", "asta"); a != "/person/xie/asta" {
+ if a := handler.URLFor("TestController.Param", ":last", "xie", ":first", "asta"); a != "/person/xie/asta" {
t.Errorf("TestController.Param must equal to /person/xie/asta, but get " + a)
}
}
@@ -100,39 +100,39 @@ func TestUrlFor(t *testing.T) {
func TestUrlFor3(t *testing.T) {
handler := NewControllerRegister()
handler.AddAuto(&TestController{})
- if a := handler.UrlFor("TestController.Myext"); a != "/test/myext" && a != "/Test/Myext" {
+ if a := handler.URLFor("TestController.Myext"); a != "/test/myext" && a != "/Test/Myext" {
t.Errorf("TestController.Myext must equal to /test/myext, but get " + a)
}
- if a := handler.UrlFor("TestController.GetUrl"); a != "/test/geturl" && a != "/Test/GetUrl" {
- t.Errorf("TestController.GetUrl must equal to /test/geturl, but get " + a)
+ if a := handler.URLFor("TestController.GetURL"); a != "/test/geturl" && a != "/Test/GetURL" {
+ t.Errorf("TestController.GetURL must equal to /test/geturl, but get " + a)
}
}
func TestUrlFor2(t *testing.T) {
handler := NewControllerRegister()
handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, "*:List")
- handler.Add("/v1/:username/edit", &TestController{}, "get:GetUrl")
+ handler.Add("/v1/:username/edit", &TestController{}, "get:GetURL")
handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param")
handler.Add("/:year:int/:month:int/:title/:entid", &TestController{})
- if handler.UrlFor("TestController.GetUrl", ":username", "astaxie") != "/v1/astaxie/edit" {
- Info(handler.UrlFor("TestController.GetUrl"))
+ if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" {
+ Info(handler.URLFor("TestController.GetURL"))
t.Errorf("TestController.List must equal to /v1/astaxie/edit")
}
- if handler.UrlFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") !=
+ if handler.URLFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") !=
"/v1/za/cms_12_123.html" {
- Info(handler.UrlFor("TestController.List"))
+ Info(handler.URLFor("TestController.List"))
t.Errorf("TestController.List must equal to /v1/za/cms_12_123.html")
}
- if handler.UrlFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") !=
+ if handler.URLFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") !=
"/v1/za_cms/ttt_12_123.html" {
- Info(handler.UrlFor("TestController.Param"))
+ Info(handler.URLFor("TestController.Param"))
t.Errorf("TestController.List must equal to /v1/za_cms/ttt_12_123.html")
}
- if handler.UrlFor("TestController.Get", ":year", "1111", ":month", "11",
+ if handler.URLFor("TestController.Get", ":year", "1111", ":month", "11",
":title", "aaaa", ":entid", "aaaa") !=
"/1111/11/aaaa/aaaa" {
- Info(handler.UrlFor("TestController.Get"))
+ Info(handler.URLFor("TestController.Get"))
t.Errorf("TestController.Get must equal to /1111/11/aaaa/aaaa")
}
}
@@ -270,7 +270,7 @@ func TestPrepare(t *testing.T) {
w := httptest.NewRecorder()
handler := NewControllerRegister()
- handler.Add("/json/list", &JsonController{})
+ handler.Add("/json/list", &JSONController{})
handler.ServeHTTP(w, r)
if w.Body.String() != `"prepare"` {
t.Errorf(w.Body.String() + "user define func can't run")
@@ -333,6 +333,18 @@ func TestRouterHandler(t *testing.T) {
}
}
+func TestRouterHandlerAll(t *testing.T) {
+ r, _ := http.NewRequest("POST", "/sayhi/a/b/c", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.Handler("/sayhi", http.HandlerFunc(sayhello), true)
+ handler.ServeHTTP(w, r)
+ if w.Body.String() != "sayhello" {
+ t.Errorf("TestRouterHandler can't run")
+ }
+}
+
//
// Benchmarks NewApp:
//
@@ -444,7 +456,7 @@ func TestFilterAfterExec(t *testing.T) {
mux := NewControllerRegister()
mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput)
mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput)
- mux.InsertFilter(url, AfterExec, beegoAfterExec1)
+ mux.InsertFilter(url, AfterExec, beegoAfterExec1, false)
mux.Get(url, beegoFilterFunc)
@@ -506,7 +518,7 @@ func TestFilterFinishRouterMultiFirstOnly(t *testing.T) {
url := "/finishRouterMultiFirstOnly"
mux := NewControllerRegister()
- mux.InsertFilter(url, FinishRouter, beegoFinishRouter1)
+ mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false)
mux.InsertFilter(url, FinishRouter, beegoFinishRouter2)
mux.Get(url, beegoFilterFunc)
@@ -534,7 +546,7 @@ func TestFilterFinishRouterMulti(t *testing.T) {
mux := NewControllerRegister()
mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false)
- mux.InsertFilter(url, FinishRouter, beegoFinishRouter2)
+ mux.InsertFilter(url, FinishRouter, beegoFinishRouter2, false)
mux.Get(url, beegoFilterFunc)
diff --git a/session/couchbase/sess_couchbase.go b/session/couchbase/sess_couchbase.go
index 827d55d9..d5be11d0 100644
--- a/session/couchbase/sess_couchbase.go
+++ b/session/couchbase/sess_couchbase.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// package couchbase for session provider
+// Package couchbase for session provider
//
// depend on github.com/couchbaselabs/go-couchbasee
//
@@ -30,21 +30,22 @@
// }
//
// more docs: http://beego.me/docs/module/session.md
-package session
+package couchbase
import (
"net/http"
"strings"
"sync"
- "github.com/couchbaselabs/go-couchbase"
+ couchbase "github.com/couchbase/go-couchbase"
"github.com/astaxie/beego/session"
)
-var couchbpder = &CouchbaseProvider{}
+var couchbpder = &Provider{}
-type CouchbaseSessionStore struct {
+// SessionStore store each session
+type SessionStore struct {
b *couchbase.Bucket
sid string
lock sync.RWMutex
@@ -52,7 +53,8 @@ type CouchbaseSessionStore struct {
maxlifetime int64
}
-type CouchbaseProvider struct {
+// Provider couchabse provided
+type Provider struct {
maxlifetime int64
savePath string
pool string
@@ -60,42 +62,47 @@ type CouchbaseProvider struct {
b *couchbase.Bucket
}
-func (cs *CouchbaseSessionStore) Set(key, value interface{}) error {
+// Set value to couchabse session
+func (cs *SessionStore) Set(key, value interface{}) error {
cs.lock.Lock()
defer cs.lock.Unlock()
cs.values[key] = value
return nil
}
-func (cs *CouchbaseSessionStore) Get(key interface{}) interface{} {
+// Get value from couchabse session
+func (cs *SessionStore) Get(key interface{}) interface{} {
cs.lock.RLock()
defer cs.lock.RUnlock()
if v, ok := cs.values[key]; ok {
return v
- } else {
- return nil
}
+ return nil
}
-func (cs *CouchbaseSessionStore) Delete(key interface{}) error {
+// Delete value in couchbase session by given key
+func (cs *SessionStore) Delete(key interface{}) error {
cs.lock.Lock()
defer cs.lock.Unlock()
delete(cs.values, key)
return nil
}
-func (cs *CouchbaseSessionStore) Flush() error {
+// Flush Clean all values in couchbase session
+func (cs *SessionStore) Flush() error {
cs.lock.Lock()
defer cs.lock.Unlock()
cs.values = make(map[interface{}]interface{})
return nil
}
-func (cs *CouchbaseSessionStore) SessionID() string {
+// SessionID Get couchbase session store id
+func (cs *SessionStore) SessionID() string {
return cs.sid
}
-func (cs *CouchbaseSessionStore) SessionRelease(w http.ResponseWriter) {
+// SessionRelease Write couchbase session with Gob string
+func (cs *SessionStore) SessionRelease(w http.ResponseWriter) {
defer cs.b.Close()
bo, err := session.EncodeGob(cs.values)
@@ -106,7 +113,7 @@ func (cs *CouchbaseSessionStore) SessionRelease(w http.ResponseWriter) {
cs.b.Set(cs.sid, int(cs.maxlifetime), bo)
}
-func (cp *CouchbaseProvider) getBucket() *couchbase.Bucket {
+func (cp *Provider) getBucket() *couchbase.Bucket {
c, err := couchbase.Connect(cp.savePath)
if err != nil {
return nil
@@ -125,10 +132,10 @@ func (cp *CouchbaseProvider) getBucket() *couchbase.Bucket {
return bucket
}
-// init couchbase session
+// SessionInit init couchbase session
// savepath like couchbase server REST/JSON URL
// e.g. http://host:port/, Pool, Bucket
-func (cp *CouchbaseProvider) SessionInit(maxlifetime int64, savePath string) error {
+func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error {
cp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
@@ -144,8 +151,8 @@ func (cp *CouchbaseProvider) SessionInit(maxlifetime int64, savePath string) err
return nil
}
-// read couchbase session by sid
-func (cp *CouchbaseProvider) SessionRead(sid string) (session.SessionStore, error) {
+// SessionRead read couchbase session by sid
+func (cp *Provider) SessionRead(sid string) (session.Store, error) {
cp.b = cp.getBucket()
var doc []byte
@@ -161,11 +168,13 @@ func (cp *CouchbaseProvider) SessionRead(sid string) (session.SessionStore, erro
}
}
- cs := &CouchbaseSessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
+ cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
return cs, nil
}
-func (cp *CouchbaseProvider) SessionExist(sid string) bool {
+// SessionExist Check couchbase session exist.
+// it checkes sid exist or not.
+func (cp *Provider) SessionExist(sid string) bool {
cp.b = cp.getBucket()
defer cp.b.Close()
@@ -173,12 +182,12 @@ func (cp *CouchbaseProvider) SessionExist(sid string) bool {
if err := cp.b.Get(sid, &doc); err != nil || doc == nil {
return false
- } else {
- return true
}
+ return true
}
-func (cp *CouchbaseProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) {
+// SessionRegenerate remove oldsid and use sid to generate new session
+func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
cp.b = cp.getBucket()
var doc []byte
@@ -206,11 +215,12 @@ func (cp *CouchbaseProvider) SessionRegenerate(oldsid, sid string) (session.Sess
}
}
- cs := &CouchbaseSessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
+ cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
return cs, nil
}
-func (cp *CouchbaseProvider) SessionDestroy(sid string) error {
+// SessionDestroy Remove bucket in this couchbase
+func (cp *Provider) SessionDestroy(sid string) error {
cp.b = cp.getBucket()
defer cp.b.Close()
@@ -218,11 +228,13 @@ func (cp *CouchbaseProvider) SessionDestroy(sid string) error {
return nil
}
-func (cp *CouchbaseProvider) SessionGC() {
+// SessionGC Recycle
+func (cp *Provider) SessionGC() {
return
}
-func (cp *CouchbaseProvider) SessionAll() int {
+// SessionAll return all active session
+func (cp *Provider) SessionAll() int {
return 0
}
diff --git a/session/ledis/ledis_session.go b/session/ledis/ledis_session.go
index 643b8817..68f37b08 100644
--- a/session/ledis/ledis_session.go
+++ b/session/ledis/ledis_session.go
@@ -1,4 +1,5 @@
-package session
+// Package ledis provide session Provider
+package ledis
import (
"net/http"
@@ -11,59 +12,58 @@ import (
"github.com/siddontang/ledisdb/ledis"
)
-var ledispder = &LedisProvider{}
+var ledispder = &Provider{}
var c *ledis.DB
-// ledis session store
-type LedisSessionStore struct {
+// SessionStore ledis session store
+type SessionStore struct {
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
-// set value in ledis session
-func (ls *LedisSessionStore) Set(key, value interface{}) error {
+// Set value in ledis session
+func (ls *SessionStore) Set(key, value interface{}) error {
ls.lock.Lock()
defer ls.lock.Unlock()
ls.values[key] = value
return nil
}
-// get value in ledis session
-func (ls *LedisSessionStore) Get(key interface{}) interface{} {
+// Get value in ledis session
+func (ls *SessionStore) Get(key interface{}) interface{} {
ls.lock.RLock()
defer ls.lock.RUnlock()
if v, ok := ls.values[key]; ok {
return v
- } else {
- return nil
}
+ return nil
}
-// delete value in ledis session
-func (ls *LedisSessionStore) Delete(key interface{}) error {
+// Delete value in ledis session
+func (ls *SessionStore) Delete(key interface{}) error {
ls.lock.Lock()
defer ls.lock.Unlock()
delete(ls.values, key)
return nil
}
-// clear all values in ledis session
-func (ls *LedisSessionStore) Flush() error {
+// Flush clear all values in ledis session
+func (ls *SessionStore) Flush() error {
ls.lock.Lock()
defer ls.lock.Unlock()
ls.values = make(map[interface{}]interface{})
return nil
}
-// get ledis session id
-func (ls *LedisSessionStore) SessionID() string {
+// SessionID get ledis session id
+func (ls *SessionStore) SessionID() string {
return ls.sid
}
-// save session values to ledis
-func (ls *LedisSessionStore) SessionRelease(w http.ResponseWriter) {
+// SessionRelease save session values to ledis
+func (ls *SessionStore) SessionRelease(w http.ResponseWriter) {
b, err := session.EncodeGob(ls.values)
if err != nil {
return
@@ -72,17 +72,17 @@ func (ls *LedisSessionStore) SessionRelease(w http.ResponseWriter) {
c.Expire([]byte(ls.sid), ls.maxlifetime)
}
-// ledis session provider
-type LedisProvider struct {
+// Provider ledis session provider
+type Provider struct {
maxlifetime int64
savePath string
db int
}
-// init ledis session
+// SessionInit init ledis session
// savepath like ledis server saveDataPath,pool size
// e.g. 127.0.0.1:6379,100,astaxie
-func (lp *LedisProvider) SessionInit(maxlifetime int64, savePath string) error {
+func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error {
var err error
lp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
@@ -106,8 +106,8 @@ func (lp *LedisProvider) SessionInit(maxlifetime int64, savePath string) error {
return nil
}
-// read ledis session by sid
-func (lp *LedisProvider) SessionRead(sid string) (session.SessionStore, error) {
+// SessionRead read ledis session by sid
+func (lp *Provider) SessionRead(sid string) (session.Store, error) {
kvs, err := c.Get([]byte(sid))
var kv map[interface{}]interface{}
if len(kvs) == 0 {
@@ -118,22 +118,21 @@ func (lp *LedisProvider) SessionRead(sid string) (session.SessionStore, error) {
return nil, err
}
}
- ls := &LedisSessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime}
+ ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime}
return ls, nil
}
-// check ledis session exist by sid
-func (lp *LedisProvider) SessionExist(sid string) bool {
+// SessionExist check ledis session exist by sid
+func (lp *Provider) SessionExist(sid string) bool {
count, _ := c.Exists([]byte(sid))
if count == 0 {
return false
- } else {
- return true
}
+ return true
}
-// generate new sid for ledis session
-func (lp *LedisProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) {
+// SessionRegenerate generate new sid for ledis session
+func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
count, _ := c.Exists([]byte(sid))
if count == 0 {
// oldsid doesn't exists, set the new sid directly
@@ -156,23 +155,23 @@ func (lp *LedisProvider) SessionRegenerate(oldsid, sid string) (session.SessionS
return nil, err
}
}
- ls := &LedisSessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime}
+ ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime}
return ls, nil
}
-// delete ledis session by id
-func (lp *LedisProvider) SessionDestroy(sid string) error {
+// SessionDestroy delete ledis session by id
+func (lp *Provider) SessionDestroy(sid string) error {
c.Del([]byte(sid))
return nil
}
-// Impelment method, no used.
-func (lp *LedisProvider) SessionGC() {
+// SessionGC Impelment method, no used.
+func (lp *Provider) SessionGC() {
return
}
-// @todo
-func (lp *LedisProvider) SessionAll() int {
+// SessionAll return all active session
+func (lp *Provider) SessionAll() int {
return 0
}
func init() {
diff --git a/session/memcache/sess_memcache.go b/session/memcache/sess_memcache.go
index 5827a0a9..f1069bc9 100644
--- a/session/memcache/sess_memcache.go
+++ b/session/memcache/sess_memcache.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// package memcache for session provider
+// Package memcache for session provider
//
// depend on github.com/bradfitz/gomemcache/memcache
//
@@ -30,7 +30,7 @@
// }
//
// more docs: http://beego.me/docs/module/session.md
-package session
+package memcache
import (
"net/http"
@@ -45,56 +45,55 @@ import (
var mempder = &MemProvider{}
var client *memcache.Client
-// memcache session store
-type MemcacheSessionStore struct {
+// SessionStore memcache session store
+type SessionStore struct {
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
-// set value in memcache session
-func (rs *MemcacheSessionStore) Set(key, value interface{}) error {
+// Set value in memcache session
+func (rs *SessionStore) Set(key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
return nil
}
-// get value in memcache session
-func (rs *MemcacheSessionStore) Get(key interface{}) interface{} {
+// Get value in memcache session
+func (rs *SessionStore) Get(key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
return v
- } else {
- return nil
}
+ return nil
}
-// delete value in memcache session
-func (rs *MemcacheSessionStore) Delete(key interface{}) error {
+// Delete value in memcache session
+func (rs *SessionStore) Delete(key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
return nil
}
-// clear all values in memcache session
-func (rs *MemcacheSessionStore) Flush() error {
+// Flush clear all values in memcache session
+func (rs *SessionStore) Flush() error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
return nil
}
-// get memcache session id
-func (rs *MemcacheSessionStore) SessionID() string {
+// SessionID get memcache session id
+func (rs *SessionStore) SessionID() string {
return rs.sid
}
-// save session values to memcache
-func (rs *MemcacheSessionStore) SessionRelease(w http.ResponseWriter) {
+// SessionRelease save session values to memcache
+func (rs *SessionStore) SessionRelease(w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
@@ -103,7 +102,7 @@ func (rs *MemcacheSessionStore) SessionRelease(w http.ResponseWriter) {
client.Set(&item)
}
-// memcahe session provider
+// MemProvider memcache session provider
type MemProvider struct {
maxlifetime int64
conninfo []string
@@ -111,7 +110,7 @@ type MemProvider struct {
password string
}
-// init memcache session
+// SessionInit init memcache session
// savepath like
// e.g. 127.0.0.1:9090
func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error {
@@ -121,16 +120,17 @@ func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error {
return nil
}
-// read memcache session by sid
-func (rp *MemProvider) SessionRead(sid string) (session.SessionStore, error) {
+// SessionRead read memcache session by sid
+func (rp *MemProvider) SessionRead(sid string) (session.Store, error) {
if client == nil {
if err := rp.connectInit(); err != nil {
return nil, err
}
}
item, err := client.Get(sid)
- if err != nil {
- return nil, err
+ if err != nil && err == memcache.ErrCacheMiss {
+ rs := &SessionStore{sid: sid, values: make(map[interface{}]interface{}), maxlifetime: rp.maxlifetime}
+ return rs, nil
}
var kv map[interface{}]interface{}
if len(item.Value) == 0 {
@@ -141,12 +141,11 @@ func (rp *MemProvider) SessionRead(sid string) (session.SessionStore, error) {
return nil, err
}
}
-
- rs := &MemcacheSessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime}
+ rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
-// check memcache session exist by sid
+// SessionExist check memcache session exist by sid
func (rp *MemProvider) SessionExist(sid string) bool {
if client == nil {
if err := rp.connectInit(); err != nil {
@@ -155,13 +154,12 @@ func (rp *MemProvider) SessionExist(sid string) bool {
}
if item, err := client.Get(sid); err != nil || len(item.Value) == 0 {
return false
- } else {
- return true
}
+ return true
}
-// generate new sid for memcache session
-func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) {
+// SessionRegenerate generate new sid for memcache session
+func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
if client == nil {
if err := rp.connectInit(); err != nil {
return nil, err
@@ -179,7 +177,6 @@ func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.SessionSto
} else {
client.Delete(oldsid)
item.Key = sid
- item.Value = item.Value
item.Expiration = int32(rp.maxlifetime)
client.Set(item)
contain = item.Value
@@ -196,11 +193,11 @@ func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.SessionSto
}
}
- rs := &MemcacheSessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime}
+ rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
-// delete memcache session by id
+// SessionDestroy delete memcache session by id
func (rp *MemProvider) SessionDestroy(sid string) error {
if client == nil {
if err := rp.connectInit(); err != nil {
@@ -220,12 +217,12 @@ func (rp *MemProvider) connectInit() error {
return nil
}
-// Impelment method, no used.
+// SessionGC Impelment method, no used.
func (rp *MemProvider) SessionGC() {
return
}
-// @todo
+// SessionAll return all activeSession
func (rp *MemProvider) SessionAll() int {
return 0
}
diff --git a/session/mysql/sess_mysql.go b/session/mysql/sess_mysql.go
index 0f6d3e4f..969d26c9 100644
--- a/session/mysql/sess_mysql.go
+++ b/session/mysql/sess_mysql.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// package mysql for session provider
+// Package mysql for session provider
//
// depends on github.com/go-sql-driver/mysql:
//
@@ -21,7 +21,7 @@
// mysql session support need create table as sql:
// CREATE TABLE `session` (
// `session_key` char(64) NOT NULL,
-// session_data` blob,
+// `session_data` blob,
// `session_expiry` int(11) unsigned NOT NULL,
// PRIMARY KEY (`session_key`)
// ) ENGINE=MyISAM DEFAULT CHARSET=utf8;
@@ -38,7 +38,7 @@
// }
//
// more docs: http://beego.me/docs/module/session.md
-package session
+package mysql
import (
"database/sql"
@@ -47,82 +47,85 @@ import (
"time"
"github.com/astaxie/beego/session"
-
+ // import mysql driver
_ "github.com/go-sql-driver/mysql"
)
-var mysqlpder = &MysqlProvider{}
+var (
+ // TableName store the session in MySQL
+ TableName = "session"
+ mysqlpder = &Provider{}
+)
-// mysql session store
-type MysqlSessionStore struct {
+// SessionStore mysql session store
+type SessionStore struct {
c *sql.DB
sid string
lock sync.RWMutex
values map[interface{}]interface{}
}
-// set value in mysql session.
+// Set value in mysql session.
// it is temp value in map.
-func (st *MysqlSessionStore) Set(key, value interface{}) error {
+func (st *SessionStore) Set(key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
return nil
}
-// get value from mysql session
-func (st *MysqlSessionStore) Get(key interface{}) interface{} {
+// Get value from mysql session
+func (st *SessionStore) Get(key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
- } else {
- return nil
}
+ return nil
}
-// delete value in mysql session
-func (st *MysqlSessionStore) Delete(key interface{}) error {
+// Delete value in mysql session
+func (st *SessionStore) Delete(key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
return nil
}
-// clear all values in mysql session
-func (st *MysqlSessionStore) Flush() error {
+// Flush clear all values in mysql session
+func (st *SessionStore) Flush() error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
return nil
}
-// get session id of this mysql session store
-func (st *MysqlSessionStore) SessionID() string {
+// SessionID get session id of this mysql session store
+func (st *SessionStore) SessionID() string {
return st.sid
}
-// save mysql session values to database.
+// SessionRelease save mysql session values to database.
// must call this method to save values to database.
-func (st *MysqlSessionStore) SessionRelease(w http.ResponseWriter) {
+func (st *SessionStore) SessionRelease(w http.ResponseWriter) {
defer st.c.Close()
b, err := session.EncodeGob(st.values)
if err != nil {
return
}
- st.c.Exec("UPDATE session set `session_data`=?, `session_expiry`=? where session_key=?",
+ st.c.Exec("UPDATE "+TableName+" set `session_data`=?, `session_expiry`=? where session_key=?",
b, time.Now().Unix(), st.sid)
}
-// mysql session provider
-type MysqlProvider struct {
+// Provider mysql session provider
+type Provider struct {
maxlifetime int64
savePath string
}
// connect to mysql
-func (mp *MysqlProvider) connectInit() *sql.DB {
+func (mp *Provider) connectInit() *sql.DB {
db, e := sql.Open("mysql", mp.savePath)
if e != nil {
return nil
@@ -130,22 +133,22 @@ func (mp *MysqlProvider) connectInit() *sql.DB {
return db
}
-// init mysql session.
+// SessionInit init mysql session.
// savepath is the connection string of mysql.
-func (mp *MysqlProvider) SessionInit(maxlifetime int64, savePath string) error {
+func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error {
mp.maxlifetime = maxlifetime
mp.savePath = savePath
return nil
}
-// get mysql session by sid
-func (mp *MysqlProvider) SessionRead(sid string) (session.SessionStore, error) {
+// SessionRead get mysql session by sid
+func (mp *Provider) SessionRead(sid string) (session.Store, error) {
c := mp.connectInit()
- row := c.QueryRow("select session_data from session where session_key=?", sid)
+ row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
- c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)",
+ c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)",
sid, "", time.Now().Unix())
}
var kv map[interface{}]interface{}
@@ -157,34 +160,33 @@ func (mp *MysqlProvider) SessionRead(sid string) (session.SessionStore, error) {
return nil, err
}
}
- rs := &MysqlSessionStore{c: c, sid: sid, values: kv}
+ rs := &SessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
-// check mysql session exist
-func (mp *MysqlProvider) SessionExist(sid string) bool {
+// SessionExist check mysql session exist
+func (mp *Provider) SessionExist(sid string) bool {
c := mp.connectInit()
defer c.Close()
- row := c.QueryRow("select session_data from session where session_key=?", sid)
+ row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
return false
- } else {
- return true
}
+ return true
}
-// generate new sid for mysql session
-func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) {
+// SessionRegenerate generate new sid for mysql session
+func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
c := mp.connectInit()
- row := c.QueryRow("select session_data from session where session_key=?", oldsid)
+ row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
- c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix())
+ c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix())
}
- c.Exec("update session set `session_key`=? where session_key=?", sid, oldsid)
+ c.Exec("update "+TableName+" set `session_key`=? where session_key=?", sid, oldsid)
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
@@ -194,32 +196,32 @@ func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (session.SessionS
return nil, err
}
}
- rs := &MysqlSessionStore{c: c, sid: sid, values: kv}
+ rs := &SessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
-// delete mysql session by sid
-func (mp *MysqlProvider) SessionDestroy(sid string) error {
+// SessionDestroy delete mysql session by sid
+func (mp *Provider) SessionDestroy(sid string) error {
c := mp.connectInit()
- c.Exec("DELETE FROM session where session_key=?", sid)
+ c.Exec("DELETE FROM "+TableName+" where session_key=?", sid)
c.Close()
return nil
}
-// delete expired values in mysql session
-func (mp *MysqlProvider) SessionGC() {
+// SessionGC delete expired values in mysql session
+func (mp *Provider) SessionGC() {
c := mp.connectInit()
- c.Exec("DELETE from session where session_expiry < ?", time.Now().Unix()-mp.maxlifetime)
+ c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime)
c.Close()
return
}
-// count values in mysql session
-func (mp *MysqlProvider) SessionAll() int {
+// SessionAll count values in mysql session
+func (mp *Provider) SessionAll() int {
c := mp.connectInit()
defer c.Close()
var total int
- err := c.QueryRow("SELECT count(*) as num from session").Scan(&total)
+ err := c.QueryRow("SELECT count(*) as num from " + TableName).Scan(&total)
if err != nil {
return 0
}
diff --git a/session/postgres/sess_postgresql.go b/session/postgres/sess_postgresql.go
index ac9a1612..73f9c13a 100644
--- a/session/postgres/sess_postgresql.go
+++ b/session/postgres/sess_postgresql.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// package postgresql for session provider
+// Package postgres for session provider
//
// depends on github.com/lib/pq:
//
@@ -27,9 +27,9 @@
// session_expiry timestamp NOT NULL,
// CONSTRAINT session_key PRIMARY KEY(session_key)
// );
-
+//
// will be activated with these settings in app.conf:
-
+//
// SessionOn = true
// SessionProvider = postgresql
// SessionSavePath = "user=a password=b dbname=c sslmode=disable"
@@ -48,7 +48,7 @@
// }
//
// more docs: http://beego.me/docs/module/session.md
-package session
+package postgres
import (
"database/sql"
@@ -57,64 +57,63 @@ import (
"time"
"github.com/astaxie/beego/session"
-
+ // import postgresql Driver
_ "github.com/lib/pq"
)
-var postgresqlpder = &PostgresqlProvider{}
+var postgresqlpder = &Provider{}
-// postgresql session store
-type PostgresqlSessionStore struct {
+// SessionStore postgresql session store
+type SessionStore struct {
c *sql.DB
sid string
lock sync.RWMutex
values map[interface{}]interface{}
}
-// set value in postgresql session.
+// Set value in postgresql session.
// it is temp value in map.
-func (st *PostgresqlSessionStore) Set(key, value interface{}) error {
+func (st *SessionStore) Set(key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
return nil
}
-// get value from postgresql session
-func (st *PostgresqlSessionStore) Get(key interface{}) interface{} {
+// Get value from postgresql session
+func (st *SessionStore) Get(key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
- } else {
- return nil
}
+ return nil
}
-// delete value in postgresql session
-func (st *PostgresqlSessionStore) Delete(key interface{}) error {
+// Delete value in postgresql session
+func (st *SessionStore) Delete(key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
return nil
}
-// clear all values in postgresql session
-func (st *PostgresqlSessionStore) Flush() error {
+// Flush clear all values in postgresql session
+func (st *SessionStore) Flush() error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
return nil
}
-// get session id of this postgresql session store
-func (st *PostgresqlSessionStore) SessionID() string {
+// SessionID get session id of this postgresql session store
+func (st *SessionStore) SessionID() string {
return st.sid
}
-// save postgresql session values to database.
+// SessionRelease save postgresql session values to database.
// must call this method to save values to database.
-func (st *PostgresqlSessionStore) SessionRelease(w http.ResponseWriter) {
+func (st *SessionStore) SessionRelease(w http.ResponseWriter) {
defer st.c.Close()
b, err := session.EncodeGob(st.values)
if err != nil {
@@ -125,14 +124,14 @@ func (st *PostgresqlSessionStore) SessionRelease(w http.ResponseWriter) {
}
-// postgresql session provider
-type PostgresqlProvider struct {
+// Provider postgresql session provider
+type Provider struct {
maxlifetime int64
savePath string
}
// connect to postgresql
-func (mp *PostgresqlProvider) connectInit() *sql.DB {
+func (mp *Provider) connectInit() *sql.DB {
db, e := sql.Open("postgres", mp.savePath)
if e != nil {
return nil
@@ -140,16 +139,16 @@ func (mp *PostgresqlProvider) connectInit() *sql.DB {
return db
}
-// init postgresql session.
+// SessionInit init postgresql session.
// savepath is the connection string of postgresql.
-func (mp *PostgresqlProvider) SessionInit(maxlifetime int64, savePath string) error {
+func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error {
mp.maxlifetime = maxlifetime
mp.savePath = savePath
return nil
}
-// get postgresql session by sid
-func (mp *PostgresqlProvider) SessionRead(sid string) (session.SessionStore, error) {
+// SessionRead get postgresql session by sid
+func (mp *Provider) SessionRead(sid string) (session.Store, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=$1", sid)
var sessiondata []byte
@@ -174,12 +173,12 @@ func (mp *PostgresqlProvider) SessionRead(sid string) (session.SessionStore, err
return nil, err
}
}
- rs := &PostgresqlSessionStore{c: c, sid: sid, values: kv}
+ rs := &SessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
-// check postgresql session exist
-func (mp *PostgresqlProvider) SessionExist(sid string) bool {
+// SessionExist check postgresql session exist
+func (mp *Provider) SessionExist(sid string) bool {
c := mp.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from session where session_key=$1", sid)
@@ -188,13 +187,12 @@ func (mp *PostgresqlProvider) SessionExist(sid string) bool {
if err == sql.ErrNoRows {
return false
- } else {
- return true
}
+ return true
}
-// generate new sid for postgresql session
-func (mp *PostgresqlProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) {
+// SessionRegenerate generate new sid for postgresql session
+func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=$1", oldsid)
var sessiondata []byte
@@ -213,28 +211,28 @@ func (mp *PostgresqlProvider) SessionRegenerate(oldsid, sid string) (session.Ses
return nil, err
}
}
- rs := &PostgresqlSessionStore{c: c, sid: sid, values: kv}
+ rs := &SessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
-// delete postgresql session by sid
-func (mp *PostgresqlProvider) SessionDestroy(sid string) error {
+// SessionDestroy delete postgresql session by sid
+func (mp *Provider) SessionDestroy(sid string) error {
c := mp.connectInit()
c.Exec("DELETE FROM session where session_key=$1", sid)
c.Close()
return nil
}
-// delete expired values in postgresql session
-func (mp *PostgresqlProvider) SessionGC() {
+// SessionGC delete expired values in postgresql session
+func (mp *Provider) SessionGC() {
c := mp.connectInit()
c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime)
c.Close()
return
}
-// count values in postgresql session
-func (mp *PostgresqlProvider) SessionAll() int {
+// SessionAll count values in postgresql session
+func (mp *Provider) SessionAll() int {
c := mp.connectInit()
defer c.Close()
var total int
diff --git a/session/redis/sess_redis.go b/session/redis/sess_redis.go
index 887fb520..c46fa7cd 100644
--- a/session/redis/sess_redis.go
+++ b/session/redis/sess_redis.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// package redis for session provider
+// Package redis for session provider
//
// depend on github.com/garyburd/redigo/redis
//
@@ -30,7 +30,7 @@
// }
//
// more docs: http://beego.me/docs/module/session.md
-package session
+package redis
import (
"net/http"
@@ -43,15 +43,13 @@ import (
"github.com/garyburd/redigo/redis"
)
-var redispder = &RedisProvider{}
+var redispder = &Provider{}
-// redis max pool size
-var MAX_POOL_SIZE = 100
+// MaxPoolSize redis max pool size
+var MaxPoolSize = 100
-var redisPool chan redis.Conn
-
-// redis session store
-type RedisSessionStore struct {
+// SessionStore redis session store
+type SessionStore struct {
p *redis.Pool
sid string
lock sync.RWMutex
@@ -59,61 +57,58 @@ type RedisSessionStore struct {
maxlifetime int64
}
-// set value in redis session
-func (rs *RedisSessionStore) Set(key, value interface{}) error {
+// Set value in redis session
+func (rs *SessionStore) Set(key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
return nil
}
-// get value in redis session
-func (rs *RedisSessionStore) Get(key interface{}) interface{} {
+// Get value in redis session
+func (rs *SessionStore) Get(key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
return v
- } else {
- return nil
}
+ return nil
}
-// delete value in redis session
-func (rs *RedisSessionStore) Delete(key interface{}) error {
+// Delete value in redis session
+func (rs *SessionStore) Delete(key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
return nil
}
-// clear all values in redis session
-func (rs *RedisSessionStore) Flush() error {
+// Flush clear all values in redis session
+func (rs *SessionStore) Flush() error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
return nil
}
-// get redis session id
-func (rs *RedisSessionStore) SessionID() string {
+// SessionID get redis session id
+func (rs *SessionStore) SessionID() string {
return rs.sid
}
-// save session values to redis
-func (rs *RedisSessionStore) SessionRelease(w http.ResponseWriter) {
- c := rs.p.Get()
- defer c.Close()
-
+// SessionRelease save session values to redis
+func (rs *SessionStore) SessionRelease(w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
}
-
+ c := rs.p.Get()
+ defer c.Close()
c.Do("SETEX", rs.sid, rs.maxlifetime, string(b))
}
-// redis session provider
-type RedisProvider struct {
+// Provider redis session provider
+type Provider struct {
maxlifetime int64
savePath string
poolsize int
@@ -122,10 +117,10 @@ type RedisProvider struct {
poollist *redis.Pool
}
-// init redis session
+// SessionInit init redis session
// savepath like redis server addr,pool size,password,dbnum
// e.g. 127.0.0.1:6379,100,astaxie,0
-func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error {
+func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
@@ -134,18 +129,18 @@ func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error {
if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize <= 0 {
- rp.poolsize = MAX_POOL_SIZE
+ rp.poolsize = MaxPoolSize
} else {
rp.poolsize = poolsize
}
} else {
- rp.poolsize = MAX_POOL_SIZE
+ rp.poolsize = MaxPoolSize
}
if len(configs) > 2 {
rp.password = configs[2]
}
if len(configs) > 3 {
- dbnum, err := strconv.Atoi(configs[1])
+ dbnum, err := strconv.Atoi(configs[3])
if err != nil || dbnum < 0 {
rp.dbNum = 0
} else {
@@ -176,8 +171,8 @@ func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error {
return rp.poollist.Get().Err()
}
-// read redis session by sid
-func (rp *RedisProvider) SessionRead(sid string) (session.SessionStore, error) {
+// SessionRead read redis session by sid
+func (rp *Provider) SessionRead(sid string) (session.Store, error) {
c := rp.poollist.Get()
defer c.Close()
@@ -192,24 +187,23 @@ func (rp *RedisProvider) SessionRead(sid string) (session.SessionStore, error) {
}
}
- rs := &RedisSessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
+ rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
-// check redis session exist by sid
-func (rp *RedisProvider) SessionExist(sid string) bool {
+// SessionExist check redis session exist by sid
+func (rp *Provider) SessionExist(sid string) bool {
c := rp.poollist.Get()
defer c.Close()
if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 {
return false
- } else {
- return true
}
+ return true
}
-// generate new sid for redis session
-func (rp *RedisProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) {
+// SessionRegenerate generate new sid for redis session
+func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
c := rp.poollist.Get()
defer c.Close()
@@ -234,12 +228,12 @@ func (rp *RedisProvider) SessionRegenerate(oldsid, sid string) (session.SessionS
}
}
- rs := &RedisSessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
+ rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
-// delete redis session by id
-func (rp *RedisProvider) SessionDestroy(sid string) error {
+// SessionDestroy delete redis session by id
+func (rp *Provider) SessionDestroy(sid string) error {
c := rp.poollist.Get()
defer c.Close()
@@ -247,13 +241,13 @@ func (rp *RedisProvider) SessionDestroy(sid string) error {
return nil
}
-// Impelment method, no used.
-func (rp *RedisProvider) SessionGC() {
+// SessionGC Impelment method, no used.
+func (rp *Provider) SessionGC() {
return
}
-// @todo
-func (rp *RedisProvider) SessionAll() int {
+// SessionAll return all activeSession
+func (rp *Provider) SessionAll() int {
return 0
}
diff --git a/session/sess_cookie.go b/session/sess_cookie.go
index 01dc505c..3fefa360 100644
--- a/session/sess_cookie.go
+++ b/session/sess_cookie.go
@@ -25,7 +25,7 @@ import (
var cookiepder = &CookieProvider{}
-// Cookie SessionStore
+// CookieSessionStore Cookie SessionStore
type CookieSessionStore struct {
sid string
values map[interface{}]interface{} // session data
@@ -47,9 +47,8 @@ func (st *CookieSessionStore) Get(key interface{}) interface{} {
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
- } else {
- return nil
}
+ return nil
}
// Delete value in cookie session
@@ -60,7 +59,7 @@ func (st *CookieSessionStore) Delete(key interface{}) error {
return nil
}
-// Clean all values in cookie session
+// Flush Clean all values in cookie session
func (st *CookieSessionStore) Flush() error {
st.lock.Lock()
defer st.lock.Unlock()
@@ -68,12 +67,12 @@ func (st *CookieSessionStore) Flush() error {
return nil
}
-// Return id of this cookie session
+// SessionID Return id of this cookie session
func (st *CookieSessionStore) SessionID() string {
return st.sid
}
-// Write cookie session to http response cookie
+// SessionRelease Write cookie session to http response cookie
func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) {
str, err := encodeCookie(cookiepder.block,
cookiepder.config.SecurityKey,
@@ -101,14 +100,14 @@ type cookieConfig struct {
Maxage int `json:"maxage"`
}
-// Cookie session provider
+// CookieProvider Cookie session provider
type CookieProvider struct {
maxlifetime int64
config *cookieConfig
block cipher.Block
}
-// Init cookie session provider with max lifetime and config json.
+// SessionInit Init cookie session provider with max lifetime and config json.
// maxlifetime is ignored.
// json config:
// securityKey - hash string
@@ -136,9 +135,9 @@ func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error
return nil
}
-// Get SessionStore in cooke.
+// SessionRead Get SessionStore in cooke.
// decode cooke string to map and put into SessionStore with sid.
-func (pder *CookieProvider) SessionRead(sid string) (SessionStore, error) {
+func (pder *CookieProvider) SessionRead(sid string) (Store, error) {
maps, _ := decodeCookie(pder.block,
pder.config.SecurityKey,
pder.config.SecurityName,
@@ -150,32 +149,32 @@ func (pder *CookieProvider) SessionRead(sid string) (SessionStore, error) {
return rs, nil
}
-// Cookie session is always existed
+// SessionExist Cookie session is always existed
func (pder *CookieProvider) SessionExist(sid string) bool {
return true
}
-// Implement method, no used.
-func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
+// SessionRegenerate Implement method, no used.
+func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (Store, error) {
return nil, nil
}
-// Implement method, no used.
+// SessionDestroy Implement method, no used.
func (pder *CookieProvider) SessionDestroy(sid string) error {
return nil
}
-// Implement method, no used.
+// SessionGC Implement method, no used.
func (pder *CookieProvider) SessionGC() {
return
}
-// Implement method, return 0.
+// SessionAll Implement method, return 0.
func (pder *CookieProvider) SessionAll() int {
return 0
}
-// Implement method, no used.
+// SessionUpdate Implement method, no used.
func (pder *CookieProvider) SessionUpdate(sid string) error {
return nil
}
diff --git a/session/sess_cookie_test.go b/session/sess_cookie_test.go
index fe3ac806..209e501c 100644
--- a/session/sess_cookie_test.go
+++ b/session/sess_cookie_test.go
@@ -53,3 +53,44 @@ func TestCookie(t *testing.T) {
}
}
}
+
+func TestDestorySessionCookie(t *testing.T) {
+ config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
+ globalSessions, err := NewManager("cookie", config)
+ if err != nil {
+ t.Fatal("init cookie session err", err)
+ }
+
+ r, _ := http.NewRequest("GET", "/", nil)
+ w := httptest.NewRecorder()
+ session, err := globalSessions.SessionStart(w, r)
+ if err != nil {
+ t.Fatal("session start err,", err)
+ }
+
+ // request again ,will get same sesssion id .
+ r1, _ := http.NewRequest("GET", "/", nil)
+ r1.Header.Set("Cookie", w.Header().Get("Set-Cookie"))
+ w = httptest.NewRecorder()
+ newSession, err := globalSessions.SessionStart(w, r1)
+ if err != nil {
+ t.Fatal("session start err,", err)
+ }
+ if newSession.SessionID() != session.SessionID() {
+ t.Fatal("get cookie session id is not the same again.")
+ }
+
+ // After destroy session , will get a new session id .
+ globalSessions.SessionDestroy(w, r1)
+ r2, _ := http.NewRequest("GET", "/", nil)
+ r2.Header.Set("Cookie", w.Header().Get("Set-Cookie"))
+
+ w = httptest.NewRecorder()
+ newSession, err = globalSessions.SessionStart(w, r2)
+ if err != nil {
+ t.Fatal("session start error")
+ }
+ if newSession.SessionID() == session.SessionID() {
+ t.Fatal("after destroy session and reqeust again ,get cookie session id is same.")
+ }
+}
diff --git a/session/sess_file.go b/session/sess_file.go
index b1084acf..9265b030 100644
--- a/session/sess_file.go
+++ b/session/sess_file.go
@@ -32,7 +32,7 @@ var (
gcmaxlifetime int64
)
-// File session store
+// FileSessionStore File session store
type FileSessionStore struct {
sid string
lock sync.RWMutex
@@ -53,9 +53,8 @@ func (fs *FileSessionStore) Get(key interface{}) interface{} {
defer fs.lock.RUnlock()
if v, ok := fs.values[key]; ok {
return v
- } else {
- return nil
}
+ return nil
}
// Delete value in file session by given key
@@ -66,7 +65,7 @@ func (fs *FileSessionStore) Delete(key interface{}) error {
return nil
}
-// Clean all values in file session
+// Flush Clean all values in file session
func (fs *FileSessionStore) Flush() error {
fs.lock.Lock()
defer fs.lock.Unlock()
@@ -74,12 +73,12 @@ func (fs *FileSessionStore) Flush() error {
return nil
}
-// Get file session store id
+// SessionID Get file session store id
func (fs *FileSessionStore) SessionID() string {
return fs.sid
}
-// Write file session to local file with Gob string
+// SessionRelease Write file session to local file with Gob string
func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
b, err := EncodeGob(fs.values)
if err != nil {
@@ -100,14 +99,14 @@ func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
f.Close()
}
-// File session provider
+// FileProvider File session provider
type FileProvider struct {
lock sync.RWMutex
maxlifetime int64
savePath string
}
-// Init file session provider.
+// SessionInit Init file session provider.
// savePath sets the session files path.
func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error {
fp.maxlifetime = maxlifetime
@@ -115,10 +114,10 @@ func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error {
return nil
}
-// Read file session by sid.
+// SessionRead Read file session by sid.
// if file is not exist, create it.
// the file path is generated from sid string.
-func (fp *FileProvider) SessionRead(sid string) (SessionStore, error) {
+func (fp *FileProvider) SessionRead(sid string) (Store, error) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
@@ -154,7 +153,7 @@ func (fp *FileProvider) SessionRead(sid string) (SessionStore, error) {
return ss, nil
}
-// Check file session exist.
+// SessionExist Check file session exist.
// it checkes the file named from sid exist or not.
func (fp *FileProvider) SessionExist(sid string) bool {
filepder.lock.Lock()
@@ -163,12 +162,11 @@ func (fp *FileProvider) SessionExist(sid string) bool {
_, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
if err == nil {
return true
- } else {
- return false
}
+ return false
}
-// Remove all files in this save path
+// SessionDestroy Remove all files in this save path
func (fp *FileProvider) SessionDestroy(sid string) error {
filepder.lock.Lock()
defer filepder.lock.Unlock()
@@ -176,7 +174,7 @@ func (fp *FileProvider) SessionDestroy(sid string) error {
return nil
}
-// Recycle files in save path
+// SessionGC Recycle files in save path
func (fp *FileProvider) SessionGC() {
filepder.lock.Lock()
defer filepder.lock.Unlock()
@@ -185,7 +183,7 @@ func (fp *FileProvider) SessionGC() {
filepath.Walk(fp.savePath, gcpath)
}
-// Get active file session number.
+// SessionAll Get active file session number.
// it walks save path to count files.
func (fp *FileProvider) SessionAll() int {
a := &activeSession{}
@@ -199,9 +197,9 @@ func (fp *FileProvider) SessionAll() int {
return a.total
}
-// Generate new sid for file session.
+// SessionRegenerate Generate new sid for file session.
// it delete old file and create new file named from new sid.
-func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
+func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
@@ -269,14 +267,14 @@ type activeSession struct {
total int
}
-func (self *activeSession) visit(paths string, f os.FileInfo, err error) error {
+func (as *activeSession) visit(paths string, f os.FileInfo, err error) error {
if err != nil {
return err
}
if f.IsDir() {
return nil
}
- self.total = self.total + 1
+ as.total = as.total + 1
return nil
}
diff --git a/session/sess_mem.go b/session/sess_mem.go
index 627a3246..dd61ef57 100644
--- a/session/sess_mem.go
+++ b/session/sess_mem.go
@@ -1,199 +1,196 @@
-// Copyright 2014 beego Author. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package session
-
-import (
- "container/list"
- "net/http"
- "sync"
- "time"
-)
-
-var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)}
-
-// memory session store.
-// it saved sessions in a map in memory.
-type MemSessionStore struct {
- sid string //session id
- timeAccessed time.Time //last access time
- value map[interface{}]interface{} //session store
- lock sync.RWMutex
-}
-
-// set value to memory session
-func (st *MemSessionStore) Set(key, value interface{}) error {
- st.lock.Lock()
- defer st.lock.Unlock()
- st.value[key] = value
- return nil
-}
-
-// get value from memory session by key
-func (st *MemSessionStore) Get(key interface{}) interface{} {
- st.lock.RLock()
- defer st.lock.RUnlock()
- if v, ok := st.value[key]; ok {
- return v
- } else {
- return nil
- }
-}
-
-// delete in memory session by key
-func (st *MemSessionStore) Delete(key interface{}) error {
- st.lock.Lock()
- defer st.lock.Unlock()
- delete(st.value, key)
- return nil
-}
-
-// clear all values in memory session
-func (st *MemSessionStore) Flush() error {
- st.lock.Lock()
- defer st.lock.Unlock()
- st.value = make(map[interface{}]interface{})
- return nil
-}
-
-// get this id of memory session store
-func (st *MemSessionStore) SessionID() string {
- return st.sid
-}
-
-// Implement method, no used.
-func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) {
-}
-
-type MemProvider struct {
- lock sync.RWMutex // locker
- sessions map[string]*list.Element // map in memory
- list *list.List // for gc
- maxlifetime int64
- savePath string
-}
-
-// init memory session
-func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error {
- pder.maxlifetime = maxlifetime
- pder.savePath = savePath
- return nil
-}
-
-// get memory session store by sid
-func (pder *MemProvider) SessionRead(sid string) (SessionStore, error) {
- pder.lock.RLock()
- if element, ok := pder.sessions[sid]; ok {
- go pder.SessionUpdate(sid)
- pder.lock.RUnlock()
- return element.Value.(*MemSessionStore), nil
- } else {
- pder.lock.RUnlock()
- pder.lock.Lock()
- newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})}
- element := pder.list.PushBack(newsess)
- pder.sessions[sid] = element
- pder.lock.Unlock()
- return newsess, nil
- }
-}
-
-// check session store exist in memory session by sid
-func (pder *MemProvider) SessionExist(sid string) bool {
- pder.lock.RLock()
- defer pder.lock.RUnlock()
- if _, ok := pder.sessions[sid]; ok {
- return true
- } else {
- return false
- }
-}
-
-// generate new sid for session store in memory session
-func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
- pder.lock.RLock()
- if element, ok := pder.sessions[oldsid]; ok {
- go pder.SessionUpdate(oldsid)
- pder.lock.RUnlock()
- pder.lock.Lock()
- element.Value.(*MemSessionStore).sid = sid
- pder.sessions[sid] = element
- delete(pder.sessions, oldsid)
- pder.lock.Unlock()
- return element.Value.(*MemSessionStore), nil
- } else {
- pder.lock.RUnlock()
- pder.lock.Lock()
- newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})}
- element := pder.list.PushBack(newsess)
- pder.sessions[sid] = element
- pder.lock.Unlock()
- return newsess, nil
- }
-}
-
-// delete session store in memory session by id
-func (pder *MemProvider) SessionDestroy(sid string) error {
- pder.lock.Lock()
- defer pder.lock.Unlock()
- if element, ok := pder.sessions[sid]; ok {
- delete(pder.sessions, sid)
- pder.list.Remove(element)
- return nil
- }
- return nil
-}
-
-// clean expired session stores in memory session
-func (pder *MemProvider) SessionGC() {
- pder.lock.RLock()
- for {
- element := pder.list.Back()
- if element == nil {
- break
- }
- if (element.Value.(*MemSessionStore).timeAccessed.Unix() + pder.maxlifetime) < time.Now().Unix() {
- pder.lock.RUnlock()
- pder.lock.Lock()
- pder.list.Remove(element)
- delete(pder.sessions, element.Value.(*MemSessionStore).sid)
- pder.lock.Unlock()
- pder.lock.RLock()
- } else {
- break
- }
- }
- pder.lock.RUnlock()
-}
-
-// get count number of memory session
-func (pder *MemProvider) SessionAll() int {
- return pder.list.Len()
-}
-
-// expand time of session store by id in memory session
-func (pder *MemProvider) SessionUpdate(sid string) error {
- pder.lock.Lock()
- defer pder.lock.Unlock()
- if element, ok := pder.sessions[sid]; ok {
- element.Value.(*MemSessionStore).timeAccessed = time.Now()
- pder.list.MoveToFront(element)
- return nil
- }
- return nil
-}
-
-func init() {
- Register("memory", mempder)
-}
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package session
+
+import (
+ "container/list"
+ "net/http"
+ "sync"
+ "time"
+)
+
+var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)}
+
+// MemSessionStore memory session store.
+// it saved sessions in a map in memory.
+type MemSessionStore struct {
+ sid string //session id
+ timeAccessed time.Time //last access time
+ value map[interface{}]interface{} //session store
+ lock sync.RWMutex
+}
+
+// Set value to memory session
+func (st *MemSessionStore) Set(key, value interface{}) error {
+ st.lock.Lock()
+ defer st.lock.Unlock()
+ st.value[key] = value
+ return nil
+}
+
+// Get value from memory session by key
+func (st *MemSessionStore) Get(key interface{}) interface{} {
+ st.lock.RLock()
+ defer st.lock.RUnlock()
+ if v, ok := st.value[key]; ok {
+ return v
+ }
+ return nil
+}
+
+// Delete in memory session by key
+func (st *MemSessionStore) Delete(key interface{}) error {
+ st.lock.Lock()
+ defer st.lock.Unlock()
+ delete(st.value, key)
+ return nil
+}
+
+// Flush clear all values in memory session
+func (st *MemSessionStore) Flush() error {
+ st.lock.Lock()
+ defer st.lock.Unlock()
+ st.value = make(map[interface{}]interface{})
+ return nil
+}
+
+// SessionID get this id of memory session store
+func (st *MemSessionStore) SessionID() string {
+ return st.sid
+}
+
+// SessionRelease Implement method, no used.
+func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) {
+}
+
+// MemProvider Implement the provider interface
+type MemProvider struct {
+ lock sync.RWMutex // locker
+ sessions map[string]*list.Element // map in memory
+ list *list.List // for gc
+ maxlifetime int64
+ savePath string
+}
+
+// SessionInit init memory session
+func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error {
+ pder.maxlifetime = maxlifetime
+ pder.savePath = savePath
+ return nil
+}
+
+// SessionRead get memory session store by sid
+func (pder *MemProvider) SessionRead(sid string) (Store, error) {
+ pder.lock.RLock()
+ if element, ok := pder.sessions[sid]; ok {
+ go pder.SessionUpdate(sid)
+ pder.lock.RUnlock()
+ return element.Value.(*MemSessionStore), nil
+ }
+ pder.lock.RUnlock()
+ pder.lock.Lock()
+ newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})}
+ element := pder.list.PushBack(newsess)
+ pder.sessions[sid] = element
+ pder.lock.Unlock()
+ return newsess, nil
+}
+
+// SessionExist check session store exist in memory session by sid
+func (pder *MemProvider) SessionExist(sid string) bool {
+ pder.lock.RLock()
+ defer pder.lock.RUnlock()
+ if _, ok := pder.sessions[sid]; ok {
+ return true
+ }
+ return false
+}
+
+// SessionRegenerate generate new sid for session store in memory session
+func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) {
+ pder.lock.RLock()
+ if element, ok := pder.sessions[oldsid]; ok {
+ go pder.SessionUpdate(oldsid)
+ pder.lock.RUnlock()
+ pder.lock.Lock()
+ element.Value.(*MemSessionStore).sid = sid
+ pder.sessions[sid] = element
+ delete(pder.sessions, oldsid)
+ pder.lock.Unlock()
+ return element.Value.(*MemSessionStore), nil
+ }
+ pder.lock.RUnlock()
+ pder.lock.Lock()
+ newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})}
+ element := pder.list.PushBack(newsess)
+ pder.sessions[sid] = element
+ pder.lock.Unlock()
+ return newsess, nil
+}
+
+// SessionDestroy delete session store in memory session by id
+func (pder *MemProvider) SessionDestroy(sid string) error {
+ pder.lock.Lock()
+ defer pder.lock.Unlock()
+ if element, ok := pder.sessions[sid]; ok {
+ delete(pder.sessions, sid)
+ pder.list.Remove(element)
+ return nil
+ }
+ return nil
+}
+
+// SessionGC clean expired session stores in memory session
+func (pder *MemProvider) SessionGC() {
+ pder.lock.RLock()
+ for {
+ element := pder.list.Back()
+ if element == nil {
+ break
+ }
+ if (element.Value.(*MemSessionStore).timeAccessed.Unix() + pder.maxlifetime) < time.Now().Unix() {
+ pder.lock.RUnlock()
+ pder.lock.Lock()
+ pder.list.Remove(element)
+ delete(pder.sessions, element.Value.(*MemSessionStore).sid)
+ pder.lock.Unlock()
+ pder.lock.RLock()
+ } else {
+ break
+ }
+ }
+ pder.lock.RUnlock()
+}
+
+// SessionAll get count number of memory session
+func (pder *MemProvider) SessionAll() int {
+ return pder.list.Len()
+}
+
+// SessionUpdate expand time of session store by id in memory session
+func (pder *MemProvider) SessionUpdate(sid string) error {
+ pder.lock.Lock()
+ defer pder.lock.Unlock()
+ if element, ok := pder.sessions[sid]; ok {
+ element.Value.(*MemSessionStore).timeAccessed = time.Now()
+ pder.list.MoveToFront(element)
+ return nil
+ }
+ return nil
+}
+
+func init() {
+ Register("memory", mempder)
+}
diff --git a/session/sess_utils.go b/session/sess_utils.go
index 9ae74528..d7db5ba8 100644
--- a/session/sess_utils.go
+++ b/session/sess_utils.go
@@ -43,6 +43,7 @@ func init() {
gob.Register(map[int]int64{})
}
+// EncodeGob encode the obj to gob
func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) {
for _, v := range obj {
gob.Register(v)
@@ -56,6 +57,7 @@ func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) {
return buf.Bytes(), nil
}
+// DecodeGob decode data to map
func DecodeGob(encoded []byte) (map[interface{}]interface{}, error) {
buf := bytes.NewBuffer(encoded)
dec := gob.NewDecoder(buf)
@@ -178,11 +180,11 @@ func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime
return nil, err
}
// 5. DecodeGob.
- if dst, err := DecodeGob(b); err != nil {
+ dst, err := DecodeGob(b)
+ if err != nil {
return nil, err
- } else {
- return dst, nil
}
+ return dst, nil
}
// Encoding -------------------------------------------------------------------
diff --git a/session/session.go b/session/session.go
index 3cbd2b05..39d475fc 100644
--- a/session/session.go
+++ b/session/session.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// package session provider
+// Package session provider
//
// Usage:
// import(
@@ -20,7 +20,7 @@
// )
//
// func init() {
-// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "sessionIDHashFunc": "sha1", "sessionIDHashKey": "", "cookieLifeTime": 3600, "providerConfig": ""}`)
+// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "cookieLifeTime": 3600, "providerConfig": ""}`)
// go globalSessions.GC()
// }
//
@@ -37,8 +37,8 @@ import (
"time"
)
-// SessionStore contains all data for one session process with specific id.
-type SessionStore interface {
+// Store contains all data for one session process with specific id.
+type Store interface {
Set(key, value interface{}) error //set session value
Get(key interface{}) interface{} //get session value
Delete(key interface{}) error //delete session value
@@ -51,9 +51,9 @@ type SessionStore interface {
// it can operate a SessionStore by its id.
type Provider interface {
SessionInit(gclifetime int64, config string) error
- SessionRead(sid string) (SessionStore, error)
+ SessionRead(sid string) (Store, error)
SessionExist(sid string) bool
- SessionRegenerate(oldsid, sid string) (SessionStore, error)
+ SessionRegenerate(oldsid, sid string) (Store, error)
SessionDestroy(sid string) error
SessionAll() int //get all active session
SessionGC()
@@ -83,7 +83,7 @@ type managerConfig struct {
CookieLifeTime int `json:"cookieLifeTime"`
ProviderConfig string `json:"providerConfig"`
Domain string `json:"domain"`
- SessionIdLength int64 `json:"sessionIdLength"`
+ SessionIDLength int64 `json:"sessionIDLength"`
}
// Manager contains Provider and its configuration.
@@ -92,7 +92,7 @@ type Manager struct {
config *managerConfig
}
-// Create new Manager with provider name and json config string.
+// NewManager Create new Manager with provider name and json config string.
// provider name:
// 1. cookie
// 2. file
@@ -123,8 +123,8 @@ func NewManager(provideName, config string) (*Manager, error) {
return nil, err
}
- if cf.SessionIdLength == 0 {
- cf.SessionIdLength = 16
+ if cf.SessionIDLength == 0 {
+ cf.SessionIDLength = 16
}
return &Manager{
@@ -133,105 +133,115 @@ func NewManager(provideName, config string) (*Manager, error) {
}, nil
}
-// Start session. generate or read the session id from http request.
-// if session id exists, return SessionStore with this id.
-func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore, err error) {
+// getSid retrieves session identifier from HTTP Request.
+// First try to retrieve id by reading from cookie, session cookie name is configurable,
+// if not exist, then retrieve id from querying parameters.
+//
+// error is not nil when there is anything wrong.
+// sid is empty when need to generate a new session id
+// otherwise return an valid session id.
+func (manager *Manager) getSid(r *http.Request) (string, error) {
cookie, errs := r.Cookie(manager.config.CookieName)
- if errs != nil || cookie.Value == "" {
- sid, errs := manager.sessionId(r)
+ if errs != nil || cookie.Value == "" || cookie.MaxAge < 0 {
+ errs := r.ParseForm()
if errs != nil {
- return nil, errs
- }
- session, err = manager.provider.SessionRead(sid)
- cookie = &http.Cookie{Name: manager.config.CookieName,
- Value: url.QueryEscape(sid),
- Path: "/",
- HttpOnly: true,
- Secure: manager.config.Secure,
- Domain: manager.config.Domain}
- if manager.config.CookieLifeTime >= 0 {
- cookie.MaxAge = manager.config.CookieLifeTime
- }
- if manager.config.EnableSetCookie {
- http.SetCookie(w, cookie)
- }
- r.AddCookie(cookie)
- } else {
- sid, errs := url.QueryUnescape(cookie.Value)
- if errs != nil {
- return nil, errs
- }
- if manager.provider.SessionExist(sid) {
- session, err = manager.provider.SessionRead(sid)
- } else {
- sid, err = manager.sessionId(r)
- if err != nil {
- return nil, err
- }
- session, err = manager.provider.SessionRead(sid)
- cookie = &http.Cookie{Name: manager.config.CookieName,
- Value: url.QueryEscape(sid),
- Path: "/",
- HttpOnly: true,
- Secure: manager.config.Secure,
- Domain: manager.config.Domain}
- if manager.config.CookieLifeTime >= 0 {
- cookie.MaxAge = manager.config.CookieLifeTime
- }
- if manager.config.EnableSetCookie {
- http.SetCookie(w, cookie)
- }
- r.AddCookie(cookie)
+ return "", errs
}
+
+ sid := r.FormValue(manager.config.CookieName)
+ return sid, nil
}
+
+ // HTTP Request contains cookie for sessionid info.
+ return url.QueryUnescape(cookie.Value)
+}
+
+// SessionStart generate or read the session id from http request.
+// if session id exists, return SessionStore with this id.
+func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session Store, err error) {
+ sid, errs := manager.getSid(r)
+ if errs != nil {
+ return nil, errs
+ }
+
+ if sid != "" && manager.provider.SessionExist(sid) {
+ return manager.provider.SessionRead(sid)
+ }
+
+ // Generate a new session
+ sid, errs = manager.sessionID()
+ if errs != nil {
+ return nil, errs
+ }
+
+ session, err = manager.provider.SessionRead(sid)
+ cookie := &http.Cookie{
+ Name: manager.config.CookieName,
+ Value: url.QueryEscape(sid),
+ Path: "/",
+ HttpOnly: true,
+ Secure: manager.isSecure(r),
+ Domain: manager.config.Domain,
+ }
+ if manager.config.CookieLifeTime > 0 {
+ cookie.MaxAge = manager.config.CookieLifeTime
+ cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second)
+ }
+ if manager.config.EnableSetCookie {
+ http.SetCookie(w, cookie)
+ }
+ r.AddCookie(cookie)
+
return
}
-// Destroy session by its id in http request cookie.
+// SessionDestroy Destroy session by its id in http request cookie.
func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" {
return
- } else {
- manager.provider.SessionDestroy(cookie.Value)
+ }
+ manager.provider.SessionDestroy(cookie.Value)
+ if manager.config.EnableSetCookie {
expiration := time.Now()
- cookie := http.Cookie{Name: manager.config.CookieName,
+ cookie = &http.Cookie{Name: manager.config.CookieName,
Path: "/",
HttpOnly: true,
Expires: expiration,
MaxAge: -1}
- http.SetCookie(w, &cookie)
+
+ http.SetCookie(w, cookie)
}
}
-// Get SessionStore by its id.
-func (manager *Manager) GetSessionStore(sid string) (sessions SessionStore, err error) {
+// GetSessionStore Get SessionStore by its id.
+func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) {
sessions, err = manager.provider.SessionRead(sid)
return
}
-// Start session gc process.
+// GC Start session gc process.
// it can do gc in times after gc lifetime.
func (manager *Manager) GC() {
manager.provider.SessionGC()
time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() })
}
-// Regenerate a session id for this SessionStore who's id is saving in http request.
-func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) {
- sid, err := manager.sessionId(r)
+// SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request.
+func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) (session Store) {
+ sid, err := manager.sessionID()
if err != nil {
return
}
cookie, err := r.Cookie(manager.config.CookieName)
- if err != nil && cookie.Value == "" {
+ if err != nil || cookie.Value == "" {
//delete old cookie
session, _ = manager.provider.SessionRead(sid)
cookie = &http.Cookie{Name: manager.config.CookieName,
Value: url.QueryEscape(sid),
Path: "/",
HttpOnly: true,
- Secure: manager.config.Secure,
+ Secure: manager.isSecure(r),
Domain: manager.config.Domain,
}
} else {
@@ -241,29 +251,46 @@ func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Reque
cookie.HttpOnly = true
cookie.Path = "/"
}
- if manager.config.CookieLifeTime >= 0 {
+ if manager.config.CookieLifeTime > 0 {
cookie.MaxAge = manager.config.CookieLifeTime
+ cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second)
+ }
+ if manager.config.EnableSetCookie {
+ http.SetCookie(w, cookie)
}
- http.SetCookie(w, cookie)
r.AddCookie(cookie)
return
}
-// Get all active sessions count number.
+// GetActiveSession Get all active sessions count number.
func (manager *Manager) GetActiveSession() int {
return manager.provider.SessionAll()
}
-// Set cookie with https.
+// SetSecure Set cookie with https.
func (manager *Manager) SetSecure(secure bool) {
manager.config.Secure = secure
}
-func (manager *Manager) sessionId(r *http.Request) (string, error) {
- b := make([]byte, manager.config.SessionIdLength)
+func (manager *Manager) sessionID() (string, error) {
+ b := make([]byte, manager.config.SessionIDLength)
n, err := rand.Read(b)
if n != len(b) || err != nil {
return "", fmt.Errorf("Could not successfully read from the system CSPRNG.")
}
return hex.EncodeToString(b), nil
}
+
+// Set cookie with https.
+func (manager *Manager) isSecure(req *http.Request) bool {
+ if !manager.config.Secure {
+ return false
+ }
+ if req.URL.Scheme != "" {
+ return req.URL.Scheme == "https"
+ }
+ if req.TLS == nil {
+ return false
+ }
+ return true
+}
diff --git a/staticfile.go b/staticfile.go
index 5ab853a3..9534ce91 100644
--- a/staticfile.go
+++ b/staticfile.go
@@ -15,107 +15,181 @@
package beego
import (
+ "bytes"
+ "errors"
"net/http"
"os"
"path"
+ "path/filepath"
"strconv"
"strings"
+ "sync"
+ "time"
"github.com/astaxie/beego/context"
- "github.com/astaxie/beego/utils"
)
+var errNotStaticRequest = errors.New("request not a static file request")
+
func serverStaticRouter(ctx *context.Context) {
if ctx.Input.Method() != "GET" && ctx.Input.Method() != "HEAD" {
return
}
- requestPath := path.Clean(ctx.Input.Request.URL.Path)
- i := 0
- for prefix, staticDir := range StaticDir {
+
+ forbidden, filePath, fileInfo, err := lookupFile(ctx)
+ if err == errNotStaticRequest {
+ return
+ }
+
+ if forbidden {
+ exception("403", ctx)
+ return
+ }
+
+ if filePath == "" || fileInfo == nil {
+ if BConfig.RunMode == DEV {
+ Warn("Can't find/open the file:", filePath, err)
+ }
+ http.NotFound(ctx.ResponseWriter, ctx.Request)
+ return
+ }
+ if fileInfo.IsDir() {
+ //serveFile will list dir
+ http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath)
+ return
+ }
+
+ var enableCompress = BConfig.EnableGzip && isStaticCompress(filePath)
+ var acceptEncoding string
+ if enableCompress {
+ acceptEncoding = context.ParseEncoding(ctx.Request)
+ }
+ b, n, sch, err := openFile(filePath, fileInfo, acceptEncoding)
+ if err != nil {
+ if BConfig.RunMode == DEV {
+ Warn("Can't compress the file:", filePath, err)
+ }
+ http.NotFound(ctx.ResponseWriter, ctx.Request)
+ return
+ }
+
+ if b {
+ ctx.Output.Header("Content-Encoding", n)
+ } else {
+ ctx.Output.Header("Content-Length", strconv.FormatInt(sch.size, 10))
+ }
+
+ http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, sch.modTime, sch)
+ return
+
+}
+
+type serveContentHolder struct {
+ *bytes.Reader
+ modTime time.Time
+ size int64
+ encoding string
+}
+
+var (
+ staticFileMap = make(map[string]*serveContentHolder)
+ mapLock sync.Mutex
+)
+
+func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, error) {
+ mapKey := acceptEncoding + ":" + filePath
+ mapFile, _ := staticFileMap[mapKey]
+ if isOk(mapFile, fi) {
+ return mapFile.encoding != "", mapFile.encoding, mapFile, nil
+ }
+ mapLock.Lock()
+ defer mapLock.Unlock()
+ if mapFile, _ = staticFileMap[mapKey]; !isOk(mapFile, fi) {
+ file, err := os.Open(filePath)
+ if err != nil {
+ return false, "", nil, err
+ }
+ defer file.Close()
+ var bufferWriter bytes.Buffer
+ _, n, err := context.WriteFile(acceptEncoding, &bufferWriter, file)
+ if err != nil {
+ return false, "", nil, err
+ }
+ mapFile = &serveContentHolder{Reader: bytes.NewReader(bufferWriter.Bytes()), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), encoding: n}
+ staticFileMap[mapKey] = mapFile
+ }
+
+ return mapFile.encoding != "", mapFile.encoding, mapFile, nil
+}
+
+func isOk(s *serveContentHolder, fi os.FileInfo) bool {
+ if s == nil {
+ return false
+ }
+ return s.modTime == fi.ModTime() && s.size == fi.Size()
+}
+
+// isStaticCompress detect static files
+func isStaticCompress(filePath string) bool {
+ for _, statExtension := range BConfig.WebConfig.StaticExtensionsToGzip {
+ if strings.HasSuffix(strings.ToLower(filePath), strings.ToLower(statExtension)) {
+ return true
+ }
+ }
+ return false
+}
+
+// searchFile search the file by url path
+// if none the static file prefix matches ,return notStaticRequestErr
+func searchFile(ctx *context.Context) (string, os.FileInfo, error) {
+ requestPath := filepath.ToSlash(filepath.Clean(ctx.Request.URL.Path))
+ // special processing : favicon.ico/robots.txt can be in any static dir
+ if requestPath == "/favicon.ico" || requestPath == "/robots.txt" {
+ file := path.Join(".", requestPath)
+ if fi, _ := os.Stat(file); fi != nil {
+ return file, fi, nil
+ }
+ for _, staticDir := range BConfig.WebConfig.StaticDir {
+ filePath := path.Join(staticDir, requestPath)
+ if fi, _ := os.Stat(filePath); fi != nil {
+ return filePath, fi, nil
+ }
+ }
+ return "", nil, errors.New(requestPath + " file not find")
+ }
+
+ for prefix, staticDir := range BConfig.WebConfig.StaticDir {
if len(prefix) == 0 {
continue
}
- if requestPath == "/favicon.ico" || requestPath == "/robots.txt" {
- file := path.Join(staticDir, requestPath)
- if utils.FileExists(file) {
- http.ServeFile(ctx.ResponseWriter, ctx.Request, file)
- return
- } else {
- i++
- if i == len(StaticDir) {
- http.NotFound(ctx.ResponseWriter, ctx.Request)
- return
- } else {
- continue
- }
- }
+ if !strings.Contains(requestPath, prefix) {
+ continue
}
- if strings.HasPrefix(requestPath, prefix) {
- if len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' {
- continue
- }
- file := path.Join(staticDir, requestPath[len(prefix):])
- finfo, err := os.Stat(file)
- if err != nil {
- if RunMode == "dev" {
- Warn(err)
- }
- http.NotFound(ctx.ResponseWriter, ctx.Request)
- return
- }
- //if the request is dir and DirectoryIndex is false then
- if finfo.IsDir() {
- if !DirectoryIndex {
- exception("403", ctx)
- return
- } else if ctx.Input.Request.URL.Path[len(ctx.Input.Request.URL.Path)-1] != '/' {
- http.Redirect(ctx.ResponseWriter, ctx.Request, ctx.Input.Request.URL.Path+"/", 302)
- return
- }
- } else if strings.HasSuffix(requestPath, "/index.html") {
- file := path.Join(staticDir, requestPath)
- if utils.FileExists(file) {
- http.ServeFile(ctx.ResponseWriter, ctx.Request, file)
- return
- }
- }
-
- //This block obtained from (https://github.com/smithfox/beego) - it should probably get merged into astaxie/beego after a pull request
- isStaticFileToCompress := false
- if StaticExtensionsToGzip != nil && len(StaticExtensionsToGzip) > 0 {
- for _, statExtension := range StaticExtensionsToGzip {
- if strings.HasSuffix(strings.ToLower(file), strings.ToLower(statExtension)) {
- isStaticFileToCompress = true
- break
- }
- }
- }
-
- if isStaticFileToCompress {
- var contentEncoding string
- if EnableGzip {
- contentEncoding = getAcceptEncodingZip(ctx.Request)
- }
-
- memzipfile, err := openMemZipFile(file, contentEncoding)
- if err != nil {
- return
- }
-
- if contentEncoding == "gzip" {
- ctx.Output.Header("Content-Encoding", "gzip")
- } else if contentEncoding == "deflate" {
- ctx.Output.Header("Content-Encoding", "deflate")
- } else {
- ctx.Output.Header("Content-Length", strconv.FormatInt(finfo.Size(), 10))
- }
-
- http.ServeContent(ctx.ResponseWriter, ctx.Request, file, finfo.ModTime(), memzipfile)
-
- } else {
- http.ServeFile(ctx.ResponseWriter, ctx.Request, file)
- }
- return
+ if len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' {
+ continue
+ }
+ filePath := path.Join(staticDir, requestPath[len(prefix):])
+ if fi, err := os.Stat(filePath); fi != nil {
+ return filePath, fi, err
}
}
+ return "", nil, errNotStaticRequest
+}
+
+// lookupFile find the file to serve
+// if the file is dir ,search the index.html as default file( MUST NOT A DIR also)
+// if the index.html not exist or is a dir, give a forbidden response depending on DirectoryIndex
+func lookupFile(ctx *context.Context) (bool, string, os.FileInfo, error) {
+ fp, fi, err := searchFile(ctx)
+ if fp == "" || fi == nil {
+ return false, "", nil, err
+ }
+ if !fi.IsDir() {
+ return false, fp, fi, err
+ }
+ ifp := filepath.Join(fp, "index.html")
+ if ifi, _ := os.Stat(ifp); ifi != nil && ifi.Mode().IsRegular() {
+ return false, ifp, ifi, err
+ }
+ return !BConfig.WebConfig.DirectoryIndex, fp, fi, err
}
diff --git a/staticfile_test.go b/staticfile_test.go
new file mode 100644
index 00000000..d3333570
--- /dev/null
+++ b/staticfile_test.go
@@ -0,0 +1,71 @@
+package beego
+
+import (
+ "bytes"
+ "compress/gzip"
+ "compress/zlib"
+ "io"
+ "io/ioutil"
+ "os"
+ "testing"
+)
+
+const licenseFile = "./LICENSE"
+
+func testOpenFile(encoding string, content []byte, t *testing.T) {
+ fi, _ := os.Stat(licenseFile)
+ b, n, sch, err := openFile(licenseFile, fi, encoding)
+ if err != nil {
+ t.Log(err)
+ t.Fail()
+ }
+
+ t.Log("open static file encoding "+n, b)
+
+ assetOpenFileAndContent(sch, content, t)
+}
+func TestOpenStaticFile_1(t *testing.T) {
+ file, _ := os.Open(licenseFile)
+ content, _ := ioutil.ReadAll(file)
+ testOpenFile("", content, t)
+}
+
+func TestOpenStaticFileGzip_1(t *testing.T) {
+ file, _ := os.Open(licenseFile)
+ var zipBuf bytes.Buffer
+ fileWriter, _ := gzip.NewWriterLevel(&zipBuf, gzip.BestCompression)
+ io.Copy(fileWriter, file)
+ fileWriter.Close()
+ content, _ := ioutil.ReadAll(&zipBuf)
+
+ testOpenFile("gzip", content, t)
+}
+func TestOpenStaticFileDeflate_1(t *testing.T) {
+ file, _ := os.Open(licenseFile)
+ var zipBuf bytes.Buffer
+ fileWriter, _ := zlib.NewWriterLevel(&zipBuf, zlib.BestCompression)
+ io.Copy(fileWriter, file)
+ fileWriter.Close()
+ content, _ := ioutil.ReadAll(&zipBuf)
+
+ testOpenFile("deflate", content, t)
+}
+
+func assetOpenFileAndContent(sch *serveContentHolder, content []byte, t *testing.T) {
+ t.Log(sch.size, len(content))
+ if sch.size != int64(len(content)) {
+ t.Log("static content file size not same")
+ t.Fail()
+ }
+ bs, _ := ioutil.ReadAll(sch)
+ for i, v := range content {
+ if v != bs[i] {
+ t.Log("content not same")
+ t.Fail()
+ }
+ }
+ if len(staticFileMap) == 0 {
+ t.Log("men map is empty")
+ t.Fail()
+ }
+}
diff --git a/swagger/docsSpec.go b/swagger/docs_spec.go
similarity index 72%
rename from swagger/docsSpec.go
rename to swagger/docs_spec.go
index 6f8fe1db..680324dc 100644
--- a/swagger/docsSpec.go
+++ b/swagger/docs_spec.go
@@ -12,53 +12,59 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// swagger struct definition
+// Package swagger struct definition
package swagger
+// SwaggerVersion show the current swagger version
const SwaggerVersion = "1.2"
+// ResourceListing list the resource
type ResourceListing struct {
- ApiVersion string `json:"apiVersion"`
+ APIVersion string `json:"apiVersion"`
SwaggerVersion string `json:"swaggerVersion"` // e.g 1.2
// BasePath string `json:"basePath"` obsolete in 1.1
- Apis []ApiRef `json:"apis"`
- Infos Infomation `json:"info"`
+ APIs []APIRef `json:"apis"`
+ Info Information `json:"info"`
}
-type ApiRef struct {
+// APIRef description the api path and description
+type APIRef struct {
Path string `json:"path"` // relative or absolute, must start with /
Description string `json:"description"`
}
-type Infomation struct {
+// Information show the API Information
+type Information struct {
Title string `json:"title,omitempty"`
Description string `json:"description,omitempty"`
Contact string `json:"contact,omitempty"`
- TermsOfServiceUrl string `json:"termsOfServiceUrl,omitempty"`
+ TermsOfServiceURL string `json:"termsOfServiceUrl,omitempty"`
License string `json:"license,omitempty"`
- LicenseUrl string `json:"licenseUrl,omitempty"`
+ LicenseURL string `json:"licenseUrl,omitempty"`
}
-// https://github.com/wordnik/swagger-core/blob/scala_2.10-1.3-RC3/schemas/api-declaration-schema.json
-type ApiDeclaration struct {
- ApiVersion string `json:"apiVersion"`
+// APIDeclaration see https://github.com/wordnik/swagger-core/blob/scala_2.10-1.3-RC3/schemas/api-declaration-schema.json
+type APIDeclaration struct {
+ APIVersion string `json:"apiVersion"`
SwaggerVersion string `json:"swaggerVersion"`
BasePath string `json:"basePath"`
ResourcePath string `json:"resourcePath"` // must start with /
Consumes []string `json:"consumes,omitempty"`
Produces []string `json:"produces,omitempty"`
- Apis []Api `json:"apis,omitempty"`
+ APIs []API `json:"apis,omitempty"`
Models map[string]Model `json:"models,omitempty"`
}
-type Api struct {
+// API show tha API struct
+type API struct {
Path string `json:"path"` // relative or absolute, must start with /
Description string `json:"description"`
Operations []Operation `json:"operations,omitempty"`
}
+// Operation desc the Operation
type Operation struct {
- HttpMethod string `json:"httpMethod"`
+ HTTPMethod string `json:"httpMethod"`
Nickname string `json:"nickname"`
Type string `json:"type"` // in 1.1 = DataType
// ResponseClass string `json:"responseClass"` obsolete in 1.2
@@ -72,15 +78,18 @@ type Operation struct {
Protocols []Protocol `json:"protocols,omitempty"`
}
+// Protocol support which Protocol
type Protocol struct {
}
+// ResponseMessage Show the
type ResponseMessage struct {
Code int `json:"code"`
Message string `json:"message"`
ResponseModel string `json:"responseModel"`
}
+// Parameter desc the request parameters
type Parameter struct {
ParamType string `json:"paramType"` // path,query,body,header,form
Name string `json:"name"`
@@ -94,17 +103,20 @@ type Parameter struct {
Maximum int `json:"maximum"`
}
+// ErrorResponse desc response
type ErrorResponse struct {
Code int `json:"code"`
Reason string `json:"reason"`
}
+// Model define the data model
type Model struct {
- Id string `json:"id"`
+ ID string `json:"id"`
Required []string `json:"required,omitempty"`
Properties map[string]ModelProperty `json:"properties"`
}
+// ModelProperty define the properties
type ModelProperty struct {
Type string `json:"type"`
Description string `json:"description"`
@@ -112,20 +124,20 @@ type ModelProperty struct {
Format string `json:"format"`
}
-// https://github.com/wordnik/swagger-core/wiki/authorizations
+// Authorization see https://github.com/wordnik/swagger-core/wiki/authorizations
type Authorization struct {
LocalOAuth OAuth `json:"local-oauth"`
- ApiKey ApiKey `json:"apiKey"`
+ APIKey APIKey `json:"apiKey"`
}
-// https://github.com/wordnik/swagger-core/wiki/authorizations
+// OAuth see https://github.com/wordnik/swagger-core/wiki/authorizations
type OAuth struct {
Type string `json:"type"` // e.g. oauth2
Scopes []string `json:"scopes"` // e.g. PUBLIC
GrantTypes map[string]GrantType `json:"grantTypes"`
}
-// https://github.com/wordnik/swagger-core/wiki/authorizations
+// GrantType see https://github.com/wordnik/swagger-core/wiki/authorizations
type GrantType struct {
LoginEndpoint Endpoint `json:"loginEndpoint"`
TokenName string `json:"tokenName"` // e.g. access_code
@@ -133,16 +145,16 @@ type GrantType struct {
TokenEndpoint Endpoint `json:"tokenEndpoint"`
}
-// https://github.com/wordnik/swagger-core/wiki/authorizations
+// Endpoint see https://github.com/wordnik/swagger-core/wiki/authorizations
type Endpoint struct {
- Url string `json:"url"`
- ClientIdName string `json:"clientIdName"`
+ URL string `json:"url"`
+ ClientIDName string `json:"clientIdName"`
ClientSecretName string `json:"clientSecretName"`
TokenName string `json:"tokenName"`
}
-// https://github.com/wordnik/swagger-core/wiki/authorizations
-type ApiKey struct {
+// APIKey see https://github.com/wordnik/swagger-core/wiki/authorizations
+type APIKey struct {
Type string `json:"type"` // e.g. apiKey
PassAs string `json:"passAs"` // e.g. header
}
diff --git a/template.go b/template.go
index 64b1939e..363c6754 100644
--- a/template.go
+++ b/template.go
@@ -28,17 +28,14 @@ import (
)
var (
- beegoTplFuncMap template.FuncMap
- // beego template caching map and supported template file extensions.
- BeeTemplates map[string]*template.Template
- BeeTemplateExt []string
+ beegoTplFuncMap = make(template.FuncMap)
+ // BeeTemplates caching map and supported template file extensions.
+ BeeTemplates = make(map[string]*template.Template)
+ // BeeTemplateExt stores the template extension which will build
+ BeeTemplateExt = []string{"tpl", "html"}
)
func init() {
- BeeTemplates = make(map[string]*template.Template)
- beegoTplFuncMap = make(template.FuncMap)
- BeeTemplateExt = make([]string, 0)
- BeeTemplateExt = append(BeeTemplateExt, "tpl", "html")
beegoTplFuncMap["dateformat"] = DateFormat
beegoTplFuncMap["date"] = Date
beegoTplFuncMap["compare"] = Compare
@@ -46,14 +43,15 @@ func init() {
beegoTplFuncMap["not_nil"] = NotNil
beegoTplFuncMap["not_null"] = NotNil
beegoTplFuncMap["substr"] = Substr
- beegoTplFuncMap["html2str"] = Html2str
+ beegoTplFuncMap["html2str"] = HTML2str
beegoTplFuncMap["str2html"] = Str2html
beegoTplFuncMap["htmlquote"] = Htmlquote
beegoTplFuncMap["htmlunquote"] = Htmlunquote
beegoTplFuncMap["renderform"] = RenderForm
beegoTplFuncMap["assets_js"] = AssetsJs
- beegoTplFuncMap["assets_css"] = AssetsCss
- beegoTplFuncMap["config"] = Config
+ beegoTplFuncMap["assets_css"] = AssetsCSS
+ beegoTplFuncMap["config"] = GetConfig
+ beegoTplFuncMap["map_get"] = MapGet
// go1.2 added template funcs
// Comparisons
@@ -64,7 +62,7 @@ func init() {
beegoTplFuncMap["lt"] = lt // <
beegoTplFuncMap["ne"] = ne // !=
- beegoTplFuncMap["urlfor"] = UrlFor // !=
+ beegoTplFuncMap["urlfor"] = URLFor // !=
}
// AddFuncMap let user to register a func in the template.
@@ -78,7 +76,7 @@ type templatefile struct {
files map[string][]string
}
-func (self *templatefile) visit(paths string, f os.FileInfo, err error) error {
+func (tf *templatefile) visit(paths string, f os.FileInfo, err error) error {
if f == nil {
return err
}
@@ -91,21 +89,21 @@ func (self *templatefile) visit(paths string, f os.FileInfo, err error) error {
replace := strings.NewReplacer("\\", "/")
a := []byte(paths)
- a = a[len([]byte(self.root)):]
+ a = a[len([]byte(tf.root)):]
file := strings.TrimLeft(replace.Replace(string(a)), "/")
subdir := filepath.Dir(file)
- if _, ok := self.files[subdir]; ok {
- self.files[subdir] = append(self.files[subdir], file)
+ if _, ok := tf.files[subdir]; ok {
+ tf.files[subdir] = append(tf.files[subdir], file)
} else {
m := make([]string, 1)
m[0] = file
- self.files[subdir] = m
+ tf.files[subdir] = m
}
return nil
}
-// return this path contains supported template extension of beego or not.
+// HasTemplateExt return this path contains supported template extension of beego or not.
func HasTemplateExt(paths string) bool {
for _, v := range BeeTemplateExt {
if strings.HasSuffix(paths, "."+v) {
@@ -115,7 +113,7 @@ func HasTemplateExt(paths string) bool {
return false
}
-// add new extension for template.
+// AddTemplateExt add new extension for template.
func AddTemplateExt(ext string) {
for _, v := range BeeTemplateExt {
if v == ext {
@@ -125,15 +123,14 @@ func AddTemplateExt(ext string) {
BeeTemplateExt = append(BeeTemplateExt, ext)
}
-// build all template files in a directory.
+// BuildTemplate will build all template files in a directory.
// it makes beego can render any template file in view directory.
-func BuildTemplate(dir string) error {
+func BuildTemplate(dir string, files ...string) error {
if _, err := os.Stat(dir); err != nil {
if os.IsNotExist(err) {
return nil
- } else {
- return errors.New("dir open err")
}
+ return errors.New("dir open err")
}
self := &templatefile{
root: dir,
@@ -148,11 +145,13 @@ func BuildTemplate(dir string) error {
}
for _, v := range self.files {
for _, file := range v {
- t, err := getTemplate(self.root, file, v...)
- if err != nil {
- Trace("parse template err:", file, err)
- } else {
- BeeTemplates[file] = t
+ if len(files) == 0 || utils.InSlice(file, files) {
+ t, err := getTemplate(self.root, file, v...)
+ if err != nil {
+ Trace("parse template err:", file, err)
+ } else {
+ BeeTemplates[file] = t
+ }
}
}
}
@@ -177,7 +176,7 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp
if err != nil {
return nil, [][]string{}, err
}
- reg := regexp.MustCompile(TemplateLeft + "[ ]*template[ ]+\"([^\"]+)\"")
+ reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*template[ ]+\"([^\"]+)\"")
allsub := reg.FindAllStringSubmatch(string(data), -1)
for _, m := range allsub {
if len(m) == 2 {
@@ -198,7 +197,7 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp
}
func getTemplate(root, file string, others ...string) (t *template.Template, err error) {
- t = template.New(file).Delims(TemplateLeft, TemplateRight).Funcs(beegoTplFuncMap)
+ t = template.New(file).Delims(BConfig.WebConfig.TemplateLeft, BConfig.WebConfig.TemplateRight).Funcs(beegoTplFuncMap)
var submods [][]string
t, submods, err = getTplDeep(root, file, "", t)
if err != nil {
@@ -240,7 +239,7 @@ func _getTemplate(t0 *template.Template, root string, submods [][]string, others
if err != nil {
continue
}
- reg := regexp.MustCompile(TemplateLeft + "[ ]*define[ ]+\"([^\"]+)\"")
+ reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*define[ ]+\"([^\"]+)\"")
allsub := reg.FindAllStringSubmatch(string(data), -1)
for _, sub := range allsub {
if len(sub) == 2 && sub[1] == m[1] {
@@ -260,3 +259,30 @@ func _getTemplate(t0 *template.Template, root string, submods [][]string, others
}
return
}
+
+// SetViewsPath sets view directory path in beego application.
+func SetViewsPath(path string) *App {
+ BConfig.WebConfig.ViewsPath = path
+ return BeeApp
+}
+
+// SetStaticPath sets static directory path and proper url pattern in beego application.
+// if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public".
+func SetStaticPath(url string, path string) *App {
+ if !strings.HasPrefix(url, "/") {
+ url = "/" + url
+ }
+ url = strings.TrimRight(url, "/")
+ BConfig.WebConfig.StaticDir[url] = path
+ return BeeApp
+}
+
+// DelStaticPath removes the static folder setting in this url pattern in beego application.
+func DelStaticPath(url string) *App {
+ if !strings.HasPrefix(url, "/") {
+ url = "/" + url
+ }
+ url = strings.TrimRight(url, "/")
+ delete(BConfig.WebConfig.StaticDir, url)
+ return BeeApp
+}
diff --git a/template_test.go b/template_test.go
index b35da5ce..2e222efc 100644
--- a/template_test.go
+++ b/template_test.go
@@ -20,11 +20,11 @@ import (
"testing"
)
-var header string = `{{define "header"}}
+var header = `{{define "header"}}
Hello, astaxie!
{{end}}`
-var index string = `
+var index = `
beego welcome template
@@ -37,7 +37,7 @@ var index string = `
`
-var block string = `{{define "block"}}
+var block = `{{define "block"}}