1
0
mirror of https://github.com/astaxie/beego.git synced 2024-06-30 11:14:13 +00:00

beego 1.6.0 released

This commit is contained in:
astaxie 2016-01-17 22:55:09 +08:00
commit 895748d632
153 changed files with 6227 additions and 5255 deletions

1
.gitignore vendored
View File

@ -2,3 +2,4 @@
.DS_Store .DS_Store
*.swp *.swp
*.swo *.swo
beego.iml

30
.travis.yml Normal file
View File

@ -0,0 +1,30 @@
language: go
go:
- 1.5.1
services:
- redis-server
- mysql
- postgresql
- memcached
env:
- ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db
- ORM_DRIVER=mysql ORM_SOURCE="root:@/orm_test?charset=utf8"
- ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
install:
- go get github.com/lib/pq
- go get github.com/go-sql-driver/mysql
- go get github.com/mattn/go-sqlite3
- go get github.com/bradfitz/gomemcache/memcache
- go get github.com/garyburd/redigo/redis
- go get github.com/beego/x2j
- go get github.com/beego/goyaml2
- go get github.com/belogik/goes
- go get github.com/couchbase/go-couchbase
- go get github.com/siddontang/ledisdb/config
- go get github.com/siddontang/ledisdb/ledis
before_script:
- sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi"
- sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi"
- sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi"

52
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,52 @@
# Contributing to beego
beego is an open source project.
It is the work of hundreds of contributors. We appreciate your help!
Here are instructions to get you started. They are probably not perfect,
please let us know if anything feels wrong or incomplete.
## Contribution guidelines
### Pull requests
First of all. beego follow the gitflow. So please send you pull request
to **develop** branch. We will close the pull request to master branch.
We are always happy to receive pull requests, and do our best to
review them as fast as possible. Not sure if that typo is worth a pull
request? Do it! We will appreciate it.
If your pull request is not accepted on the first try, don't be
discouraged! Sometimes we can make a mistake, please do more explaining
for us. We will appreciate it.
We're trying very hard to keep beego simple and fast. We don't want it
to do everything for everybody. This means that we might decide against
incorporating a new feature. But we will give you some advice on how to
do it in other way.
### Create issues
Any significant improvement should be documented as [a GitHub
issue](https://github.com/astaxie/beego/issues) before anybody
starts working on it.
Also when filing an issue, make sure to answer these five questions:
- What version of beego are you using (bee version)?
- What operating system and processor architecture are you using?
- What did you do?
- What did you expect to see?
- What did you see instead?
### but check existing issues and docs first!
Please take a moment to check that an issue doesn't already exist
documenting your bug report or improvement proposal. If it does, it
never hurts to add a quick "+1" or "I have this problem too". This will
help prioritize the most common problems and requests.
Also if you don't know how to use it. please make sure you have read though
the docs in http://beego.me/docs

View File

@ -1,16 +1,38 @@
## Beego ## Beego
[![Build Status](https://drone.io/github.com/astaxie/beego/status.png)](https://drone.io/github.com/astaxie/beego/latest) [![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego)
[![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego) [![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego)
beego is an open-source, high-performance, modular, full-stack web framework. beego is used for rapid development of RESTful APIs, web apps and backend services in Go.
It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct embedding.
More info [beego.me](http://beego.me) More info [beego.me](http://beego.me)
## Installation ##Quick Start
######Download and install
go get github.com/astaxie/beego go get github.com/astaxie/beego
######Create file `hello.go`
```go
package main
import "github.com/astaxie/beego"
func main(){
beego.Run()
}
```
######Build and run
```bash
go build hello.go
./hello
```
######Congratulations!
You just built your first beego app.
Open your browser and visit `http://localhost:8000`.
Please see [Documentation](http://beego.me/docs) for more.
## Features ## Features
* RESTful support * RESTful support
@ -26,6 +48,7 @@ More info [beego.me](http://beego.me)
* [English](http://beego.me/docs/intro/) * [English](http://beego.me/docs/intro/)
* [中文文档](http://beego.me/docs/intro/) * [中文文档](http://beego.me/docs/intro/)
* [Русский](http://beego.me/docs/intro/)
## Community ## Community
@ -33,5 +56,5 @@ More info [beego.me](http://beego.me)
## LICENSE ## LICENSE
beego is licensed under the Apache Licence, Version 2.0 beego source code is licensed under the Apache Licence, Version 2.0
(http://www.apache.org/licenses/LICENSE-2.0.html). (http://www.apache.org/licenses/LICENSE-2.0.html).

429
admin.go
View File

@ -65,24 +65,15 @@ func init() {
// AdminIndex is the default http.Handler for admin module. // AdminIndex is the default http.Handler for admin module.
// it matches url pattern "/". // it matches url pattern "/".
func adminIndex(rw http.ResponseWriter, r *http.Request) { func adminIndex(rw http.ResponseWriter, r *http.Request) {
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) execTpl(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl)
tmpl = template.Must(tmpl.Parse(indexTpl))
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
data := make(map[interface{}]interface{})
tmpl.Execute(rw, data)
} }
// QpsIndex is the http.Handler for writing qbs statistics map result info in http.ResponseWriter. // QpsIndex is the http.Handler for writing qbs statistics map result info in http.ResponseWriter.
// it's registered with url pattern "/qbs" in admin module. // it's registered with url pattern "/qbs" in admin module.
func qpsIndex(rw http.ResponseWriter, r *http.Request) { func qpsIndex(rw http.ResponseWriter, r *http.Request) {
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
tmpl = template.Must(tmpl.Parse(qpsTpl))
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
data := make(map[interface{}]interface{}) data := make(map[interface{}]interface{})
data["Content"] = toolbox.StatisticsMap.GetMap() data["Content"] = toolbox.StatisticsMap.GetMap()
execTpl(rw, data, qpsTpl, defaultScriptsTpl)
tmpl.Execute(rw, data)
} }
// ListConf is the http.Handler of displaying all beego configuration values as key/value pair. // ListConf is the http.Handler of displaying all beego configuration values as key/value pair.
@ -90,178 +81,145 @@ func qpsIndex(rw http.ResponseWriter, r *http.Request) {
func listConf(rw http.ResponseWriter, r *http.Request) { func listConf(rw http.ResponseWriter, r *http.Request) {
r.ParseForm() r.ParseForm()
command := r.Form.Get("command") command := r.Form.Get("command")
if command != "" { if command == "" {
data := make(map[interface{}]interface{}) rw.Write([]byte("command not support"))
switch command { return
case "conf": }
m := make(map[string]interface{})
m["AppName"] = AppName data := make(map[interface{}]interface{})
m["AppPath"] = AppPath switch command {
m["AppConfigPath"] = AppConfigPath case "conf":
m["StaticDir"] = StaticDir m := make(map[string]interface{})
m["StaticExtensionsToGzip"] = StaticExtensionsToGzip m["AppConfigPath"] = AppConfigPath
m["HttpAddr"] = HttpAddr m["AppConfigProvider"] = AppConfigProvider
m["HttpPort"] = HttpPort m["BConfig.AppName"] = BConfig.AppName
m["HttpTLS"] = EnableHttpTLS m["BConfig.RunMode"] = BConfig.RunMode
m["HttpCertFile"] = HttpCertFile m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive
m["HttpKeyFile"] = HttpKeyFile m["BConfig.ServerName"] = BConfig.ServerName
m["RecoverPanic"] = RecoverPanic m["BConfig.RecoverPanic"] = BConfig.RecoverPanic
m["AutoRender"] = AutoRender m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody
m["ViewsPath"] = ViewsPath m["BConfig.EnableGzip"] = BConfig.EnableGzip
m["RunMode"] = RunMode m["BConfig.MaxMemory"] = BConfig.MaxMemory
m["SessionOn"] = SessionOn m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow
m["SessionProvider"] = SessionProvider m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful
m["SessionName"] = SessionName m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut
m["SessionGCMaxLifetime"] = SessionGCMaxLifetime m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4
m["SessionSavePath"] = SessionSavePath m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP
m["SessionCookieLifeTime"] = SessionCookieLifeTime m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr
m["UseFcgi"] = UseFcgi m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort
m["MaxMemory"] = MaxMemory m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS
m["EnableGzip"] = EnableGzip m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr
m["DirectoryIndex"] = DirectoryIndex m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort
m["HttpServerTimeOut"] = HttpServerTimeOut m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile
m["ErrorsShow"] = ErrorsShow m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile
m["XSRFKEY"] = XSRFKEY m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin
m["EnableXSRF"] = EnableXSRF m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr
m["XSRFExpire"] = XSRFExpire m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort
m["CopyRequestBody"] = CopyRequestBody m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi
m["TemplateLeft"] = TemplateLeft m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo
m["TemplateRight"] = TemplateRight m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender
m["BeegoServerName"] = BeegoServerName m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs
m["EnableAdmin"] = EnableAdmin m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName
m["AdminHttpAddr"] = AdminHttpAddr m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator
m["AdminHttpPort"] = AdminHttpPort m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex
m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir
m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip
m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft
m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight
m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath
m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF
m["BConfig.WebConfig.XSRFKEY"] = BConfig.WebConfig.XSRFKey
m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire
m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn
m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider
m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName
m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime
m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig
m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime
m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie
m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain
m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs
m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum
m["BConfig.Log.Outputs"] = BConfig.Log.Outputs
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
tmpl = template.Must(tmpl.Parse(configTpl))
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) data["Content"] = m
tmpl = template.Must(tmpl.Parse(configTpl))
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
data["Content"] = m tmpl.Execute(rw, data)
tmpl.Execute(rw, data) case "router":
var (
case "router": content = map[string]interface{}{
content := make(map[string]interface{}) "Fields": []string{
"Router Pattern",
var fields = []string{ "Methods",
fmt.Sprintf("Router Pattern"), "Controller",
fmt.Sprintf("Methods"), },
fmt.Sprintf("Controller"),
} }
content["Fields"] = fields methods = []string{}
methodsData = make(map[string]interface{})
)
for method, t := range BeeApp.Handlers.routers {
methods := []string{} resultList := new([][]string)
methodsData := make(map[string]interface{})
for method, t := range BeeApp.Handlers.routers {
resultList := new([][]string) printTree(resultList, t)
printTree(resultList, t) methods = append(methods, method)
methodsData[method] = resultList
methods = append(methods, method)
methodsData[method] = resultList
}
content["Data"] = methodsData
content["Methods"] = methods
data["Content"] = content
data["Title"] = "Routers"
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
tmpl = template.Must(tmpl.Parse(routerAndFilterTpl))
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
tmpl.Execute(rw, data)
case "filter":
content := make(map[string]interface{})
var fields = []string{
fmt.Sprintf("Router Pattern"),
fmt.Sprintf("Filter Function"),
}
content["Fields"] = fields
filterTypes := []string{}
filterTypeData := make(map[string]interface{})
if BeeApp.Handlers.enableFilter {
var filterType string
if bf, ok := BeeApp.Handlers.filters[BeforeRouter]; ok {
filterType = "Before Router"
filterTypes = append(filterTypes, filterType)
resultList := new([][]string)
for _, f := range bf {
var result = []string{
fmt.Sprintf("%s", f.pattern),
fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)),
}
*resultList = append(*resultList, result)
}
filterTypeData[filterType] = resultList
}
if bf, ok := BeeApp.Handlers.filters[BeforeExec]; ok {
filterType = "Before Exec"
filterTypes = append(filterTypes, filterType)
resultList := new([][]string)
for _, f := range bf {
var result = []string{
fmt.Sprintf("%s", f.pattern),
fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)),
}
*resultList = append(*resultList, result)
}
filterTypeData[filterType] = resultList
}
if bf, ok := BeeApp.Handlers.filters[AfterExec]; ok {
filterType = "After Exec"
filterTypes = append(filterTypes, filterType)
resultList := new([][]string)
for _, f := range bf {
var result = []string{
fmt.Sprintf("%s", f.pattern),
fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)),
}
*resultList = append(*resultList, result)
}
filterTypeData[filterType] = resultList
}
if bf, ok := BeeApp.Handlers.filters[FinishRouter]; ok {
filterType = "Finish Router"
filterTypes = append(filterTypes, filterType)
resultList := new([][]string)
for _, f := range bf {
var result = []string{
fmt.Sprintf("%s", f.pattern),
fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)),
}
*resultList = append(*resultList, result)
}
filterTypeData[filterType] = resultList
}
}
content["Data"] = filterTypeData
content["Methods"] = filterTypes
data["Content"] = content
data["Title"] = "Filters"
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
tmpl = template.Must(tmpl.Parse(routerAndFilterTpl))
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
tmpl.Execute(rw, data)
default:
rw.Write([]byte("command not support"))
} }
} else {
content["Data"] = methodsData
content["Methods"] = methods
data["Content"] = content
data["Title"] = "Routers"
execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl)
case "filter":
var (
content = map[string]interface{}{
"Fields": []string{
"Router Pattern",
"Filter Function",
},
}
filterTypes = []string{}
filterTypeData = make(map[string]interface{})
)
if BeeApp.Handlers.enableFilter {
var filterType string
for k, fr := range map[int]string{
BeforeStatic: "Before Static",
BeforeRouter: "Before Router",
BeforeExec: "Before Exec",
AfterExec: "After Exec",
FinishRouter: "Finish Router"} {
if bf, ok := BeeApp.Handlers.filters[k]; ok {
filterType = fr
filterTypes = append(filterTypes, filterType)
resultList := new([][]string)
for _, f := range bf {
var result = []string{
fmt.Sprintf("%s", f.pattern),
fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)),
}
*resultList = append(*resultList, result)
}
filterTypeData[filterType] = resultList
}
}
}
content["Data"] = filterTypeData
content["Methods"] = filterTypes
data["Content"] = content
data["Title"] = "Filters"
execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl)
default:
rw.Write([]byte("command not support"))
} }
} }
@ -276,23 +234,23 @@ func printTree(resultList *[][]string, t *Tree) {
if v, ok := l.runObject.(*controllerInfo); ok { if v, ok := l.runObject.(*controllerInfo); ok {
if v.routerType == routerTypeBeego { if v.routerType == routerTypeBeego {
var result = []string{ var result = []string{
fmt.Sprintf("%s", v.pattern), v.pattern,
fmt.Sprintf("%s", v.methods), fmt.Sprintf("%s", v.methods),
fmt.Sprintf("%s", v.controllerType), fmt.Sprintf("%s", v.controllerType),
} }
*resultList = append(*resultList, result) *resultList = append(*resultList, result)
} else if v.routerType == routerTypeRESTFul { } else if v.routerType == routerTypeRESTFul {
var result = []string{ var result = []string{
fmt.Sprintf("%s", v.pattern), v.pattern,
fmt.Sprintf("%s", v.methods), fmt.Sprintf("%s", v.methods),
fmt.Sprintf(""), "",
} }
*resultList = append(*resultList, result) *resultList = append(*resultList, result)
} else if v.routerType == routerTypeHandler { } else if v.routerType == routerTypeHandler {
var result = []string{ var result = []string{
fmt.Sprintf("%s", v.pattern), v.pattern,
fmt.Sprintf(""), "",
fmt.Sprintf(""), "",
} }
*resultList = append(*resultList, result) *resultList = append(*resultList, result)
} }
@ -305,53 +263,49 @@ func printTree(resultList *[][]string, t *Tree) {
func profIndex(rw http.ResponseWriter, r *http.Request) { func profIndex(rw http.ResponseWriter, r *http.Request) {
r.ParseForm() r.ParseForm()
command := r.Form.Get("command") command := r.Form.Get("command")
format := r.Form.Get("format") if command == "" {
data := make(map[string]interface{}) return
}
var result bytes.Buffer var (
if command != "" { format = r.Form.Get("format")
toolbox.ProcessInput(command, &result) data = make(map[interface{}]interface{})
data["Content"] = result.String() result bytes.Buffer
)
toolbox.ProcessInput(command, &result)
data["Content"] = result.String()
if format == "json" && command == "gc summary" { if format == "json" && command == "gc summary" {
dataJson, err := json.Marshal(data) dataJSON, err := json.Marshal(data)
if err != nil { if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError) http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
rw.Header().Set("Content-Type", "application/json")
rw.Write(dataJson)
return return
} }
data["Title"] = command rw.Header().Set("Content-Type", "application/json")
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) rw.Write(dataJSON)
tmpl = template.Must(tmpl.Parse(profillingTpl)) return
if command == "gc summary" {
tmpl = template.Must(tmpl.Parse(gcAjaxTpl))
} else {
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
}
tmpl.Execute(rw, data)
} }
data["Title"] = command
defaultTpl := defaultScriptsTpl
if command == "gc summary" {
defaultTpl = gcAjaxTpl
}
execTpl(rw, data, profillingTpl, defaultTpl)
} }
// Healthcheck is a http.Handler calling health checking and showing the result. // Healthcheck is a http.Handler calling health checking and showing the result.
// it's in "/healthcheck" pattern in admin module. // it's in "/healthcheck" pattern in admin module.
func healthcheck(rw http.ResponseWriter, req *http.Request) { func healthcheck(rw http.ResponseWriter, req *http.Request) {
data := make(map[interface{}]interface{}) var (
data = make(map[interface{}]interface{})
var result = []string{} result = []string{}
fields := []string{ resultList = new([][]string)
fmt.Sprintf("Name"), content = map[string]interface{}{
fmt.Sprintf("Message"), "Fields": []string{"Name", "Message", "Status"},
fmt.Sprintf("Status"), }
} )
resultList := new([][]string)
content := make(map[string]interface{})
for name, h := range toolbox.AdminCheckList { for name, h := range toolbox.AdminCheckList {
if err := h.Check(); err != nil { if err := h.Check(); err != nil {
@ -371,16 +325,10 @@ func healthcheck(rw http.ResponseWriter, req *http.Request) {
} }
*resultList = append(*resultList, result) *resultList = append(*resultList, result)
} }
content["Fields"] = fields
content["Data"] = resultList content["Data"] = resultList
data["Content"] = content data["Content"] = content
data["Title"] = "Health Check" data["Title"] = "Health Check"
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) execTpl(rw, data, healthCheckTpl, defaultScriptsTpl)
tmpl = template.Must(tmpl.Parse(healthCheckTpl))
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
tmpl.Execute(rw, data)
} }
// TaskStatus is a http.Handler with running task status (task name, status and the last execution). // TaskStatus is a http.Handler with running task status (task name, status and the last execution).
@ -392,10 +340,8 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) {
req.ParseForm() req.ParseForm()
taskname := req.Form.Get("taskname") taskname := req.Form.Get("taskname")
if taskname != "" { if taskname != "" {
if t, ok := toolbox.AdminTaskList[taskname]; ok { if t, ok := toolbox.AdminTaskList[taskname]; ok {
err := t.Run() if err := t.Run(); err != nil {
if err != nil {
data["Message"] = []string{"error", fmt.Sprintf("%s", err)} data["Message"] = []string{"error", fmt.Sprintf("%s", err)}
} }
data["Message"] = []string{"success", fmt.Sprintf("%s run success,Now the Status is <br>%s", taskname, t.GetStatus())} data["Message"] = []string{"success", fmt.Sprintf("%s run success,Now the Status is <br>%s", taskname, t.GetStatus())}
@ -409,18 +355,18 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) {
resultList := new([][]string) resultList := new([][]string)
var result = []string{} var result = []string{}
var fields = []string{ var fields = []string{
fmt.Sprintf("Task Name"), "Task Name",
fmt.Sprintf("Task Spec"), "Task Spec",
fmt.Sprintf("Task Status"), "Task Status",
fmt.Sprintf("Last Time"), "Last Time",
fmt.Sprintf(""), "",
} }
for tname, tk := range toolbox.AdminTaskList { for tname, tk := range toolbox.AdminTaskList {
result = []string{ result = []string{
fmt.Sprintf("%s", tname), tname,
fmt.Sprintf("%s", tk.GetSpec()), fmt.Sprintf("%s", tk.GetSpec()),
fmt.Sprintf("%s", tk.GetStatus()), fmt.Sprintf("%s", tk.GetStatus()),
fmt.Sprintf("%s", tk.GetPrev().String()), tk.GetPrev().String(),
} }
*resultList = append(*resultList, result) *resultList = append(*resultList, result)
} }
@ -429,9 +375,14 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) {
content["Data"] = resultList content["Data"] = resultList
data["Content"] = content data["Content"] = content
data["Title"] = "Tasks" data["Title"] = "Tasks"
execTpl(rw, data, tasksTpl, defaultScriptsTpl)
}
func execTpl(rw http.ResponseWriter, data map[interface{}]interface{}, tpls ...string) {
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
tmpl = template.Must(tmpl.Parse(tasksTpl)) for _, tpl := range tpls {
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) tmpl = template.Must(tmpl.Parse(tpl))
}
tmpl.Execute(rw, data) tmpl.Execute(rw, data)
} }
@ -451,10 +402,10 @@ func (admin *adminApp) Run() {
if len(toolbox.AdminTaskList) > 0 { if len(toolbox.AdminTaskList) > 0 {
toolbox.StartTask() toolbox.StartTask()
} }
addr := AdminHttpAddr addr := BConfig.Listen.AdminAddr
if AdminHttpPort != 0 { if BConfig.Listen.AdminPort != 0 {
addr = fmt.Sprintf("%s:%d", AdminHttpAddr, AdminHttpPort) addr = fmt.Sprintf("%s:%d", BConfig.Listen.AdminAddr, BConfig.Listen.AdminPort)
} }
for p, f := range admin.routers { for p, f := range admin.routers {
http.Handle(p, f) http.Handle(p, f)
@ -462,7 +413,7 @@ func (admin *adminApp) Run() {
BeeLogger.Info("Admin server Running on %s", addr) BeeLogger.Info("Admin server Running on %s", addr)
var err error var err error
if Graceful { if BConfig.Listen.Graceful {
err = grace.ListenAndServe(addr, nil) err = grace.ListenAndServe(addr, nil)
} else { } else {
err = http.ListenAndServe(addr, nil) err = http.ListenAndServe(addr, nil)

418
app.go
View File

@ -20,15 +20,26 @@ import (
"net/http" "net/http"
"net/http/fcgi" "net/http/fcgi"
"os" "os"
"path"
"time" "time"
"github.com/astaxie/beego/grace" "github.com/astaxie/beego/grace"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
var (
// BeeApp is an application instance
BeeApp *App
)
func init() {
// create beego application
BeeApp = NewApp()
}
// App defines beego application with a new PatternServeMux. // App defines beego application with a new PatternServeMux.
type App struct { type App struct {
Handlers *ControllerRegistor Handlers *ControllerRegister
Server *http.Server Server *http.Server
} }
@ -41,132 +52,311 @@ func NewApp() *App {
// Run beego application. // Run beego application.
func (app *App) Run() { func (app *App) Run() {
addr := HttpAddr addr := BConfig.Listen.HTTPAddr
if HttpPort != 0 { if BConfig.Listen.HTTPPort != 0 {
addr = fmt.Sprintf("%s:%d", HttpAddr, HttpPort) addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPAddr, BConfig.Listen.HTTPPort)
} }
var ( var (
err error err error
l net.Listener l net.Listener
endRunning = make(chan bool, 1)
) )
endRunning := make(chan bool, 1)
if UseFcgi { // run cgi server
if UseStdIo { if BConfig.Listen.EnableFcgi {
err = fcgi.Serve(nil, app.Handlers) // standard I/O if BConfig.Listen.EnableStdIo {
if err == nil { if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O
BeeLogger.Info("Use FCGI via standard I/O") BeeLogger.Info("Use FCGI via standard I/O")
} else { } else {
BeeLogger.Info("Cannot use FCGI via standard I/O", err) BeeLogger.Critical("Cannot use FCGI via standard I/O", err)
} }
return
}
if BConfig.Listen.HTTPPort == 0 {
// remove the Socket file before start
if utils.FileExists(addr) {
os.Remove(addr)
}
l, err = net.Listen("unix", addr)
} else { } else {
if HttpPort == 0 { l, err = net.Listen("tcp", addr)
// remove the Socket file before start }
if utils.FileExists(addr) { if err != nil {
os.Remove(addr) BeeLogger.Critical("Listen: ", err)
}
if err = fcgi.Serve(l, app.Handlers); err != nil {
BeeLogger.Critical("fcgi.Serve: ", err)
}
return
}
app.Server.Handler = app.Handlers
app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second
app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second
// run graceful mode
if BConfig.Listen.Graceful {
httpsAddr := BConfig.Listen.HTTPSAddr
app.Server.Addr = httpsAddr
if BConfig.Listen.EnableHTTPS {
go func() {
time.Sleep(20 * time.Microsecond)
if BConfig.Listen.HTTPSPort != 0 {
httpsAddr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort)
app.Server.Addr = httpsAddr
}
server := grace.NewServer(httpsAddr, app.Handlers)
server.Server.ReadTimeout = app.Server.ReadTimeout
server.Server.WriteTimeout = app.Server.WriteTimeout
if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
BeeLogger.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
}()
}
if BConfig.Listen.EnableHTTP {
go func() {
server := grace.NewServer(addr, app.Handlers)
server.Server.ReadTimeout = app.Server.ReadTimeout
server.Server.WriteTimeout = app.Server.WriteTimeout
if BConfig.Listen.ListenTCP4 {
server.Network = "tcp4"
}
if err := server.ListenAndServe(); err != nil {
BeeLogger.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
}()
}
<-endRunning
return
}
// run normal mode
app.Server.Addr = addr
if BConfig.Listen.EnableHTTPS {
go func() {
time.Sleep(20 * time.Microsecond)
if BConfig.Listen.HTTPSPort != 0 {
app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort)
}
BeeLogger.Info("https server Running on %s", app.Server.Addr)
if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
BeeLogger.Critical("ListenAndServeTLS: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
}()
}
if BConfig.Listen.EnableHTTP {
go func() {
app.Server.Addr = addr
BeeLogger.Info("http server Running on %s", app.Server.Addr)
if BConfig.Listen.ListenTCP4 {
ln, err := net.Listen("tcp4", app.Server.Addr)
if err != nil {
BeeLogger.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
return
}
if err = app.Server.Serve(ln); err != nil {
BeeLogger.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
return
} }
l, err = net.Listen("unix", addr)
} else { } else {
l, err = net.Listen("tcp", addr) if err := app.Server.ListenAndServe(); err != nil {
BeeLogger.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
} }
if err != nil { }()
BeeLogger.Critical("Listen: ", err)
}
err = fcgi.Serve(l, app.Handlers)
}
} else {
if Graceful {
app.Server.Addr = addr
app.Server.Handler = app.Handlers
app.Server.ReadTimeout = time.Duration(HttpServerTimeOut) * time.Second
app.Server.WriteTimeout = time.Duration(HttpServerTimeOut) * time.Second
if EnableHttpTLS {
go func() {
time.Sleep(20 * time.Microsecond)
if HttpsPort != 0 {
addr = fmt.Sprintf("%s:%d", HttpAddr, HttpsPort)
app.Server.Addr = addr
}
server := grace.NewServer(addr, app.Handlers)
server.Server = app.Server
err := server.ListenAndServeTLS(HttpCertFile, HttpKeyFile)
if err != nil {
BeeLogger.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
}()
}
if EnableHttpListen {
go func() {
server := grace.NewServer(addr, app.Handlers)
server.Server = app.Server
if ListenTCP4 && HttpAddr == "" {
server.Network = "tcp4"
}
err := server.ListenAndServe()
if err != nil {
BeeLogger.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
}()
}
} else {
app.Server.Addr = addr
app.Server.Handler = app.Handlers
app.Server.ReadTimeout = time.Duration(HttpServerTimeOut) * time.Second
app.Server.WriteTimeout = time.Duration(HttpServerTimeOut) * time.Second
if EnableHttpTLS {
go func() {
time.Sleep(20 * time.Microsecond)
if HttpsPort != 0 {
app.Server.Addr = fmt.Sprintf("%s:%d", HttpAddr, HttpsPort)
}
BeeLogger.Info("https server Running on %s", app.Server.Addr)
err := app.Server.ListenAndServeTLS(HttpCertFile, HttpKeyFile)
if err != nil {
BeeLogger.Critical("ListenAndServeTLS: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
}()
}
if EnableHttpListen {
go func() {
app.Server.Addr = addr
BeeLogger.Info("http server Running on %s", app.Server.Addr)
if ListenTCP4 && HttpAddr == "" {
ln, err := net.Listen("tcp4", app.Server.Addr)
if err != nil {
BeeLogger.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
return
}
err = app.Server.Serve(ln)
if err != nil {
BeeLogger.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
return
}
} else {
err := app.Server.ListenAndServe()
if err != nil {
BeeLogger.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
}
}()
}
}
} }
<-endRunning <-endRunning
} }
// Router adds a patterned controller handler to BeeApp.
// it's an alias method of App.Router.
// usage:
// simple router
// beego.Router("/admin", &admin.UserController{})
// beego.Router("/admin/index", &admin.ArticleController{})
//
// regex router
//
// beego.Router("/api/:id([0-9]+)", &controllers.RController{})
//
// custom rules
// beego.Router("/api/list",&RestController{},"*:ListFood")
// beego.Router("/api/create",&RestController{},"post:CreateFood")
// beego.Router("/api/update",&RestController{},"put:UpdateFood")
// beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App {
BeeApp.Handlers.Add(rootpath, c, mappingMethods...)
return BeeApp
}
// Include will generate router file in the router/xxx.go from the controller's comments
// usage:
// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{})
// type BankAccount struct{
// beego.Controller
// }
//
// register the function
// func (b *BankAccount)Mapping(){
// b.Mapping("ShowAccount" , b.ShowAccount)
// b.Mapping("ModifyAccount", b.ModifyAccount)
//}
//
// //@router /account/:id [get]
// func (b *BankAccount) ShowAccount(){
// //logic
// }
//
//
// //@router /account/:id [post]
// func (b *BankAccount) ModifyAccount(){
// //logic
// }
//
// the comments @router url methodlist
// url support all the function Router's pattern
// methodlist [get post head put delete options *]
func Include(cList ...ControllerInterface) *App {
BeeApp.Handlers.Include(cList...)
return BeeApp
}
// RESTRouter adds a restful controller handler to BeeApp.
// its' controller implements beego.ControllerInterface and
// defines a param "pattern/:objectId" to visit each resource.
func RESTRouter(rootpath string, c ControllerInterface) *App {
Router(rootpath, c)
Router(path.Join(rootpath, ":objectId"), c)
return BeeApp
}
// AutoRouter adds defined controller handler to BeeApp.
// it's same to App.AutoRouter.
// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page,
// visit the url /main/list to exec List function or /main/page to exec Page function.
func AutoRouter(c ControllerInterface) *App {
BeeApp.Handlers.AddAuto(c)
return BeeApp
}
// AutoPrefix adds controller handler to BeeApp with prefix.
// it's same to App.AutoRouterWithPrefix.
// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page,
// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function.
func AutoPrefix(prefix string, c ControllerInterface) *App {
BeeApp.Handlers.AddAutoPrefix(prefix, c)
return BeeApp
}
// Get used to register router for Get method
// usage:
// beego.Get("/", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Get(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Get(rootpath, f)
return BeeApp
}
// Post used to register router for Post method
// usage:
// beego.Post("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Post(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Post(rootpath, f)
return BeeApp
}
// Delete used to register router for Delete method
// usage:
// beego.Delete("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Delete(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Delete(rootpath, f)
return BeeApp
}
// Put used to register router for Put method
// usage:
// beego.Put("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Put(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Put(rootpath, f)
return BeeApp
}
// Head used to register router for Head method
// usage:
// beego.Head("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Head(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Head(rootpath, f)
return BeeApp
}
// Options used to register router for Options method
// usage:
// beego.Options("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Options(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Options(rootpath, f)
return BeeApp
}
// Patch used to register router for Patch method
// usage:
// beego.Patch("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Patch(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Patch(rootpath, f)
return BeeApp
}
// Any used to register router for all methods
// usage:
// beego.Any("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Any(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Any(rootpath, f)
return BeeApp
}
// Handler used to register a Handler router
// usage:
// beego.Handler("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Handler(rootpath string, h http.Handler, options ...interface{}) *App {
BeeApp.Handlers.Handler(rootpath, h, options...)
return BeeApp
}
// InsertFilter adds a FilterFunc with pattern condition and action constant.
// The pos means action constant including
// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter.
// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute)
func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App {
BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...)
return BeeApp
}

353
beego.go
View File

@ -12,243 +12,35 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// beego is an open-source, high-performance, modularity, full-stack web framework
//
// package main
//
// import "github.com/astaxie/beego"
//
// func main() {
// beego.Run()
// }
//
// more infomation: http://beego.me
package beego package beego
import ( import (
"net/http" "fmt"
"os" "os"
"path"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"github.com/astaxie/beego/session"
) )
// beego web framework version. const (
const VERSION = "1.5.0" // VERSION represent beego web framework version.
VERSION = "1.6.0"
type hookfunc func() error //hook function to run // DEV is for develop
var hooks []hookfunc //hook function slice to store the hookfunc DEV = "dev"
// PROD is for production
PROD = "prod"
)
// Router adds a patterned controller handler to BeeApp. //hook function to run
// it's an alias method of App.Router. type hookfunc func() error
// usage:
// simple router
// beego.Router("/admin", &admin.UserController{})
// beego.Router("/admin/index", &admin.ArticleController{})
//
// regex router
//
// beego.Router("/api/:id([0-9]+)", &controllers.RController{})
//
// custom rules
// beego.Router("/api/list",&RestController{},"*:ListFood")
// beego.Router("/api/create",&RestController{},"post:CreateFood")
// beego.Router("/api/update",&RestController{},"put:UpdateFood")
// beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App {
BeeApp.Handlers.Add(rootpath, c, mappingMethods...)
return BeeApp
}
// Router add list from var (
// usage: hooks = make([]hookfunc, 0) //hook function slice to store the hookfunc
// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) )
// type BankAccount struct{
// beego.Controller
// }
//
// register the function
// func (b *BankAccount)Mapping(){
// b.Mapping("ShowAccount" , b.ShowAccount)
// b.Mapping("ModifyAccount", b.ModifyAccount)
//}
//
// //@router /account/:id [get]
// func (b *BankAccount) ShowAccount(){
// //logic
// }
//
//
// //@router /account/:id [post]
// func (b *BankAccount) ModifyAccount(){
// //logic
// }
//
// the comments @router url methodlist
// url support all the function Router's pattern
// methodlist [get post head put delete options *]
func Include(cList ...ControllerInterface) *App {
BeeApp.Handlers.Include(cList...)
return BeeApp
}
// RESTRouter adds a restful controller handler to BeeApp. // AddAPPStartHook is used to register the hookfunc
// its' controller implements beego.ControllerInterface and // The hookfuncs will run in beego.Run()
// defines a param "pattern/:objectId" to visit each resource.
func RESTRouter(rootpath string, c ControllerInterface) *App {
Router(rootpath, c)
Router(path.Join(rootpath, ":objectId"), c)
return BeeApp
}
// AutoRouter adds defined controller handler to BeeApp.
// it's same to App.AutoRouter.
// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page,
// visit the url /main/list to exec List function or /main/page to exec Page function.
func AutoRouter(c ControllerInterface) *App {
BeeApp.Handlers.AddAuto(c)
return BeeApp
}
// AutoPrefix adds controller handler to BeeApp with prefix.
// it's same to App.AutoRouterWithPrefix.
// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page,
// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function.
func AutoPrefix(prefix string, c ControllerInterface) *App {
BeeApp.Handlers.AddAutoPrefix(prefix, c)
return BeeApp
}
// register router for Get method
// usage:
// beego.Get("/", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Get(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Get(rootpath, f)
return BeeApp
}
// register router for Post method
// usage:
// beego.Post("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Post(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Post(rootpath, f)
return BeeApp
}
// register router for Delete method
// usage:
// beego.Delete("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Delete(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Delete(rootpath, f)
return BeeApp
}
// register router for Put method
// usage:
// beego.Put("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Put(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Put(rootpath, f)
return BeeApp
}
// register router for Head method
// usage:
// beego.Head("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Head(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Head(rootpath, f)
return BeeApp
}
// register router for Options method
// usage:
// beego.Options("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Options(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Options(rootpath, f)
return BeeApp
}
// register router for Patch method
// usage:
// beego.Patch("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Patch(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Patch(rootpath, f)
return BeeApp
}
// register router for all method
// usage:
// beego.Any("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Any(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Any(rootpath, f)
return BeeApp
}
// register router for own Handler
// usage:
// beego.Handler("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Handler(rootpath string, h http.Handler, options ...interface{}) *App {
BeeApp.Handlers.Handler(rootpath, h, options...)
return BeeApp
}
// SetViewsPath sets view directory path in beego application.
func SetViewsPath(path string) *App {
ViewsPath = path
return BeeApp
}
// SetStaticPath sets static directory path and proper url pattern in beego application.
// if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public".
func SetStaticPath(url string, path string) *App {
if !strings.HasPrefix(url, "/") {
url = "/" + url
}
url = strings.TrimRight(url, "/")
StaticDir[url] = path
return BeeApp
}
// DelStaticPath removes the static folder setting in this url pattern in beego application.
func DelStaticPath(url string) *App {
if !strings.HasPrefix(url, "/") {
url = "/" + url
}
url = strings.TrimRight(url, "/")
delete(StaticDir, url)
return BeeApp
}
// InsertFilter adds a FilterFunc with pattern condition and action constant.
// The pos means action constant including
// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter.
// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute)
func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App {
BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...)
return BeeApp
}
// The hookfunc will run in beego.Run()
// such as sessionInit, middlerware start, buildtemplate, admin start // such as sessionInit, middlerware start, buildtemplate, admin start
func AddAPPStartHook(hf hookfunc) { func AddAPPStartHook(hf hookfunc) {
hooks = append(hooks, hf) hooks = append(hooks, hf)
@ -256,97 +48,60 @@ func AddAPPStartHook(hf hookfunc) {
// Run beego application. // Run beego application.
// beego.Run() default run on HttpPort // beego.Run() default run on HttpPort
// beego.Run("localhost")
// beego.Run(":8089") // beego.Run(":8089")
// beego.Run("127.0.0.1:8089") // beego.Run("127.0.0.1:8089")
func Run(params ...string) { func Run(params ...string) {
initBeforeHTTPRun()
if len(params) > 0 && params[0] != "" { if len(params) > 0 && params[0] != "" {
strs := strings.Split(params[0], ":") strs := strings.Split(params[0], ":")
if len(strs) > 0 && strs[0] != "" { if len(strs) > 0 && strs[0] != "" {
HttpAddr = strs[0] BConfig.Listen.HTTPAddr = strs[0]
} }
if len(strs) > 1 && strs[1] != "" { if len(strs) > 1 && strs[1] != "" {
HttpPort, _ = strconv.Atoi(strs[1]) BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1])
} }
} }
initBeforeHttpRun()
if EnableAdmin {
go beeAdminApp.Run()
}
BeeApp.Run() BeeApp.Run()
} }
func initBeforeHttpRun() { func initBeforeHTTPRun() {
// if AppConfigPath not In the conf/app.conf reParse config // if AppConfigPath is setted or conf/app.conf exist
if AppConfigPath != filepath.Join(AppPath, "conf", "app.conf") {
err := ParseConfig()
if err != nil && AppConfigPath != filepath.Join(workPath, "conf", "app.conf") {
// configuration is critical to app, panic here if parse failed
panic(err)
}
}
//init mime
AddAPPStartHook(initMime)
// do hooks function
for _, hk := range hooks {
err := hk()
if err != nil {
panic(err)
}
}
if SessionOn {
var err error
sessionConfig := AppConfig.String("sessionConfig")
if sessionConfig == "" {
sessionConfig = `{"cookieName":"` + SessionName + `",` +
`"gclifetime":` + strconv.FormatInt(SessionGCMaxLifetime, 10) + `,` +
`"providerConfig":"` + filepath.ToSlash(SessionSavePath) + `",` +
`"secure":` + strconv.FormatBool(EnableHttpTLS) + `,` +
`"enableSetCookie":` + strconv.FormatBool(SessionAutoSetCookie) + `,` +
`"domain":"` + SessionDomain + `",` +
`"cookieLifeTime":` + strconv.Itoa(SessionCookieLifeTime) + `}`
}
GlobalSessions, err = session.NewManager(SessionProvider,
sessionConfig)
if err != nil {
panic(err)
}
go GlobalSessions.GC()
}
err := BuildTemplate(ViewsPath)
if err != nil {
if RunMode == "dev" {
Warn(err)
}
}
registerDefaultErrorHandler()
if EnableDocs {
Get("/docs", serverDocs)
Get("/docs/*", serverDocs)
}
}
// this function is for test package init
func TestBeegoInit(apppath string) {
AppPath = apppath
os.Setenv("BEEGO_RUNMODE", "test")
AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
err := ParseConfig() err := ParseConfig()
if err != nil && !os.IsNotExist(err) { if err != nil {
// for init if doesn't have app.conf will not panic panic(err)
Info(err) }
//init log
for adaptor, config := range BConfig.Log.Outputs {
err = BeeLogger.SetLogger(adaptor, config)
if err != nil {
fmt.Printf("%s with the config `%s` got err:%s\n", adaptor, config, err)
}
}
SetLogFuncCall(BConfig.Log.FileLineNum)
//init hooks
AddAPPStartHook(registerMime)
AddAPPStartHook(registerDefaultErrorHandler)
AddAPPStartHook(registerSession)
AddAPPStartHook(registerDocs)
AddAPPStartHook(registerTemplate)
AddAPPStartHook(registerAdmin)
for _, hk := range hooks {
if err := hk(); err != nil {
panic(err)
}
} }
os.Chdir(AppPath)
initBeforeHttpRun()
} }
func init() { // TestBeegoInit is for test package init
hooks = make([]hookfunc, 0) func TestBeegoInit(ap string) {
os.Setenv("BEEGO_RUNMODE", "test")
AppConfigPath = filepath.Join(ap, "conf", "app.conf")
os.Chdir(ap)
initBeforeHTTPRun()
} }

6
cache/README.md vendored
View File

@ -26,7 +26,7 @@ Then init a Cache (example with memory adapter)
Use it like this: Use it like this:
bm.Put("astaxie", 1, 10) bm.Put("astaxie", 1, 10 * time.Second)
bm.Get("astaxie") bm.Get("astaxie")
bm.IsExist("astaxie") bm.IsExist("astaxie")
bm.Delete("astaxie") bm.Delete("astaxie")
@ -43,7 +43,7 @@ interval means the gc time. The cache will check at each time interval, whether
## Memcache adapter ## Memcache adapter
Memcache adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client. Memcache adapter use the [gomemcache](http://github.com/bradfitz/gomemcache) client.
Configure like this: Configure like this:
@ -52,7 +52,7 @@ Configure like this:
## Redis adapter ## Redis adapter
Redis adapter use the [redigo](http://github.com/garyburd/redigo/redis) client. Redis adapter use the [redigo](http://github.com/garyburd/redigo) client.
Configure like this: Configure like this:

21
cache/cache.go vendored
View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Package cache provide a Cache interface and some implemetn engine
// Usage: // Usage:
// //
// import( // import(
@ -22,7 +23,7 @@
// //
// Use it like this: // Use it like this:
// //
// bm.Put("astaxie", 1, 10) // bm.Put("astaxie", 1, 10 * time.Second)
// bm.Get("astaxie") // bm.Get("astaxie")
// bm.IsExist("astaxie") // bm.IsExist("astaxie")
// bm.Delete("astaxie") // bm.Delete("astaxie")
@ -32,13 +33,14 @@ package cache
import ( import (
"fmt" "fmt"
"time"
) )
// Cache interface contains all behaviors for cache adapter. // Cache interface contains all behaviors for cache adapter.
// usage: // usage:
// cache.Register("file",cache.NewFileCache()) // this operation is run in init method of file.go. // cache.Register("file",cache.NewFileCache) // this operation is run in init method of file.go.
// c,err := cache.NewCache("file","{....}") // c,err := cache.NewCache("file","{....}")
// c.Put("key",value,3600) // c.Put("key",value, 3600 * time.Second)
// v := c.Get("key") // v := c.Get("key")
// //
// c.Incr("counter") // now is 1 // c.Incr("counter") // now is 1
@ -50,7 +52,7 @@ type Cache interface {
// GetMulti is a batch version of Get. // GetMulti is a batch version of Get.
GetMulti(keys []string) []interface{} GetMulti(keys []string) []interface{}
// set cached value with key and expire time. // set cached value with key and expire time.
Put(key string, val interface{}, timeout int64) error Put(key string, val interface{}, timeout time.Duration) error
// delete cached value by key. // delete cached value by key.
Delete(key string) error Delete(key string) error
// increase cached int value by key, as a counter. // increase cached int value by key, as a counter.
@ -65,12 +67,14 @@ type Cache interface {
StartAndGC(config string) error StartAndGC(config string) error
} }
var adapters = make(map[string]Cache) type CacheInstance func() Cache
var adapters = make(map[string]CacheInstance)
// Register makes a cache adapter available by the adapter name. // Register makes a cache adapter available by the adapter name.
// If Register is called twice with the same name or if driver is nil, // If Register is called twice with the same name or if driver is nil,
// it panics. // it panics.
func Register(name string, adapter Cache) { func Register(name string, adapter CacheInstance) {
if adapter == nil { if adapter == nil {
panic("cache: Register adapter is nil") panic("cache: Register adapter is nil")
} }
@ -80,15 +84,16 @@ func Register(name string, adapter Cache) {
adapters[name] = adapter adapters[name] = adapter
} }
// Create a new cache driver by adapter name and config string. // NewCache Create a new cache driver by adapter name and config string.
// config need to be correct JSON as string: {"interval":360}. // config need to be correct JSON as string: {"interval":360}.
// it will start gc automatically. // it will start gc automatically.
func NewCache(adapterName, config string) (adapter Cache, err error) { func NewCache(adapterName, config string) (adapter Cache, err error) {
adapter, ok := adapters[adapterName] instanceFunc, ok := adapters[adapterName]
if !ok { if !ok {
err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName)
return return
} }
adapter = instanceFunc()
err = adapter.StartAndGC(config) err = adapter.StartAndGC(config)
if err != nil { if err != nil {
adapter = nil adapter = nil

16
cache/cache_test.go vendored
View File

@ -25,7 +25,8 @@ func TestCache(t *testing.T) {
if err != nil { if err != nil {
t.Error("init err") t.Error("init err")
} }
if err = bm.Put("astaxie", 1, 10); err != nil { timeoutDuration := 10 * time.Second
if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err) t.Error("set Error", err)
} }
if !bm.IsExist("astaxie") { if !bm.IsExist("astaxie") {
@ -42,7 +43,7 @@ func TestCache(t *testing.T) {
t.Error("check err") t.Error("check err")
} }
if err = bm.Put("astaxie", 1, 10); err != nil { if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err) t.Error("set Error", err)
} }
@ -67,7 +68,7 @@ func TestCache(t *testing.T) {
} }
//test GetMulti //test GetMulti
if err = bm.Put("astaxie", "author", 10); err != nil { if err = bm.Put("astaxie", "author", timeoutDuration); err != nil {
t.Error("set Error", err) t.Error("set Error", err)
} }
if !bm.IsExist("astaxie") { if !bm.IsExist("astaxie") {
@ -77,7 +78,7 @@ func TestCache(t *testing.T) {
t.Error("get err") t.Error("get err")
} }
if err = bm.Put("astaxie1", "author1", 10); err != nil { if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil {
t.Error("set Error", err) t.Error("set Error", err)
} }
if !bm.IsExist("astaxie1") { if !bm.IsExist("astaxie1") {
@ -101,7 +102,8 @@ func TestFileCache(t *testing.T) {
if err != nil { if err != nil {
t.Error("init err") t.Error("init err")
} }
if err = bm.Put("astaxie", 1, 10); err != nil { timeoutDuration := 10 * time.Second
if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err) t.Error("set Error", err)
} }
if !bm.IsExist("astaxie") { if !bm.IsExist("astaxie") {
@ -133,7 +135,7 @@ func TestFileCache(t *testing.T) {
} }
//test string //test string
if err = bm.Put("astaxie", "author", 10); err != nil { if err = bm.Put("astaxie", "author", timeoutDuration); err != nil {
t.Error("set Error", err) t.Error("set Error", err)
} }
if !bm.IsExist("astaxie") { if !bm.IsExist("astaxie") {
@ -144,7 +146,7 @@ func TestFileCache(t *testing.T) {
} }
//test GetMulti //test GetMulti
if err = bm.Put("astaxie1", "author1", 10); err != nil { if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil {
t.Error("set Error", err) t.Error("set Error", err)
} }
if !bm.IsExist("astaxie1") { if !bm.IsExist("astaxie1") {

22
cache/conv.go vendored
View File

@ -19,7 +19,7 @@ import (
"strconv" "strconv"
) )
// convert interface to string. // GetString convert interface to string.
func GetString(v interface{}) string { func GetString(v interface{}) string {
switch result := v.(type) { switch result := v.(type) {
case string: case string:
@ -34,7 +34,7 @@ func GetString(v interface{}) string {
return "" return ""
} }
// convert interface to int. // GetInt convert interface to int.
func GetInt(v interface{}) int { func GetInt(v interface{}) int {
switch result := v.(type) { switch result := v.(type) {
case int: case int:
@ -52,7 +52,7 @@ func GetInt(v interface{}) int {
return 0 return 0
} }
// convert interface to int64. // GetInt64 convert interface to int64.
func GetInt64(v interface{}) int64 { func GetInt64(v interface{}) int64 {
switch result := v.(type) { switch result := v.(type) {
case int: case int:
@ -71,7 +71,7 @@ func GetInt64(v interface{}) int64 {
return 0 return 0
} }
// convert interface to float64. // GetFloat64 convert interface to float64.
func GetFloat64(v interface{}) float64 { func GetFloat64(v interface{}) float64 {
switch result := v.(type) { switch result := v.(type) {
case float64: case float64:
@ -85,7 +85,7 @@ func GetFloat64(v interface{}) float64 {
return 0 return 0
} }
// convert interface to bool. // GetBool convert interface to bool.
func GetBool(v interface{}) bool { func GetBool(v interface{}) bool {
switch result := v.(type) { switch result := v.(type) {
case bool: case bool:
@ -98,15 +98,3 @@ func GetBool(v interface{}) bool {
} }
return false return false
} }
// convert interface to byte slice.
func getByteArray(v interface{}) []byte {
switch result := v.(type) {
case []byte:
return result
case string:
return []byte(result)
default:
return nil
}
}

29
cache/conv_test.go vendored
View File

@ -27,7 +27,7 @@ func TestGetString(t *testing.T) {
if "test2" != GetString(t2) { if "test2" != GetString(t2) {
t.Error("get string from byte array error") t.Error("get string from byte array error")
} }
var t3 int = 1 var t3 = 1
if "1" != GetString(t3) { if "1" != GetString(t3) {
t.Error("get string from int error") t.Error("get string from int error")
} }
@ -35,7 +35,7 @@ func TestGetString(t *testing.T) {
if "1" != GetString(t4) { if "1" != GetString(t4) {
t.Error("get string from int64 error") t.Error("get string from int64 error")
} }
var t5 float64 = 1.1 var t5 = 1.1
if "1.1" != GetString(t5) { if "1.1" != GetString(t5) {
t.Error("get string from float64 error") t.Error("get string from float64 error")
} }
@ -46,7 +46,7 @@ func TestGetString(t *testing.T) {
} }
func TestGetInt(t *testing.T) { func TestGetInt(t *testing.T) {
var t1 int = 1 var t1 = 1
if 1 != GetInt(t1) { if 1 != GetInt(t1) {
t.Error("get int from int error") t.Error("get int from int error")
} }
@ -69,7 +69,7 @@ func TestGetInt(t *testing.T) {
func TestGetInt64(t *testing.T) { func TestGetInt64(t *testing.T) {
var i int64 = 1 var i int64 = 1
var t1 int = 1 var t1 = 1
if i != GetInt64(t1) { if i != GetInt64(t1) {
t.Error("get int64 from int error") t.Error("get int64 from int error")
} }
@ -91,12 +91,12 @@ func TestGetInt64(t *testing.T) {
} }
func TestGetFloat64(t *testing.T) { func TestGetFloat64(t *testing.T) {
var f float64 = 1.11 var f = 1.11
var t1 float32 = 1.11 var t1 float32 = 1.11
if f != GetFloat64(t1) { if f != GetFloat64(t1) {
t.Error("get float64 from float32 error") t.Error("get float64 from float32 error")
} }
var t2 float64 = 1.11 var t2 = 1.11
if f != GetFloat64(t2) { if f != GetFloat64(t2) {
t.Error("get float64 from float64 error") t.Error("get float64 from float64 error")
} }
@ -106,7 +106,7 @@ func TestGetFloat64(t *testing.T) {
} }
var f2 float64 = 1 var f2 float64 = 1
var t4 int = 1 var t4 = 1
if f2 != GetFloat64(t4) { if f2 != GetFloat64(t4) {
t.Error("get float64 from int error") t.Error("get float64 from int error")
} }
@ -130,21 +130,6 @@ func TestGetBool(t *testing.T) {
} }
} }
func TestGetByteArray(t *testing.T) {
var b = []byte("test")
var t1 = []byte("test")
if !byteArrayEquals(b, getByteArray(t1)) {
t.Error("get byte array from byte array error")
}
var t2 = "test"
if !byteArrayEquals(b, getByteArray(t2)) {
t.Error("get byte array from string error")
}
if nil != getByteArray(nil) {
t.Error("get byte array from nil error")
}
}
func byteArrayEquals(a []byte, b []byte) bool { func byteArrayEquals(a []byte, b []byte) bool {
if len(a) != len(b) { if len(a) != len(b) {
return false return false

71
cache/file.go vendored
View File

@ -29,23 +29,20 @@ import (
"time" "time"
) )
func init() {
Register("file", NewFileCache())
}
// FileCacheItem is basic unit of file cache adapter. // FileCacheItem is basic unit of file cache adapter.
// it contains data and expire time. // it contains data and expire time.
type FileCacheItem struct { type FileCacheItem struct {
Data interface{} Data interface{}
Lastaccess int64 Lastaccess time.Time
Expired int64 Expired time.Time
} }
// FileCache Config
var ( var (
FileCachePath string = "cache" // cache directory FileCachePath = "cache" // cache directory
FileCacheFileSuffix string = ".bin" // cache file suffix FileCacheFileSuffix = ".bin" // cache file suffix
FileCacheDirectoryLevel int = 2 // cache file deep level if auto generated cache files. FileCacheDirectoryLevel = 2 // cache file deep level if auto generated cache files.
FileCacheEmbedExpiry int64 = 0 // cache expire time, default is no expire forever. FileCacheEmbedExpiry time.Duration = 0 // cache expire time, default is no expire forever.
) )
// FileCache is cache adapter for file storage. // FileCache is cache adapter for file storage.
@ -56,14 +53,14 @@ type FileCache struct {
EmbedExpiry int EmbedExpiry int
} }
// Create new file cache with no config. // NewFileCache Create new file cache with no config.
// the level and expiry need set in method StartAndGC as config string. // the level and expiry need set in method StartAndGC as config string.
func NewFileCache() *FileCache { func NewFileCache() Cache {
// return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix} // return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix}
return &FileCache{} return &FileCache{}
} }
// Start and begin gc for file cache. // StartAndGC will start and begin gc for file cache.
// the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0} // the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0}
func (fc *FileCache) StartAndGC(config string) error { func (fc *FileCache) StartAndGC(config string) error {
@ -79,7 +76,7 @@ func (fc *FileCache) StartAndGC(config string) error {
cfg["DirectoryLevel"] = strconv.Itoa(FileCacheDirectoryLevel) cfg["DirectoryLevel"] = strconv.Itoa(FileCacheDirectoryLevel)
} }
if _, ok := cfg["EmbedExpiry"]; !ok { if _, ok := cfg["EmbedExpiry"]; !ok {
cfg["EmbedExpiry"] = strconv.FormatInt(FileCacheEmbedExpiry, 10) cfg["EmbedExpiry"] = strconv.FormatInt(int64(FileCacheEmbedExpiry.Seconds()), 10)
} }
fc.CachePath = cfg["CachePath"] fc.CachePath = cfg["CachePath"]
fc.FileSuffix = cfg["FileSuffix"] fc.FileSuffix = cfg["FileSuffix"]
@ -120,13 +117,13 @@ func (fc *FileCache) getCacheFileName(key string) string {
// Get value from file cache. // Get value from file cache.
// if non-exist or expired, return empty string. // if non-exist or expired, return empty string.
func (fc *FileCache) Get(key string) interface{} { func (fc *FileCache) Get(key string) interface{} {
fileData, err := File_get_contents(fc.getCacheFileName(key)) fileData, err := FileGetContents(fc.getCacheFileName(key))
if err != nil { if err != nil {
return "" return ""
} }
var to FileCacheItem var to FileCacheItem
Gob_decode(fileData, &to) GobDecode(fileData, &to)
if to.Expired < time.Now().Unix() { if to.Expired.Before(time.Now()) {
return "" return ""
} }
return to.Data return to.Data
@ -145,21 +142,21 @@ func (fc *FileCache) GetMulti(keys []string) []interface{} {
// Put value into file cache. // Put value into file cache.
// timeout means how long to keep this file, unit of ms. // timeout means how long to keep this file, unit of ms.
// if timeout equals FileCacheEmbedExpiry(default is 0), cache this item forever. // if timeout equals FileCacheEmbedExpiry(default is 0), cache this item forever.
func (fc *FileCache) Put(key string, val interface{}, timeout int64) error { func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error {
gob.Register(val) gob.Register(val)
item := FileCacheItem{Data: val} item := FileCacheItem{Data: val}
if timeout == FileCacheEmbedExpiry { if timeout == FileCacheEmbedExpiry {
item.Expired = time.Now().Unix() + (86400 * 365 * 10) // ten years item.Expired = time.Now().Add((86400 * 365 * 10) * time.Second) // ten years
} else { } else {
item.Expired = time.Now().Unix() + timeout item.Expired = time.Now().Add(timeout)
} }
item.Lastaccess = time.Now().Unix() item.Lastaccess = time.Now()
data, err := Gob_encode(item) data, err := GobEncode(item)
if err != nil { if err != nil {
return err return err
} }
return File_put_contents(fc.getCacheFileName(key), data) return FilePutContents(fc.getCacheFileName(key), data)
} }
// Delete file cache value. // Delete file cache value.
@ -171,7 +168,7 @@ func (fc *FileCache) Delete(key string) error {
return nil return nil
} }
// Increase cached int value. // Incr will increase cached int value.
// fc value is saving forever unless Delete. // fc value is saving forever unless Delete.
func (fc *FileCache) Incr(key string) error { func (fc *FileCache) Incr(key string) error {
data := fc.Get(key) data := fc.Get(key)
@ -185,7 +182,7 @@ func (fc *FileCache) Incr(key string) error {
return nil return nil
} }
// Decrease cached int value. // Decr will decrease cached int value.
func (fc *FileCache) Decr(key string) error { func (fc *FileCache) Decr(key string) error {
data := fc.Get(key) data := fc.Get(key)
var decr int var decr int
@ -198,13 +195,13 @@ func (fc *FileCache) Decr(key string) error {
return nil return nil
} }
// Check value is exist. // IsExist check value is exist.
func (fc *FileCache) IsExist(key string) bool { func (fc *FileCache) IsExist(key string) bool {
ret, _ := exists(fc.getCacheFileName(key)) ret, _ := exists(fc.getCacheFileName(key))
return ret return ret
} }
// Clean cached files. // ClearAll will clean cached files.
// not implemented. // not implemented.
func (fc *FileCache) ClearAll() error { func (fc *FileCache) ClearAll() error {
return nil return nil
@ -222,9 +219,9 @@ func exists(path string) (bool, error) {
return false, err return false, err
} }
// Get bytes to file. // FileGetContents Get bytes to file.
// if non-exist, create this file. // if non-exist, create this file.
func File_get_contents(filename string) (data []byte, e error) { func FileGetContents(filename string) (data []byte, e error) {
f, e := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, os.ModePerm) f, e := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, os.ModePerm)
if e != nil { if e != nil {
return return
@ -242,9 +239,9 @@ func File_get_contents(filename string) (data []byte, e error) {
return return
} }
// Put bytes to file. // FilePutContents Put bytes to file.
// if non-exist, create this file. // if non-exist, create this file.
func File_put_contents(filename string, content []byte) error { func FilePutContents(filename string, content []byte) error {
fp, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, os.ModePerm) fp, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, os.ModePerm)
if err != nil { if err != nil {
return err return err
@ -254,8 +251,8 @@ func File_put_contents(filename string, content []byte) error {
return err return err
} }
// Gob encodes file cache item. // GobEncode Gob encodes file cache item.
func Gob_encode(data interface{}) ([]byte, error) { func GobEncode(data interface{}) ([]byte, error) {
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
enc := gob.NewEncoder(buf) enc := gob.NewEncoder(buf)
err := enc.Encode(data) err := enc.Encode(data)
@ -265,9 +262,13 @@ func Gob_encode(data interface{}) ([]byte, error) {
return buf.Bytes(), err return buf.Bytes(), err
} }
// Gob decodes file cache item. // GobDecode Gob decodes file cache item.
func Gob_decode(data []byte, to *FileCacheItem) error { func GobDecode(data []byte, to *FileCacheItem) error {
buf := bytes.NewBuffer(data) buf := bytes.NewBuffer(data)
dec := gob.NewDecoder(buf) dec := gob.NewDecoder(buf)
return dec.Decode(&to) return dec.Decode(&to)
} }
func init() {
Register("file", NewFileCache)
}

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// package memcahe for cache provider // Package memcache for cache provider
// //
// depend on github.com/bradfitz/gomemcache/memcache // depend on github.com/bradfitz/gomemcache/memcache
// //
@ -37,21 +37,22 @@ import (
"github.com/bradfitz/gomemcache/memcache" "github.com/bradfitz/gomemcache/memcache"
"github.com/astaxie/beego/cache" "github.com/astaxie/beego/cache"
"time"
) )
// Memcache adapter. // Cache Memcache adapter.
type MemcacheCache struct { type Cache struct {
conn *memcache.Client conn *memcache.Client
conninfo []string conninfo []string
} }
// create new memcache adapter. // NewMemCache create new memcache adapter.
func NewMemCache() *MemcacheCache { func NewMemCache() cache.Cache {
return &MemcacheCache{} return &Cache{}
} }
// get value from memcache. // Get get value from memcache.
func (rc *MemcacheCache) Get(key string) interface{} { func (rc *Cache) Get(key string) interface{} {
if rc.conn == nil { if rc.conn == nil {
if err := rc.connectInit(); err != nil { if err := rc.connectInit(); err != nil {
return err return err
@ -63,8 +64,8 @@ func (rc *MemcacheCache) Get(key string) interface{} {
return nil return nil
} }
// get value from memcache. // GetMulti get value from memcache.
func (rc *MemcacheCache) GetMulti(keys []string) []interface{} { func (rc *Cache) GetMulti(keys []string) []interface{} {
size := len(keys) size := len(keys)
var rv []interface{} var rv []interface{}
if rc.conn == nil { if rc.conn == nil {
@ -81,16 +82,15 @@ func (rc *MemcacheCache) GetMulti(keys []string) []interface{} {
rv = append(rv, string(v.Value)) rv = append(rv, string(v.Value))
} }
return rv return rv
} else {
for i := 0; i < size; i++ {
rv = append(rv, err)
}
return rv
} }
for i := 0; i < size; i++ {
rv = append(rv, err)
}
return rv
} }
// put value to memcache. only support string. // Put put value to memcache. only support string.
func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error { func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error {
if rc.conn == nil { if rc.conn == nil {
if err := rc.connectInit(); err != nil { if err := rc.connectInit(); err != nil {
return err return err
@ -100,12 +100,12 @@ func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
if !ok { if !ok {
return errors.New("val must string") return errors.New("val must string")
} }
item := memcache.Item{Key: key, Value: []byte(v), Expiration: int32(timeout)} item := memcache.Item{Key: key, Value: []byte(v), Expiration: int32(timeout/time.Second)}
return rc.conn.Set(&item) return rc.conn.Set(&item)
} }
// delete value in memcache. // Delete delete value in memcache.
func (rc *MemcacheCache) Delete(key string) error { func (rc *Cache) Delete(key string) error {
if rc.conn == nil { if rc.conn == nil {
if err := rc.connectInit(); err != nil { if err := rc.connectInit(); err != nil {
return err return err
@ -114,8 +114,8 @@ func (rc *MemcacheCache) Delete(key string) error {
return rc.conn.Delete(key) return rc.conn.Delete(key)
} }
// increase counter. // Incr increase counter.
func (rc *MemcacheCache) Incr(key string) error { func (rc *Cache) Incr(key string) error {
if rc.conn == nil { if rc.conn == nil {
if err := rc.connectInit(); err != nil { if err := rc.connectInit(); err != nil {
return err return err
@ -125,8 +125,8 @@ func (rc *MemcacheCache) Incr(key string) error {
return err return err
} }
// decrease counter. // Decr decrease counter.
func (rc *MemcacheCache) Decr(key string) error { func (rc *Cache) Decr(key string) error {
if rc.conn == nil { if rc.conn == nil {
if err := rc.connectInit(); err != nil { if err := rc.connectInit(); err != nil {
return err return err
@ -136,8 +136,8 @@ func (rc *MemcacheCache) Decr(key string) error {
return err return err
} }
// check value exists in memcache. // IsExist check value exists in memcache.
func (rc *MemcacheCache) IsExist(key string) bool { func (rc *Cache) IsExist(key string) bool {
if rc.conn == nil { if rc.conn == nil {
if err := rc.connectInit(); err != nil { if err := rc.connectInit(); err != nil {
return false return false
@ -150,8 +150,8 @@ func (rc *MemcacheCache) IsExist(key string) bool {
return true return true
} }
// clear all cached in memcache. // ClearAll clear all cached in memcache.
func (rc *MemcacheCache) ClearAll() error { func (rc *Cache) ClearAll() error {
if rc.conn == nil { if rc.conn == nil {
if err := rc.connectInit(); err != nil { if err := rc.connectInit(); err != nil {
return err return err
@ -160,10 +160,10 @@ func (rc *MemcacheCache) ClearAll() error {
return rc.conn.FlushAll() return rc.conn.FlushAll()
} }
// start memcache adapter. // StartAndGC start memcache adapter.
// config string is like {"conn":"connection info"}. // config string is like {"conn":"connection info"}.
// if connecting error, return. // if connecting error, return.
func (rc *MemcacheCache) StartAndGC(config string) error { func (rc *Cache) StartAndGC(config string) error {
var cf map[string]string var cf map[string]string
json.Unmarshal([]byte(config), &cf) json.Unmarshal([]byte(config), &cf)
if _, ok := cf["conn"]; !ok { if _, ok := cf["conn"]; !ok {
@ -179,11 +179,11 @@ func (rc *MemcacheCache) StartAndGC(config string) error {
} }
// connect to memcache and keep the connection. // connect to memcache and keep the connection.
func (rc *MemcacheCache) connectInit() error { func (rc *Cache) connectInit() error {
rc.conn = memcache.New(rc.conninfo...) rc.conn = memcache.New(rc.conninfo...)
return nil return nil
} }
func init() { func init() {
cache.Register("memcache", NewMemCache()) cache.Register("memcache", NewMemCache)
} }

View File

@ -23,12 +23,13 @@ import (
"time" "time"
) )
func TestRedisCache(t *testing.T) { func TestMemcacheCache(t *testing.T) {
bm, err := cache.NewCache("memcache", `{"conn": "127.0.0.1:11211"}`) bm, err := cache.NewCache("memcache", `{"conn": "127.0.0.1:11211"}`)
if err != nil { if err != nil {
t.Error("init err") t.Error("init err")
} }
if err = bm.Put("astaxie", "1", 10); err != nil { timeoutDuration := 10 * time.Second
if err = bm.Put("astaxie", "1", timeoutDuration); err != nil {
t.Error("set Error", err) t.Error("set Error", err)
} }
if !bm.IsExist("astaxie") { if !bm.IsExist("astaxie") {
@ -40,7 +41,7 @@ func TestRedisCache(t *testing.T) {
if bm.IsExist("astaxie") { if bm.IsExist("astaxie") {
t.Error("check err") t.Error("check err")
} }
if err = bm.Put("astaxie", "1", 10); err != nil { if err = bm.Put("astaxie", "1", timeoutDuration); err != nil {
t.Error("set Error", err) t.Error("set Error", err)
} }
@ -69,7 +70,7 @@ func TestRedisCache(t *testing.T) {
} }
//test string //test string
if err = bm.Put("astaxie", "author", 10); err != nil { if err = bm.Put("astaxie", "author", timeoutDuration); err != nil {
t.Error("set Error", err) t.Error("set Error", err)
} }
if !bm.IsExist("astaxie") { if !bm.IsExist("astaxie") {
@ -81,7 +82,7 @@ func TestRedisCache(t *testing.T) {
} }
//test GetMulti //test GetMulti
if err = bm.Put("astaxie1", "author1", 10); err != nil { if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil {
t.Error("set Error", err) t.Error("set Error", err)
} }
if !bm.IsExist("astaxie1") { if !bm.IsExist("astaxie1") {

112
cache/memory.go vendored
View File

@ -17,34 +17,41 @@ package cache
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"sync" "sync"
"time" "time"
) )
var ( var (
// clock time of recycling the expired cache items in memory. // DefaultEvery means the clock time of recycling the expired cache items in memory.
DefaultEvery int = 60 // 1 minute DefaultEvery = 60 // 1 minute
) )
// Memory cache item. // MemoryItem store memory cache item.
type MemoryItem struct { type MemoryItem struct {
val interface{} val interface{}
Lastaccess time.Time createdTime time.Time
expired int64 lifespan time.Duration
} }
// Memory cache adapter. func (mi *MemoryItem) isExpire() bool {
// 0 means forever
if mi.lifespan == 0 {
return false
}
return time.Now().Sub(mi.createdTime) > mi.lifespan
}
// MemoryCache is Memory cache adapter.
// it contains a RW locker for safe map storage. // it contains a RW locker for safe map storage.
type MemoryCache struct { type MemoryCache struct {
lock sync.RWMutex sync.RWMutex
dur time.Duration dur time.Duration
items map[string]*MemoryItem items map[string]*MemoryItem
Every int // run an expiration check Every clock time Every int // run an expiration check Every clock time
} }
// NewMemoryCache returns a new MemoryCache. // NewMemoryCache returns a new MemoryCache.
func NewMemoryCache() *MemoryCache { func NewMemoryCache() Cache {
cache := MemoryCache{items: make(map[string]*MemoryItem)} cache := MemoryCache{items: make(map[string]*MemoryItem)}
return &cache return &cache
} }
@ -52,11 +59,10 @@ func NewMemoryCache() *MemoryCache {
// Get cache from memory. // Get cache from memory.
// if non-existed or expired, return nil. // if non-existed or expired, return nil.
func (bc *MemoryCache) Get(name string) interface{} { func (bc *MemoryCache) Get(name string) interface{} {
bc.lock.RLock() bc.RLock()
defer bc.lock.RUnlock() defer bc.RUnlock()
if itm, ok := bc.items[name]; ok { if itm, ok := bc.items[name]; ok {
if (time.Now().Unix() - itm.Lastaccess.Unix()) > itm.expired { if itm.isExpire() {
go bc.Delete(name)
return nil return nil
} }
return itm.val return itm.val
@ -75,22 +81,22 @@ func (bc *MemoryCache) GetMulti(names []string) []interface{} {
} }
// Put cache to memory. // Put cache to memory.
// if expired is 0, it will be cleaned by next gc operation ( default gc clock is 1 minute). // if lifespan is 0, it will be forever till restart.
func (bc *MemoryCache) Put(name string, value interface{}, expired int64) error { func (bc *MemoryCache) Put(name string, value interface{}, lifespan time.Duration) error {
bc.lock.Lock() bc.Lock()
defer bc.lock.Unlock() defer bc.Unlock()
bc.items[name] = &MemoryItem{ bc.items[name] = &MemoryItem{
val: value, val: value,
Lastaccess: time.Now(), createdTime: time.Now(),
expired: expired, lifespan: lifespan,
} }
return nil return nil
} }
/// Delete cache in memory. // Delete cache in memory.
func (bc *MemoryCache) Delete(name string) error { func (bc *MemoryCache) Delete(name string) error {
bc.lock.Lock() bc.Lock()
defer bc.lock.Unlock() defer bc.Unlock()
if _, ok := bc.items[name]; !ok { if _, ok := bc.items[name]; !ok {
return errors.New("key not exist") return errors.New("key not exist")
} }
@ -101,11 +107,11 @@ func (bc *MemoryCache) Delete(name string) error {
return nil return nil
} }
// Increase cache counter in memory. // Incr increase cache counter in memory.
// it supports int,int64,int32,uint,uint64,uint32. // it supports int,int32,int64,uint,uint32,uint64.
func (bc *MemoryCache) Incr(key string) error { func (bc *MemoryCache) Incr(key string) error {
bc.lock.RLock() bc.RLock()
defer bc.lock.RUnlock() defer bc.RUnlock()
itm, ok := bc.items[key] itm, ok := bc.items[key]
if !ok { if !ok {
return errors.New("key not exist") return errors.New("key not exist")
@ -113,10 +119,10 @@ func (bc *MemoryCache) Incr(key string) error {
switch itm.val.(type) { switch itm.val.(type) {
case int: case int:
itm.val = itm.val.(int) + 1 itm.val = itm.val.(int) + 1
case int64:
itm.val = itm.val.(int64) + 1
case int32: case int32:
itm.val = itm.val.(int32) + 1 itm.val = itm.val.(int32) + 1
case int64:
itm.val = itm.val.(int64) + 1
case uint: case uint:
itm.val = itm.val.(uint) + 1 itm.val = itm.val.(uint) + 1
case uint32: case uint32:
@ -124,15 +130,15 @@ func (bc *MemoryCache) Incr(key string) error {
case uint64: case uint64:
itm.val = itm.val.(uint64) + 1 itm.val = itm.val.(uint64) + 1
default: default:
return errors.New("item val is not int int64 int32") return errors.New("item val is not (u)int (u)int32 (u)int64")
} }
return nil return nil
} }
// Decrease counter in memory. // Decr decrease counter in memory.
func (bc *MemoryCache) Decr(key string) error { func (bc *MemoryCache) Decr(key string) error {
bc.lock.RLock() bc.RLock()
defer bc.lock.RUnlock() defer bc.RUnlock()
itm, ok := bc.items[key] itm, ok := bc.items[key]
if !ok { if !ok {
return errors.New("key not exist") return errors.New("key not exist")
@ -168,23 +174,25 @@ func (bc *MemoryCache) Decr(key string) error {
return nil return nil
} }
// check cache exist in memory. // IsExist check cache exist in memory.
func (bc *MemoryCache) IsExist(name string) bool { func (bc *MemoryCache) IsExist(name string) bool {
bc.lock.RLock() bc.RLock()
defer bc.lock.RUnlock() defer bc.RUnlock()
_, ok := bc.items[name] if v, ok := bc.items[name]; ok {
return ok return !v.isExpire()
}
return false
} }
// delete all cache in memory. // ClearAll will delete all cache in memory.
func (bc *MemoryCache) ClearAll() error { func (bc *MemoryCache) ClearAll() error {
bc.lock.Lock() bc.Lock()
defer bc.lock.Unlock() defer bc.Unlock()
bc.items = make(map[string]*MemoryItem) bc.items = make(map[string]*MemoryItem)
return nil return nil
} }
// start memory cache. it will check expiration in every clock time. // StartAndGC start memory cache. it will check expiration in every clock time.
func (bc *MemoryCache) StartAndGC(config string) error { func (bc *MemoryCache) StartAndGC(config string) error {
var cf map[string]int var cf map[string]int
json.Unmarshal([]byte(config), &cf) json.Unmarshal([]byte(config), &cf)
@ -192,10 +200,7 @@ func (bc *MemoryCache) StartAndGC(config string) error {
cf = make(map[string]int) cf = make(map[string]int)
cf["interval"] = DefaultEvery cf["interval"] = DefaultEvery
} }
dur, err := time.ParseDuration(fmt.Sprintf("%ds", cf["interval"])) dur := time.Duration(cf["interval"]) * time.Second
if err != nil {
return err
}
bc.Every = cf["interval"] bc.Every = cf["interval"]
bc.dur = dur bc.dur = dur
go bc.vaccuum() go bc.vaccuum()
@ -213,20 +218,21 @@ func (bc *MemoryCache) vaccuum() {
return return
} }
for name := range bc.items { for name := range bc.items {
bc.item_expired(name) bc.itemExpired(name)
} }
} }
} }
// item_expired returns true if an item is expired. // itemExpired returns true if an item is expired.
func (bc *MemoryCache) item_expired(name string) bool { func (bc *MemoryCache) itemExpired(name string) bool {
bc.lock.Lock() bc.Lock()
defer bc.lock.Unlock() defer bc.Unlock()
itm, ok := bc.items[name] itm, ok := bc.items[name]
if !ok { if !ok {
return true return true
} }
if time.Now().Unix()-itm.Lastaccess.Unix() >= itm.expired { if itm.isExpire() {
delete(bc.items, name) delete(bc.items, name)
return true return true
} }
@ -234,5 +240,5 @@ func (bc *MemoryCache) item_expired(name string) bool {
} }
func init() { func init() {
Register("memory", NewMemoryCache()) Register("memory", NewMemoryCache)
} }

56
cache/redis/redis.go vendored
View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// package redis for cache provider // Package redis for cache provider
// //
// depend on github.com/garyburd/redigo/redis // depend on github.com/garyburd/redigo/redis
// //
@ -41,12 +41,12 @@ import (
) )
var ( var (
// the collection name of redis for cache adapter. // DefaultKey the collection name of redis for cache adapter.
DefaultKey string = "beecacheRedis" DefaultKey = "beecacheRedis"
) )
// Redis cache adapter. // Cache is Redis cache adapter.
type RedisCache struct { type Cache struct {
p *redis.Pool // redis connection pool p *redis.Pool // redis connection pool
conninfo string conninfo string
dbNum int dbNum int
@ -54,13 +54,13 @@ type RedisCache struct {
password string password string
} }
// create new redis cache with default collection name. // NewRedisCache create new redis cache with default collection name.
func NewRedisCache() *RedisCache { func NewRedisCache() cache.Cache {
return &RedisCache{key: DefaultKey} return &Cache{key: DefaultKey}
} }
// actually do the redis cmds // actually do the redis cmds
func (rc *RedisCache) do(commandName string, args ...interface{}) (reply interface{}, err error) { func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) {
c := rc.p.Get() c := rc.p.Get()
defer c.Close() defer c.Close()
@ -68,7 +68,7 @@ func (rc *RedisCache) do(commandName string, args ...interface{}) (reply interfa
} }
// Get cache from redis. // Get cache from redis.
func (rc *RedisCache) Get(key string) interface{} { func (rc *Cache) Get(key string) interface{} {
if v, err := rc.do("GET", key); err == nil { if v, err := rc.do("GET", key); err == nil {
return v return v
} }
@ -76,7 +76,7 @@ func (rc *RedisCache) Get(key string) interface{} {
} }
// GetMulti get cache from redis. // GetMulti get cache from redis.
func (rc *RedisCache) GetMulti(keys []string) []interface{} { func (rc *Cache) GetMulti(keys []string) []interface{} {
size := len(keys) size := len(keys)
var rv []interface{} var rv []interface{}
c := rc.p.Get() c := rc.p.Get()
@ -108,10 +108,10 @@ ERROR:
return rv return rv
} }
// put cache to redis. // Put put cache to redis.
func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error { func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error {
var err error var err error
if _, err = rc.do("SETEX", key, timeout, val); err != nil { if _, err = rc.do("SETEX", key, int64(timeout/time.Second), val); err != nil {
return err return err
} }
@ -121,8 +121,8 @@ func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error {
return err return err
} }
// delete cache in redis. // Delete delete cache in redis.
func (rc *RedisCache) Delete(key string) error { func (rc *Cache) Delete(key string) error {
var err error var err error
if _, err = rc.do("DEL", key); err != nil { if _, err = rc.do("DEL", key); err != nil {
return err return err
@ -131,8 +131,8 @@ func (rc *RedisCache) Delete(key string) error {
return err return err
} }
// check cache's existence in redis. // IsExist check cache's existence in redis.
func (rc *RedisCache) IsExist(key string) bool { func (rc *Cache) IsExist(key string) bool {
v, err := redis.Bool(rc.do("EXISTS", key)) v, err := redis.Bool(rc.do("EXISTS", key))
if err != nil { if err != nil {
return false return false
@ -145,20 +145,20 @@ func (rc *RedisCache) IsExist(key string) bool {
return v return v
} }
// increase counter in redis. // Incr increase counter in redis.
func (rc *RedisCache) Incr(key string) error { func (rc *Cache) Incr(key string) error {
_, err := redis.Bool(rc.do("INCRBY", key, 1)) _, err := redis.Bool(rc.do("INCRBY", key, 1))
return err return err
} }
// decrease counter in redis. // Decr decrease counter in redis.
func (rc *RedisCache) Decr(key string) error { func (rc *Cache) Decr(key string) error {
_, err := redis.Bool(rc.do("INCRBY", key, -1)) _, err := redis.Bool(rc.do("INCRBY", key, -1))
return err return err
} }
// clean all cache in redis. delete this redis collection. // ClearAll clean all cache in redis. delete this redis collection.
func (rc *RedisCache) ClearAll() error { func (rc *Cache) ClearAll() error {
cachedKeys, err := redis.Strings(rc.do("HKEYS", rc.key)) cachedKeys, err := redis.Strings(rc.do("HKEYS", rc.key))
if err != nil { if err != nil {
return err return err
@ -172,11 +172,11 @@ func (rc *RedisCache) ClearAll() error {
return err return err
} }
// start redis cache adapter. // StartAndGC start redis cache adapter.
// config is like {"key":"collection key","conn":"connection info","dbNum":"0"} // config is like {"key":"collection key","conn":"connection info","dbNum":"0"}
// the cache item in redis are stored forever, // the cache item in redis are stored forever,
// so no gc operation. // so no gc operation.
func (rc *RedisCache) StartAndGC(config string) error { func (rc *Cache) StartAndGC(config string) error {
var cf map[string]string var cf map[string]string
json.Unmarshal([]byte(config), &cf) json.Unmarshal([]byte(config), &cf)
@ -206,7 +206,7 @@ func (rc *RedisCache) StartAndGC(config string) error {
} }
// connect to redis. // connect to redis.
func (rc *RedisCache) connectInit() { func (rc *Cache) connectInit() {
dialFunc := func() (c redis.Conn, err error) { dialFunc := func() (c redis.Conn, err error) {
c, err = redis.Dial("tcp", rc.conninfo) c, err = redis.Dial("tcp", rc.conninfo)
if err != nil { if err != nil {
@ -236,5 +236,5 @@ func (rc *RedisCache) connectInit() {
} }
func init() { func init() {
cache.Register("redis", NewRedisCache()) cache.Register("redis", NewRedisCache)
} }

View File

@ -28,19 +28,20 @@ func TestRedisCache(t *testing.T) {
if err != nil { if err != nil {
t.Error("init err") t.Error("init err")
} }
if err = bm.Put("astaxie", 1, 10); err != nil { timeoutDuration := 10 * time.Second
if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err) t.Error("set Error", err)
} }
if !bm.IsExist("astaxie") { if !bm.IsExist("astaxie") {
t.Error("check err") t.Error("check err")
} }
time.Sleep(10 * time.Second) time.Sleep(11 * time.Second)
if bm.IsExist("astaxie") { if bm.IsExist("astaxie") {
t.Error("check err") t.Error("check err")
} }
if err = bm.Put("astaxie", 1, 10); err != nil { if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err) t.Error("set Error", err)
} }
@ -69,7 +70,7 @@ func TestRedisCache(t *testing.T) {
} }
//test string //test string
if err = bm.Put("astaxie", "author", 10); err != nil { if err = bm.Put("astaxie", "author", timeoutDuration); err != nil {
t.Error("set Error", err) t.Error("set Error", err)
} }
if !bm.IsExist("astaxie") { if !bm.IsExist("astaxie") {
@ -81,7 +82,7 @@ func TestRedisCache(t *testing.T) {
} }
//test GetMulti //test GetMulti
if err = bm.Put("astaxie1", "author1", 10); err != nil { if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil {
t.Error("set Error", err) t.Error("set Error", err)
} }
if !bm.IsExist("astaxie1") { if !bm.IsExist("astaxie1") {

666
config.go
View File

@ -15,78 +15,264 @@
package beego package beego
import ( import (
"fmt"
"html/template" "html/template"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"github.com/astaxie/beego/config" "github.com/astaxie/beego/config"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/session" "github.com/astaxie/beego/session"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
var ( type BeegoConfig struct {
BeeApp *App // beego application AppName string //Application name
AppName string RunMode string //Running Mode: dev | prod
AppPath string RouterCaseSensitive bool
workPath string ServerName string
AppConfigPath string RecoverPanic bool
CopyRequestBody bool
EnableGzip bool
MaxMemory int64
EnableErrorsShow bool
Listen Listen
WebConfig WebConfig
Log LogConfig
}
type Listen struct {
Graceful bool // Graceful means use graceful module to start the server
ServerTimeOut int64
ListenTCP4 bool
EnableHTTP bool
HTTPAddr string
HTTPPort int
EnableHTTPS bool
HTTPSAddr string
HTTPSPort int
HTTPSCertFile string
HTTPSKeyFile string
EnableAdmin bool
AdminAddr string
AdminPort int
EnableFcgi bool
EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O
}
type WebConfig struct {
AutoRender bool
EnableDocs bool
FlashName string
FlashSeparator string
DirectoryIndex bool
StaticDir map[string]string StaticDir map[string]string
TemplateCache map[string]*template.Template // template caching map StaticExtensionsToGzip []string
StaticExtensionsToGzip []string // files with should be compressed with gzip (.js,.css,etc)
EnableHttpListen bool
HttpAddr string
HttpPort int
ListenTCP4 bool
EnableHttpTLS bool
HttpsPort int
HttpCertFile string
HttpKeyFile string
RecoverPanic bool // flag of auto recover panic
AutoRender bool // flag of render template automatically
ViewsPath string
AppConfig *beegoAppConfig
RunMode string // run mode, "dev" or "prod"
GlobalSessions *session.Manager // global session mananger
SessionOn bool // flag of starting session auto. default is false.
SessionProvider string // default session provider, memory, mysql , redis ,etc.
SessionName string // the cookie name when saving session id into cookie.
SessionGCMaxLifetime int64 // session gc time for auto cleaning expired session.
SessionSavePath string // if use mysql/redis/file provider, define save path to connection info.
SessionCookieLifeTime int // the life time of session id in cookie.
SessionAutoSetCookie bool // auto setcookie
SessionDomain string // the cookie domain default is empty
UseFcgi bool
UseStdIo bool
MaxMemory int64
EnableGzip bool // flag of enable gzip
DirectoryIndex bool // flag of display directory index. default is false.
HttpServerTimeOut int64
ErrorsShow bool // flag of show errors in page. if true, show error and trace info in page rendered with error template.
XSRFKEY string // xsrf hash salt string.
EnableXSRF bool // flag of enable xsrf.
XSRFExpire int // the expiry of xsrf value.
CopyRequestBody bool // flag of copy raw request body in context.
TemplateLeft string TemplateLeft string
TemplateRight string TemplateRight string
BeegoServerName string // beego server name exported in response header. ViewsPath string
EnableAdmin bool // flag of enable admin module to log every request info. EnableXSRF bool
AdminHttpAddr string // http server configurations for admin module. XSRFKey string
AdminHttpPort int XSRFExpire int
FlashName string // name of the flash variable found in response header and cookie Session SessionConfig
FlashSeperator string // used to seperate flash key:value }
AppConfigProvider string // config provider
EnableDocs bool // enable generate docs & server docs API Swagger type SessionConfig struct {
RouterCaseSensitive bool // router case sensitive default is true SessionOn bool
AccessLogs bool // print access logs, default is false SessionProvider string
Graceful bool // use graceful start the server SessionName string
SessionGCMaxLifetime int64
SessionProviderConfig string
SessionCookieLifeTime int
SessionAutoSetCookie bool
SessionDomain string
}
type LogConfig struct {
AccessLogs bool
FileLineNum bool
Outputs map[string]string // Store Adaptor : config
}
var (
// BConfig is the default config for Application
BConfig *BeegoConfig
// AppConfig is the instance of Config, store the config information from file
AppConfig *beegoAppConfig
// AppConfigPath is the path to the config files
AppConfigPath string
// AppConfigProvider is the provider for the config, default is ini
AppConfigProvider = "ini"
// TemplateCache stores template caching
TemplateCache map[string]*template.Template
// GlobalSessions is the instance for the session manager
GlobalSessions *session.Manager
) )
func init() {
BConfig = &BeegoConfig{
AppName: "beego",
RunMode: DEV,
RouterCaseSensitive: true,
ServerName: "beegoServer:" + VERSION,
RecoverPanic: true,
CopyRequestBody: false,
EnableGzip: false,
MaxMemory: 1 << 26, //64MB
EnableErrorsShow: true,
Listen: Listen{
Graceful: false,
ServerTimeOut: 0,
ListenTCP4: false,
EnableHTTP: true,
HTTPAddr: "",
HTTPPort: 8080,
EnableHTTPS: false,
HTTPSAddr: "",
HTTPSPort: 10443,
HTTPSCertFile: "",
HTTPSKeyFile: "",
EnableAdmin: false,
AdminAddr: "",
AdminPort: 8088,
EnableFcgi: false,
EnableStdIo: false,
},
WebConfig: WebConfig{
AutoRender: true,
EnableDocs: false,
FlashName: "BEEGO_FLASH",
FlashSeparator: "BEEGOFLASH",
DirectoryIndex: false,
StaticDir: map[string]string{"/static": "static"},
StaticExtensionsToGzip: []string{".css", ".js"},
TemplateLeft: "{{",
TemplateRight: "}}",
ViewsPath: "views",
EnableXSRF: false,
XSRFKey: "beegoxsrf",
XSRFExpire: 0,
Session: SessionConfig{
SessionOn: false,
SessionProvider: "memory",
SessionName: "beegosessionID",
SessionGCMaxLifetime: 3600,
SessionProviderConfig: "",
SessionCookieLifeTime: 0, //set cookie default is the brower life
SessionAutoSetCookie: true,
SessionDomain: "",
},
},
Log: LogConfig{
AccessLogs: false,
FileLineNum: true,
Outputs: map[string]string{"console": ""},
},
}
ParseConfig()
}
// ParseConfig parsed default config file.
// now only support ini, next will support json.
func ParseConfig() (err error) {
if AppConfigPath == "" {
if utils.FileExists(filepath.Join("conf", "app.conf")) {
AppConfigPath = filepath.Join("conf", "app.conf")
} else {
AppConfig = &beegoAppConfig{config.NewFakeConfig()}
return
}
}
AppConfig, err = newAppConfig(AppConfigProvider, AppConfigPath)
if err != nil {
return err
}
// set the runmode first
if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" {
BConfig.RunMode = envRunMode
} else if runmode := AppConfig.String("RunMode"); runmode != "" {
BConfig.RunMode = runmode
}
BConfig.AppName = AppConfig.DefaultString("AppName", BConfig.AppName)
BConfig.RecoverPanic = AppConfig.DefaultBool("RecoverPanic", BConfig.RecoverPanic)
BConfig.RouterCaseSensitive = AppConfig.DefaultBool("RouterCaseSensitive", BConfig.RouterCaseSensitive)
BConfig.ServerName = AppConfig.DefaultString("ServerName", BConfig.ServerName)
BConfig.EnableGzip = AppConfig.DefaultBool("EnableGzip", BConfig.EnableGzip)
BConfig.EnableErrorsShow = AppConfig.DefaultBool("EnableErrorsShow", BConfig.EnableErrorsShow)
BConfig.CopyRequestBody = AppConfig.DefaultBool("CopyRequestBody", BConfig.CopyRequestBody)
BConfig.MaxMemory = AppConfig.DefaultInt64("MaxMemory", BConfig.MaxMemory)
BConfig.Listen.Graceful = AppConfig.DefaultBool("Graceful", BConfig.Listen.Graceful)
BConfig.Listen.HTTPAddr = AppConfig.String("HTTPAddr")
BConfig.Listen.HTTPPort = AppConfig.DefaultInt("HTTPPort", BConfig.Listen.HTTPPort)
BConfig.Listen.ListenTCP4 = AppConfig.DefaultBool("ListenTCP4", BConfig.Listen.ListenTCP4)
BConfig.Listen.EnableHTTP = AppConfig.DefaultBool("EnableHTTP", BConfig.Listen.EnableHTTP)
BConfig.Listen.EnableHTTPS = AppConfig.DefaultBool("EnableHTTPS", BConfig.Listen.EnableHTTPS)
BConfig.Listen.HTTPSAddr = AppConfig.DefaultString("HTTPSAddr", BConfig.Listen.HTTPSAddr)
BConfig.Listen.HTTPSPort = AppConfig.DefaultInt("HTTPSPort", BConfig.Listen.HTTPSPort)
BConfig.Listen.HTTPSCertFile = AppConfig.DefaultString("HTTPSCertFile", BConfig.Listen.HTTPSCertFile)
BConfig.Listen.HTTPSKeyFile = AppConfig.DefaultString("HTTPSKeyFile", BConfig.Listen.HTTPSKeyFile)
BConfig.Listen.EnableAdmin = AppConfig.DefaultBool("EnableAdmin", BConfig.Listen.EnableAdmin)
BConfig.Listen.AdminAddr = AppConfig.DefaultString("AdminAddr", BConfig.Listen.AdminAddr)
BConfig.Listen.AdminPort = AppConfig.DefaultInt("AdminPort", BConfig.Listen.AdminPort)
BConfig.Listen.EnableFcgi = AppConfig.DefaultBool("EnableFcgi", BConfig.Listen.EnableFcgi)
BConfig.Listen.EnableStdIo = AppConfig.DefaultBool("EnableStdIo", BConfig.Listen.EnableStdIo)
BConfig.Listen.ServerTimeOut = AppConfig.DefaultInt64("ServerTimeOut", BConfig.Listen.ServerTimeOut)
BConfig.WebConfig.AutoRender = AppConfig.DefaultBool("AutoRender", BConfig.WebConfig.AutoRender)
BConfig.WebConfig.ViewsPath = AppConfig.DefaultString("ViewsPath", BConfig.WebConfig.ViewsPath)
BConfig.WebConfig.DirectoryIndex = AppConfig.DefaultBool("DirectoryIndex", BConfig.WebConfig.DirectoryIndex)
BConfig.WebConfig.FlashName = AppConfig.DefaultString("FlashName", BConfig.WebConfig.FlashName)
BConfig.WebConfig.FlashSeparator = AppConfig.DefaultString("FlashSeparator", BConfig.WebConfig.FlashSeparator)
BConfig.WebConfig.EnableDocs = AppConfig.DefaultBool("EnableDocs", BConfig.WebConfig.EnableDocs)
BConfig.WebConfig.XSRFKey = AppConfig.DefaultString("XSRFKEY", BConfig.WebConfig.XSRFKey)
BConfig.WebConfig.EnableXSRF = AppConfig.DefaultBool("EnableXSRF", BConfig.WebConfig.EnableXSRF)
BConfig.WebConfig.XSRFExpire = AppConfig.DefaultInt("XSRFExpire", BConfig.WebConfig.XSRFExpire)
BConfig.WebConfig.TemplateLeft = AppConfig.DefaultString("TemplateLeft", BConfig.WebConfig.TemplateLeft)
BConfig.WebConfig.TemplateRight = AppConfig.DefaultString("TemplateRight", BConfig.WebConfig.TemplateRight)
BConfig.WebConfig.Session.SessionOn = AppConfig.DefaultBool("SessionOn", BConfig.WebConfig.Session.SessionOn)
BConfig.WebConfig.Session.SessionProvider = AppConfig.DefaultString("SessionProvider", BConfig.WebConfig.Session.SessionProvider)
BConfig.WebConfig.Session.SessionName = AppConfig.DefaultString("SessionName", BConfig.WebConfig.Session.SessionName)
BConfig.WebConfig.Session.SessionProviderConfig = AppConfig.DefaultString("SessionProviderConfig", BConfig.WebConfig.Session.SessionProviderConfig)
BConfig.WebConfig.Session.SessionGCMaxLifetime = AppConfig.DefaultInt64("SessionGCMaxLifetime", BConfig.WebConfig.Session.SessionGCMaxLifetime)
BConfig.WebConfig.Session.SessionCookieLifeTime = AppConfig.DefaultInt("SessionCookieLifeTime", BConfig.WebConfig.Session.SessionCookieLifeTime)
BConfig.WebConfig.Session.SessionAutoSetCookie = AppConfig.DefaultBool("SessionAutoSetCookie", BConfig.WebConfig.Session.SessionAutoSetCookie)
BConfig.WebConfig.Session.SessionDomain = AppConfig.DefaultString("SessionDomain", BConfig.WebConfig.Session.SessionDomain)
if sd := AppConfig.String("StaticDir"); sd != "" {
for k := range BConfig.WebConfig.StaticDir {
delete(BConfig.WebConfig.StaticDir, k)
}
sds := strings.Fields(sd)
for _, v := range sds {
if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 {
BConfig.WebConfig.StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[1]
} else {
BConfig.WebConfig.StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[0]
}
}
}
if sgz := AppConfig.String("StaticExtensionsToGzip"); sgz != "" {
extensions := strings.Split(sgz, ",")
fileExts := []string{}
for _, ext := range extensions {
ext = strings.TrimSpace(ext)
if ext == "" {
continue
}
if !strings.HasPrefix(ext, ".") {
ext = "." + ext
}
fileExts = append(fileExts, ext)
}
if len(fileExts) > 0 {
BConfig.WebConfig.StaticExtensionsToGzip = fileExts
}
}
return nil
}
type beegoAppConfig struct { type beegoAppConfig struct {
innerConfig config.ConfigContainer innerConfig config.Configer
} }
func newAppConfig(AppConfigProvider, AppConfigPath string) (*beegoAppConfig, error) { func newAppConfig(AppConfigProvider, AppConfigPath string) (*beegoAppConfig, error) {
@ -94,109 +280,95 @@ func newAppConfig(AppConfigProvider, AppConfigPath string) (*beegoAppConfig, err
if err != nil { if err != nil {
return nil, err return nil, err
} }
rac := &beegoAppConfig{ac} return &beegoAppConfig{ac}, nil
return rac, nil
} }
func (b *beegoAppConfig) Set(key, val string) error { func (b *beegoAppConfig) Set(key, val string) error {
err := b.innerConfig.Set(RunMode+"::"+key, val) if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil {
if err == nil {
return err return err
} }
return b.innerConfig.Set(key, val) return b.innerConfig.Set(key, val)
} }
func (b *beegoAppConfig) String(key string) string { func (b *beegoAppConfig) String(key string) string {
v := b.innerConfig.String(RunMode + "::" + key) if v := b.innerConfig.String(BConfig.RunMode + "::" + key); v != "" {
if v == "" { return v
return b.innerConfig.String(key)
} }
return v return b.innerConfig.String(key)
} }
func (b *beegoAppConfig) Strings(key string) []string { func (b *beegoAppConfig) Strings(key string) []string {
v := b.innerConfig.Strings(RunMode + "::" + key) if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); v[0] != "" {
if v[0] == "" { return v
return b.innerConfig.Strings(key)
} }
return v return b.innerConfig.Strings(key)
} }
func (b *beegoAppConfig) Int(key string) (int, error) { func (b *beegoAppConfig) Int(key string) (int, error) {
v, err := b.innerConfig.Int(RunMode + "::" + key) if v, err := b.innerConfig.Int(BConfig.RunMode + "::" + key); err == nil {
if err != nil { return v, nil
return b.innerConfig.Int(key)
} }
return v, nil return b.innerConfig.Int(key)
} }
func (b *beegoAppConfig) Int64(key string) (int64, error) { func (b *beegoAppConfig) Int64(key string) (int64, error) {
v, err := b.innerConfig.Int64(RunMode + "::" + key) if v, err := b.innerConfig.Int64(BConfig.RunMode + "::" + key); err == nil {
if err != nil { return v, nil
return b.innerConfig.Int64(key)
} }
return v, nil return b.innerConfig.Int64(key)
} }
func (b *beegoAppConfig) Bool(key string) (bool, error) { func (b *beegoAppConfig) Bool(key string) (bool, error) {
v, err := b.innerConfig.Bool(RunMode + "::" + key) if v, err := b.innerConfig.Bool(BConfig.RunMode + "::" + key); err == nil {
if err != nil { return v, nil
return b.innerConfig.Bool(key)
} }
return v, nil return b.innerConfig.Bool(key)
} }
func (b *beegoAppConfig) Float(key string) (float64, error) { func (b *beegoAppConfig) Float(key string) (float64, error) {
v, err := b.innerConfig.Float(RunMode + "::" + key) if v, err := b.innerConfig.Float(BConfig.RunMode + "::" + key); err == nil {
if err != nil { return v, nil
return b.innerConfig.Float(key)
} }
return v, nil return b.innerConfig.Float(key)
} }
func (b *beegoAppConfig) DefaultString(key string, defaultval string) string { func (b *beegoAppConfig) DefaultString(key string, defaultval string) string {
v := b.String(key) if v := b.String(key); v != "" {
if v != "" {
return v return v
} }
return defaultval return defaultval
} }
func (b *beegoAppConfig) DefaultStrings(key string, defaultval []string) []string { func (b *beegoAppConfig) DefaultStrings(key string, defaultval []string) []string {
v := b.Strings(key) if v := b.Strings(key); len(v) != 0 {
if len(v) != 0 {
return v return v
} }
return defaultval return defaultval
} }
func (b *beegoAppConfig) DefaultInt(key string, defaultval int) int { func (b *beegoAppConfig) DefaultInt(key string, defaultval int) int {
v, err := b.Int(key) if v, err := b.Int(key); err == nil {
if err == nil {
return v return v
} }
return defaultval return defaultval
} }
func (b *beegoAppConfig) DefaultInt64(key string, defaultval int64) int64 { func (b *beegoAppConfig) DefaultInt64(key string, defaultval int64) int64 {
v, err := b.Int64(key) if v, err := b.Int64(key); err == nil {
if err == nil {
return v return v
} }
return defaultval return defaultval
} }
func (b *beegoAppConfig) DefaultBool(key string, defaultval bool) bool { func (b *beegoAppConfig) DefaultBool(key string, defaultval bool) bool {
v, err := b.Bool(key) if v, err := b.Bool(key); err == nil {
if err == nil {
return v return v
} }
return defaultval return defaultval
} }
func (b *beegoAppConfig) DefaultFloat(key string, defaultval float64) float64 { func (b *beegoAppConfig) DefaultFloat(key string, defaultval float64) float64 {
v, err := b.Float(key) if v, err := b.Float(key); err == nil {
if err == nil {
return v return v
} }
return defaultval return defaultval
@ -213,305 +385,3 @@ func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) {
func (b *beegoAppConfig) SaveConfigFile(filename string) error { func (b *beegoAppConfig) SaveConfigFile(filename string) error {
return b.innerConfig.SaveConfigFile(filename) return b.innerConfig.SaveConfigFile(filename)
} }
func init() {
// create beego application
BeeApp = NewApp()
workPath, _ = os.Getwd()
workPath, _ = filepath.Abs(workPath)
// initialize default configurations
AppPath, _ = filepath.Abs(filepath.Dir(os.Args[0]))
AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
if workPath != AppPath {
if utils.FileExists(AppConfigPath) {
os.Chdir(AppPath)
} else {
AppConfigPath = filepath.Join(workPath, "conf", "app.conf")
}
}
AppConfigProvider = "ini"
StaticDir = make(map[string]string)
StaticDir["/static"] = "static"
StaticExtensionsToGzip = []string{".css", ".js"}
TemplateCache = make(map[string]*template.Template)
// set this to 0.0.0.0 to make this app available to externally
EnableHttpListen = true //default enable http Listen
HttpAddr = ""
HttpPort = 8080
HttpsPort = 10443
AppName = "beego"
RunMode = "dev" //default runmod
AutoRender = true
RecoverPanic = true
ViewsPath = "views"
SessionOn = false
SessionProvider = "memory"
SessionName = "beegosessionID"
SessionGCMaxLifetime = 3600
SessionSavePath = ""
SessionCookieLifeTime = 0 //set cookie default is the brower life
SessionAutoSetCookie = true
UseFcgi = false
UseStdIo = false
MaxMemory = 1 << 26 //64MB
EnableGzip = false
HttpServerTimeOut = 0
ErrorsShow = true
XSRFKEY = "beegoxsrf"
XSRFExpire = 0
TemplateLeft = "{{"
TemplateRight = "}}"
BeegoServerName = "beegoServer:" + VERSION
EnableAdmin = false
AdminHttpAddr = "127.0.0.1"
AdminHttpPort = 8088
FlashName = "BEEGO_FLASH"
FlashSeperator = "BEEGOFLASH"
RouterCaseSensitive = true
runtime.GOMAXPROCS(runtime.NumCPU())
// init BeeLogger
BeeLogger = logs.NewLogger(10000)
err := BeeLogger.SetLogger("console", "")
if err != nil {
fmt.Println("init console log error:", err)
}
SetLogFuncCall(true)
err = ParseConfig()
if err != nil && os.IsNotExist(err) {
// for init if doesn't have app.conf will not panic
ac := config.NewFakeConfig()
AppConfig = &beegoAppConfig{ac}
Warning(err)
}
}
// ParseConfig parsed default config file.
// now only support ini, next will support json.
func ParseConfig() (err error) {
AppConfig, err = newAppConfig(AppConfigProvider, AppConfigPath)
if err != nil {
return err
}
envRunMode := os.Getenv("BEEGO_RUNMODE")
// set the runmode first
if envRunMode != "" {
RunMode = envRunMode
} else if runmode := AppConfig.String("RunMode"); runmode != "" {
RunMode = runmode
}
HttpAddr = AppConfig.String("HttpAddr")
if v, err := AppConfig.Int("HttpPort"); err == nil {
HttpPort = v
}
if v, err := AppConfig.Bool("ListenTCP4"); err == nil {
ListenTCP4 = v
}
if v, err := AppConfig.Bool("EnableHttpListen"); err == nil {
EnableHttpListen = v
}
if maxmemory, err := AppConfig.Int64("MaxMemory"); err == nil {
MaxMemory = maxmemory
}
if appname := AppConfig.String("AppName"); appname != "" {
AppName = appname
}
if autorender, err := AppConfig.Bool("AutoRender"); err == nil {
AutoRender = autorender
}
if autorecover, err := AppConfig.Bool("RecoverPanic"); err == nil {
RecoverPanic = autorecover
}
if views := AppConfig.String("ViewsPath"); views != "" {
ViewsPath = views
}
if sessionon, err := AppConfig.Bool("SessionOn"); err == nil {
SessionOn = sessionon
}
if sessProvider := AppConfig.String("SessionProvider"); sessProvider != "" {
SessionProvider = sessProvider
}
if sessName := AppConfig.String("SessionName"); sessName != "" {
SessionName = sessName
}
if sesssavepath := AppConfig.String("SessionSavePath"); sesssavepath != "" {
SessionSavePath = sesssavepath
}
if sessMaxLifeTime, err := AppConfig.Int64("SessionGCMaxLifetime"); err == nil && sessMaxLifeTime != 0 {
SessionGCMaxLifetime = sessMaxLifeTime
}
if sesscookielifetime, err := AppConfig.Int("SessionCookieLifeTime"); err == nil && sesscookielifetime != 0 {
SessionCookieLifeTime = sesscookielifetime
}
if usefcgi, err := AppConfig.Bool("UseFcgi"); err == nil {
UseFcgi = usefcgi
}
if enablegzip, err := AppConfig.Bool("EnableGzip"); err == nil {
EnableGzip = enablegzip
}
if directoryindex, err := AppConfig.Bool("DirectoryIndex"); err == nil {
DirectoryIndex = directoryindex
}
if timeout, err := AppConfig.Int64("HttpServerTimeOut"); err == nil {
HttpServerTimeOut = timeout
}
if errorsshow, err := AppConfig.Bool("ErrorsShow"); err == nil {
ErrorsShow = errorsshow
}
if copyrequestbody, err := AppConfig.Bool("CopyRequestBody"); err == nil {
CopyRequestBody = copyrequestbody
}
if xsrfkey := AppConfig.String("XSRFKEY"); xsrfkey != "" {
XSRFKEY = xsrfkey
}
if enablexsrf, err := AppConfig.Bool("EnableXSRF"); err == nil {
EnableXSRF = enablexsrf
}
if expire, err := AppConfig.Int("XSRFExpire"); err == nil {
XSRFExpire = expire
}
if tplleft := AppConfig.String("TemplateLeft"); tplleft != "" {
TemplateLeft = tplleft
}
if tplright := AppConfig.String("TemplateRight"); tplright != "" {
TemplateRight = tplright
}
if httptls, err := AppConfig.Bool("EnableHttpTLS"); err == nil {
EnableHttpTLS = httptls
}
if httpsport, err := AppConfig.Int("HttpsPort"); err == nil {
HttpsPort = httpsport
}
if certfile := AppConfig.String("HttpCertFile"); certfile != "" {
HttpCertFile = certfile
}
if keyfile := AppConfig.String("HttpKeyFile"); keyfile != "" {
HttpKeyFile = keyfile
}
if serverName := AppConfig.String("BeegoServerName"); serverName != "" {
BeegoServerName = serverName
}
if flashname := AppConfig.String("FlashName"); flashname != "" {
FlashName = flashname
}
if flashseperator := AppConfig.String("FlashSeperator"); flashseperator != "" {
FlashSeperator = flashseperator
}
if sd := AppConfig.String("StaticDir"); sd != "" {
for k := range StaticDir {
delete(StaticDir, k)
}
sds := strings.Fields(sd)
for _, v := range sds {
if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 {
StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[1]
} else {
StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[0]
}
}
}
if sgz := AppConfig.String("StaticExtensionsToGzip"); sgz != "" {
extensions := strings.Split(sgz, ",")
if len(extensions) > 0 {
StaticExtensionsToGzip = []string{}
for _, ext := range extensions {
if len(ext) == 0 {
continue
}
extWithDot := ext
if extWithDot[:1] != "." {
extWithDot = "." + extWithDot
}
StaticExtensionsToGzip = append(StaticExtensionsToGzip, extWithDot)
}
}
}
if enableadmin, err := AppConfig.Bool("EnableAdmin"); err == nil {
EnableAdmin = enableadmin
}
if adminhttpaddr := AppConfig.String("AdminHttpAddr"); adminhttpaddr != "" {
AdminHttpAddr = adminhttpaddr
}
if adminhttpport, err := AppConfig.Int("AdminHttpPort"); err == nil {
AdminHttpPort = adminhttpport
}
if enabledocs, err := AppConfig.Bool("EnableDocs"); err == nil {
EnableDocs = enabledocs
}
if casesensitive, err := AppConfig.Bool("RouterCaseSensitive"); err == nil {
RouterCaseSensitive = casesensitive
}
if graceful, err := AppConfig.Bool("Graceful"); err == nil {
Graceful = graceful
}
return nil
}

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Package config is used to parse config
// Usage: // Usage:
// import( // import(
// "github.com/astaxie/beego/config" // "github.com/astaxie/beego/config"
@ -28,12 +29,12 @@
// cnf.Int64(key string) (int64, error) // cnf.Int64(key string) (int64, error)
// cnf.Bool(key string) (bool, error) // cnf.Bool(key string) (bool, error)
// cnf.Float(key string) (float64, error) // cnf.Float(key string) (float64, error)
// cnf.DefaultString(key string, defaultval string) string // cnf.DefaultString(key string, defaultVal string) string
// cnf.DefaultStrings(key string, defaultval []string) []string // cnf.DefaultStrings(key string, defaultVal []string) []string
// cnf.DefaultInt(key string, defaultval int) int // cnf.DefaultInt(key string, defaultVal int) int
// cnf.DefaultInt64(key string, defaultval int64) int64 // cnf.DefaultInt64(key string, defaultVal int64) int64
// cnf.DefaultBool(key string, defaultval bool) bool // cnf.DefaultBool(key string, defaultVal bool) bool
// cnf.DefaultFloat(key string, defaultval float64) float64 // cnf.DefaultFloat(key string, defaultVal float64) float64
// cnf.DIY(key string) (interface{}, error) // cnf.DIY(key string) (interface{}, error)
// cnf.GetSection(section string) (map[string]string, error) // cnf.GetSection(section string) (map[string]string, error)
// cnf.SaveConfigFile(filename string) error // cnf.SaveConfigFile(filename string) error
@ -45,30 +46,30 @@ import (
"fmt" "fmt"
) )
// ConfigContainer defines how to get and set value from configuration raw data. // Configer defines how to get and set value from configuration raw data.
type ConfigContainer interface { type Configer interface {
Set(key, val string) error // support section::key type in given key when using ini type. Set(key, val string) error //support section::key type in given key when using ini type.
String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. String(key string) string //support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
Strings(key string) []string //get string slice Strings(key string) []string //get string slice
Int(key string) (int, error) Int(key string) (int, error)
Int64(key string) (int64, error) Int64(key string) (int64, error)
Bool(key string) (bool, error) Bool(key string) (bool, error)
Float(key string) (float64, error) Float(key string) (float64, error)
DefaultString(key string, defaultval string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
DefaultStrings(key string, defaultval []string) []string //get string slice DefaultStrings(key string, defaultVal []string) []string //get string slice
DefaultInt(key string, defaultval int) int DefaultInt(key string, defaultVal int) int
DefaultInt64(key string, defaultval int64) int64 DefaultInt64(key string, defaultVal int64) int64
DefaultBool(key string, defaultval bool) bool DefaultBool(key string, defaultVal bool) bool
DefaultFloat(key string, defaultval float64) float64 DefaultFloat(key string, defaultVal float64) float64
DIY(key string) (interface{}, error) DIY(key string) (interface{}, error)
GetSection(section string) (map[string]string, error) GetSection(section string) (map[string]string, error)
SaveConfigFile(filename string) error SaveConfigFile(filename string) error
} }
// Config is the adapter interface for parsing config file to get raw data to ConfigContainer. // Config is the adapter interface for parsing config file to get raw data to Configer.
type Config interface { type Config interface {
Parse(key string) (ConfigContainer, error) Parse(key string) (Configer, error)
ParseData(data []byte) (ConfigContainer, error) ParseData(data []byte) (Configer, error)
} }
var adapters = make(map[string]Config) var adapters = make(map[string]Config)
@ -86,19 +87,19 @@ func Register(name string, adapter Config) {
adapters[name] = adapter adapters[name] = adapter
} }
// adapterName is ini/json/xml/yaml. // NewConfig adapterName is ini/json/xml/yaml.
// filename is the config file path. // filename is the config file path.
func NewConfig(adapterName, fileaname string) (ConfigContainer, error) { func NewConfig(adapterName, filename string) (Configer, error) {
adapter, ok := adapters[adapterName] adapter, ok := adapters[adapterName]
if !ok { if !ok {
return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName) return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName)
} }
return adapter.Parse(fileaname) return adapter.Parse(filename)
} }
// adapterName is ini/json/xml/yaml. // NewConfigData adapterName is ini/json/xml/yaml.
// data is the config data. // data is the config data.
func NewConfigData(adapterName string, data []byte) (ConfigContainer, error) { func NewConfigData(adapterName string, data []byte) (Configer, error) {
adapter, ok := adapters[adapterName] adapter, ok := adapters[adapterName]
if !ok { if !ok {
return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName) return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName)

View File

@ -38,11 +38,11 @@ func (c *fakeConfigContainer) String(key string) string {
} }
func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string { func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string {
if v := c.getData(key); v == "" { v := c.getData(key)
if v == "" {
return defaultval return defaultval
} else {
return v
} }
return v
} }
func (c *fakeConfigContainer) Strings(key string) []string { func (c *fakeConfigContainer) Strings(key string) []string {
@ -50,11 +50,11 @@ func (c *fakeConfigContainer) Strings(key string) []string {
} }
func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []string { func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []string {
if v := c.Strings(key); len(v) == 0 { v := c.Strings(key)
if len(v) == 0 {
return defaultval return defaultval
} else {
return v
} }
return v
} }
func (c *fakeConfigContainer) Int(key string) (int, error) { func (c *fakeConfigContainer) Int(key string) (int, error) {
@ -62,11 +62,11 @@ func (c *fakeConfigContainer) Int(key string) (int, error) {
} }
func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int { func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int {
if v, err := c.Int(key); err != nil { v, err := c.Int(key)
if err != nil {
return defaultval return defaultval
} else {
return v
} }
return v
} }
func (c *fakeConfigContainer) Int64(key string) (int64, error) { func (c *fakeConfigContainer) Int64(key string) (int64, error) {
@ -74,11 +74,11 @@ func (c *fakeConfigContainer) Int64(key string) (int64, error) {
} }
func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 { func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
if v, err := c.Int64(key); err != nil { v, err := c.Int64(key)
if err != nil {
return defaultval return defaultval
} else {
return v
} }
return v
} }
func (c *fakeConfigContainer) Bool(key string) (bool, error) { func (c *fakeConfigContainer) Bool(key string) (bool, error) {
@ -86,11 +86,11 @@ func (c *fakeConfigContainer) Bool(key string) (bool, error) {
} }
func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool { func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool {
if v, err := c.Bool(key); err != nil { v, err := c.Bool(key)
if err != nil {
return defaultval return defaultval
} else {
return v
} }
return v
} }
func (c *fakeConfigContainer) Float(key string) (float64, error) { func (c *fakeConfigContainer) Float(key string) (float64, error) {
@ -98,11 +98,11 @@ func (c *fakeConfigContainer) Float(key string) (float64, error) {
} }
func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float64 { func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
if v, err := c.Float(key); err != nil { v, err := c.Float(key)
if err != nil {
return defaultval return defaultval
} else {
return v
} }
return v
} }
func (c *fakeConfigContainer) DIY(key string) (interface{}, error) { func (c *fakeConfigContainer) DIY(key string) (interface{}, error) {
@ -120,9 +120,10 @@ func (c *fakeConfigContainer) SaveConfigFile(filename string) error {
return errors.New("not implement in the fakeConfigContainer") return errors.New("not implement in the fakeConfigContainer")
} }
var _ ConfigContainer = new(fakeConfigContainer) var _ Configer = new(fakeConfigContainer)
func NewFakeConfig() ConfigContainer { // NewFakeConfig return a fake Congiger
func NewFakeConfig() Configer {
return &fakeConfigContainer{ return &fakeConfigContainer{
data: make(map[string]string), data: make(map[string]string),
} }

View File

@ -31,23 +31,23 @@ import (
) )
var ( var (
DEFAULT_SECTION = "default" // default section means if some ini items not in a section, make them in default section, defaultSection = "default" // default section means if some ini items not in a section, make them in default section,
bNumComment = []byte{'#'} // number signal bNumComment = []byte{'#'} // number signal
bSemComment = []byte{';'} // semicolon signal bSemComment = []byte{';'} // semicolon signal
bEmpty = []byte{} bEmpty = []byte{}
bEqual = []byte{'='} // equal signal bEqual = []byte{'='} // equal signal
bDQuote = []byte{'"'} // quote signal bDQuote = []byte{'"'} // quote signal
sectionStart = []byte{'['} // section start signal sectionStart = []byte{'['} // section start signal
sectionEnd = []byte{']'} // section end signal sectionEnd = []byte{']'} // section end signal
lineBreak = "\n" lineBreak = "\n"
) )
// IniConfig implements Config to parse ini file. // IniConfig implements Config to parse ini file.
type IniConfig struct { type IniConfig struct {
} }
// ParseFile creates a new Config and parses the file configuration from the named file. // Parse creates a new Config and parses the file configuration from the named file.
func (ini *IniConfig) Parse(name string) (ConfigContainer, error) { func (ini *IniConfig) Parse(name string) (Configer, error) {
return ini.parseFile(name) return ini.parseFile(name)
} }
@ -77,7 +77,7 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
buf.ReadByte() buf.ReadByte()
} }
} }
section := DEFAULT_SECTION section := defaultSection
for { for {
line, _, err := buf.ReadLine() line, _, err := buf.ReadLine()
if err == io.EOF { if err == io.EOF {
@ -171,7 +171,8 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
return cfg, nil return cfg, nil
} }
func (ini *IniConfig) ParseData(data []byte) (ConfigContainer, error) { // ParseData parse ini the data
func (ini *IniConfig) ParseData(data []byte) (Configer, error) {
// Save memory data to temporary file // Save memory data to temporary file
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond())) tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond()))
os.MkdirAll(path.Dir(tmpName), os.ModePerm) os.MkdirAll(path.Dir(tmpName), os.ModePerm)
@ -181,7 +182,7 @@ func (ini *IniConfig) ParseData(data []byte) (ConfigContainer, error) {
return ini.Parse(tmpName) return ini.Parse(tmpName)
} }
// A Config represents the ini configuration. // IniConfigContainer A Config represents the ini configuration.
// When set and get value, support key as section:name type. // When set and get value, support key as section:name type.
type IniConfigContainer struct { type IniConfigContainer struct {
filename string filename string
@ -199,11 +200,11 @@ func (c *IniConfigContainer) Bool(key string) (bool, error) {
// DefaultBool returns the boolean value for a given key. // DefaultBool returns the boolean value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool { func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool {
if v, err := c.Bool(key); err != nil { v, err := c.Bool(key)
if err != nil {
return defaultval return defaultval
} else {
return v
} }
return v
} }
// Int returns the integer value for a given key. // Int returns the integer value for a given key.
@ -214,11 +215,11 @@ func (c *IniConfigContainer) Int(key string) (int, error) {
// DefaultInt returns the integer value for a given key. // DefaultInt returns the integer value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int { func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int {
if v, err := c.Int(key); err != nil { v, err := c.Int(key)
if err != nil {
return defaultval return defaultval
} else {
return v
} }
return v
} }
// Int64 returns the int64 value for a given key. // Int64 returns the int64 value for a given key.
@ -229,11 +230,11 @@ func (c *IniConfigContainer) Int64(key string) (int64, error) {
// DefaultInt64 returns the int64 value for a given key. // DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 { func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
if v, err := c.Int64(key); err != nil { v, err := c.Int64(key)
if err != nil {
return defaultval return defaultval
} else {
return v
} }
return v
} }
// Float returns the float value for a given key. // Float returns the float value for a given key.
@ -244,11 +245,11 @@ func (c *IniConfigContainer) Float(key string) (float64, error) {
// DefaultFloat returns the float64 value for a given key. // DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 { func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
if v, err := c.Float(key); err != nil { v, err := c.Float(key)
if err != nil {
return defaultval return defaultval
} else {
return v
} }
return v
} }
// String returns the string value for a given key. // String returns the string value for a given key.
@ -259,11 +260,11 @@ func (c *IniConfigContainer) String(key string) string {
// DefaultString returns the string value for a given key. // DefaultString returns the string value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *IniConfigContainer) DefaultString(key string, defaultval string) string { func (c *IniConfigContainer) DefaultString(key string, defaultval string) string {
if v := c.String(key); v == "" { v := c.String(key)
if v == "" {
return defaultval return defaultval
} else {
return v
} }
return v
} }
// Strings returns the []string value for a given key. // Strings returns the []string value for a given key.
@ -274,20 +275,19 @@ func (c *IniConfigContainer) Strings(key string) []string {
// DefaultStrings returns the []string value for a given key. // DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string { func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string {
if v := c.Strings(key); len(v) == 0 { v := c.Strings(key)
if len(v) == 0 {
return defaultval return defaultval
} else {
return v
} }
return v
} }
// GetSection returns map for the given section // GetSection returns map for the given section
func (c *IniConfigContainer) GetSection(section string) (map[string]string, error) { func (c *IniConfigContainer) GetSection(section string) (map[string]string, error) {
if v, ok := c.data[section]; ok { if v, ok := c.data[section]; ok {
return v, nil return v, nil
} else {
return nil, errors.New("not exist setction")
} }
return nil, errors.New("not exist setction")
} }
// SaveConfigFile save the config into file // SaveConfigFile save the config into file
@ -301,7 +301,7 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
// Save default section at first place // Save default section at first place
if dt, ok := c.data[DEFAULT_SECTION]; ok { if dt, ok := c.data[defaultSection]; ok {
for key, val := range dt { for key, val := range dt {
if key != " " { if key != " " {
// Write key comments. // Write key comments.
@ -325,7 +325,7 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
} }
// Save named sections // Save named sections
for section, dt := range c.data { for section, dt := range c.data {
if section != DEFAULT_SECTION { if section != defaultSection {
// Write section comments. // Write section comments.
if v, ok := c.sectionComment[section]; ok { if v, ok := c.sectionComment[section]; ok {
if _, err = buf.WriteString(string(bNumComment) + v + lineBreak); err != nil { if _, err = buf.WriteString(string(bNumComment) + v + lineBreak); err != nil {
@ -367,7 +367,7 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
return nil return nil
} }
// WriteValue writes a new value for key. // Set writes a new value for key.
// if write to one section, the key need be "section::key". // if write to one section, the key need be "section::key".
// if the section is not existed, it panics. // if the section is not existed, it panics.
func (c *IniConfigContainer) Set(key, value string) error { func (c *IniConfigContainer) Set(key, value string) error {
@ -379,14 +379,14 @@ func (c *IniConfigContainer) Set(key, value string) error {
var ( var (
section, k string section, k string
sectionKey []string = strings.Split(key, "::") sectionKey = strings.Split(key, "::")
) )
if len(sectionKey) >= 2 { if len(sectionKey) >= 2 {
section = sectionKey[0] section = sectionKey[0]
k = sectionKey[1] k = sectionKey[1]
} else { } else {
section = DEFAULT_SECTION section = defaultSection
k = sectionKey[0] k = sectionKey[0]
} }
@ -415,13 +415,13 @@ func (c *IniConfigContainer) getdata(key string) string {
var ( var (
section, k string section, k string
sectionKey []string = strings.Split(strings.ToLower(key), "::") sectionKey = strings.Split(strings.ToLower(key), "::")
) )
if len(sectionKey) >= 2 { if len(sectionKey) >= 2 {
section = sectionKey[0] section = sectionKey[0]
k = sectionKey[1] k = sectionKey[1]
} else { } else {
section = DEFAULT_SECTION section = defaultSection
k = sectionKey[0] k = sectionKey[0]
} }
if v, ok := c.data[section]; ok { if v, ok := c.data[section]; ok {

View File

@ -23,12 +23,12 @@ import (
"sync" "sync"
) )
// JsonConfig is a json config parser and implements Config interface. // JSONConfig is a json config parser and implements Config interface.
type JsonConfig struct { type JSONConfig struct {
} }
// Parse returns a ConfigContainer with parsed json config map. // Parse returns a ConfigContainer with parsed json config map.
func (js *JsonConfig) Parse(filename string) (ConfigContainer, error) { func (js *JSONConfig) Parse(filename string) (Configer, error) {
file, err := os.Open(filename) file, err := os.Open(filename)
if err != nil { if err != nil {
return nil, err return nil, err
@ -43,8 +43,8 @@ func (js *JsonConfig) Parse(filename string) (ConfigContainer, error) {
} }
// ParseData returns a ConfigContainer with json string // ParseData returns a ConfigContainer with json string
func (js *JsonConfig) ParseData(data []byte) (ConfigContainer, error) { func (js *JSONConfig) ParseData(data []byte) (Configer, error) {
x := &JsonConfigContainer{ x := &JSONConfigContainer{
data: make(map[string]interface{}), data: make(map[string]interface{}),
} }
err := json.Unmarshal(data, &x.data) err := json.Unmarshal(data, &x.data)
@ -59,15 +59,15 @@ func (js *JsonConfig) ParseData(data []byte) (ConfigContainer, error) {
return x, nil return x, nil
} }
// A Config represents the json configuration. // JSONConfigContainer A Config represents the json configuration.
// Only when get value, support key as section:name type. // Only when get value, support key as section:name type.
type JsonConfigContainer struct { type JSONConfigContainer struct {
data map[string]interface{} data map[string]interface{}
sync.RWMutex sync.RWMutex
} }
// Bool returns the boolean value for a given key. // Bool returns the boolean value for a given key.
func (c *JsonConfigContainer) Bool(key string) (bool, error) { func (c *JSONConfigContainer) Bool(key string) (bool, error) {
val := c.getData(key) val := c.getData(key)
if val != nil { if val != nil {
if v, ok := val.(bool); ok { if v, ok := val.(bool); ok {
@ -80,7 +80,7 @@ func (c *JsonConfigContainer) Bool(key string) (bool, error) {
// DefaultBool return the bool value if has no error // DefaultBool return the bool value if has no error
// otherwise return the defaultval // otherwise return the defaultval
func (c *JsonConfigContainer) DefaultBool(key string, defaultval bool) bool { func (c *JSONConfigContainer) DefaultBool(key string, defaultval bool) bool {
if v, err := c.Bool(key); err == nil { if v, err := c.Bool(key); err == nil {
return v return v
} }
@ -88,7 +88,7 @@ func (c *JsonConfigContainer) DefaultBool(key string, defaultval bool) bool {
} }
// Int returns the integer value for a given key. // Int returns the integer value for a given key.
func (c *JsonConfigContainer) Int(key string) (int, error) { func (c *JSONConfigContainer) Int(key string) (int, error) {
val := c.getData(key) val := c.getData(key)
if val != nil { if val != nil {
if v, ok := val.(float64); ok { if v, ok := val.(float64); ok {
@ -101,7 +101,7 @@ func (c *JsonConfigContainer) Int(key string) (int, error) {
// DefaultInt returns the integer value for a given key. // DefaultInt returns the integer value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *JsonConfigContainer) DefaultInt(key string, defaultval int) int { func (c *JSONConfigContainer) DefaultInt(key string, defaultval int) int {
if v, err := c.Int(key); err == nil { if v, err := c.Int(key); err == nil {
return v return v
} }
@ -109,7 +109,7 @@ func (c *JsonConfigContainer) DefaultInt(key string, defaultval int) int {
} }
// Int64 returns the int64 value for a given key. // Int64 returns the int64 value for a given key.
func (c *JsonConfigContainer) Int64(key string) (int64, error) { func (c *JSONConfigContainer) Int64(key string) (int64, error) {
val := c.getData(key) val := c.getData(key)
if val != nil { if val != nil {
if v, ok := val.(float64); ok { if v, ok := val.(float64); ok {
@ -122,7 +122,7 @@ func (c *JsonConfigContainer) Int64(key string) (int64, error) {
// DefaultInt64 returns the int64 value for a given key. // DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *JsonConfigContainer) DefaultInt64(key string, defaultval int64) int64 { func (c *JSONConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
if v, err := c.Int64(key); err == nil { if v, err := c.Int64(key); err == nil {
return v return v
} }
@ -130,7 +130,7 @@ func (c *JsonConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
} }
// Float returns the float value for a given key. // Float returns the float value for a given key.
func (c *JsonConfigContainer) Float(key string) (float64, error) { func (c *JSONConfigContainer) Float(key string) (float64, error) {
val := c.getData(key) val := c.getData(key)
if val != nil { if val != nil {
if v, ok := val.(float64); ok { if v, ok := val.(float64); ok {
@ -143,7 +143,7 @@ func (c *JsonConfigContainer) Float(key string) (float64, error) {
// DefaultFloat returns the float64 value for a given key. // DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *JsonConfigContainer) DefaultFloat(key string, defaultval float64) float64 { func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
if v, err := c.Float(key); err == nil { if v, err := c.Float(key); err == nil {
return v return v
} }
@ -151,7 +151,7 @@ func (c *JsonConfigContainer) DefaultFloat(key string, defaultval float64) float
} }
// String returns the string value for a given key. // String returns the string value for a given key.
func (c *JsonConfigContainer) String(key string) string { func (c *JSONConfigContainer) String(key string) string {
val := c.getData(key) val := c.getData(key)
if val != nil { if val != nil {
if v, ok := val.(string); ok { if v, ok := val.(string); ok {
@ -163,7 +163,7 @@ func (c *JsonConfigContainer) String(key string) string {
// DefaultString returns the string value for a given key. // DefaultString returns the string value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *JsonConfigContainer) DefaultString(key string, defaultval string) string { func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string {
// TODO FIXME should not use "" to replace non existance // TODO FIXME should not use "" to replace non existance
if v := c.String(key); v != "" { if v := c.String(key); v != "" {
return v return v
@ -172,7 +172,7 @@ func (c *JsonConfigContainer) DefaultString(key string, defaultval string) strin
} }
// Strings returns the []string value for a given key. // Strings returns the []string value for a given key.
func (c *JsonConfigContainer) Strings(key string) []string { func (c *JSONConfigContainer) Strings(key string) []string {
stringVal := c.String(key) stringVal := c.String(key)
if stringVal == "" { if stringVal == "" {
return []string{} return []string{}
@ -182,7 +182,7 @@ func (c *JsonConfigContainer) Strings(key string) []string {
// DefaultStrings returns the []string value for a given key. // DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *JsonConfigContainer) DefaultStrings(key string, defaultval []string) []string { func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string {
if v := c.Strings(key); len(v) > 0 { if v := c.Strings(key); len(v) > 0 {
return v return v
} }
@ -190,7 +190,7 @@ func (c *JsonConfigContainer) DefaultStrings(key string, defaultval []string) []
} }
// GetSection returns map for the given section // GetSection returns map for the given section
func (c *JsonConfigContainer) GetSection(section string) (map[string]string, error) { func (c *JSONConfigContainer) GetSection(section string) (map[string]string, error) {
if v, ok := c.data[section]; ok { if v, ok := c.data[section]; ok {
return v.(map[string]string), nil return v.(map[string]string), nil
} }
@ -198,7 +198,7 @@ func (c *JsonConfigContainer) GetSection(section string) (map[string]string, err
} }
// SaveConfigFile save the config into file // SaveConfigFile save the config into file
func (c *JsonConfigContainer) SaveConfigFile(filename string) (err error) { func (c *JSONConfigContainer) SaveConfigFile(filename string) (err error) {
// Write configuration file by filename. // Write configuration file by filename.
f, err := os.Create(filename) f, err := os.Create(filename)
if err != nil { if err != nil {
@ -214,7 +214,7 @@ func (c *JsonConfigContainer) SaveConfigFile(filename string) (err error) {
} }
// Set writes a new value for key. // Set writes a new value for key.
func (c *JsonConfigContainer) Set(key, val string) error { func (c *JSONConfigContainer) Set(key, val string) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
c.data[key] = val c.data[key] = val
@ -222,7 +222,7 @@ func (c *JsonConfigContainer) Set(key, val string) error {
} }
// DIY returns the raw value by a given key. // DIY returns the raw value by a given key.
func (c *JsonConfigContainer) DIY(key string) (v interface{}, err error) { func (c *JSONConfigContainer) DIY(key string) (v interface{}, err error) {
val := c.getData(key) val := c.getData(key)
if val != nil { if val != nil {
return val, nil return val, nil
@ -231,7 +231,7 @@ func (c *JsonConfigContainer) DIY(key string) (v interface{}, err error) {
} }
// section.key or key // section.key or key
func (c *JsonConfigContainer) getData(key string) interface{} { func (c *JSONConfigContainer) getData(key string) interface{} {
if len(key) == 0 { if len(key) == 0 {
return nil return nil
} }
@ -261,5 +261,5 @@ func (c *JsonConfigContainer) getData(key string) interface{} {
} }
func init() { func init() {
Register("json", &JsonConfig{}) Register("json", &JSONConfig{})
} }

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// package xml for config provider // Package xml for config provider
// //
// depend on github.com/beego/x2j // depend on github.com/beego/x2j
// //
@ -45,20 +45,20 @@ import (
"github.com/beego/x2j" "github.com/beego/x2j"
) )
// XmlConfig is a xml config parser and implements Config interface. // Config is a xml config parser and implements Config interface.
// xml configurations should be included in <config></config> tag. // xml configurations should be included in <config></config> tag.
// only support key/value pair as <key>value</key> as each item. // only support key/value pair as <key>value</key> as each item.
type XMLConfig struct{} type Config struct{}
// Parse returns a ConfigContainer with parsed xml config map. // Parse returns a ConfigContainer with parsed xml config map.
func (xc *XMLConfig) Parse(filename string) (config.ConfigContainer, error) { func (xc *Config) Parse(filename string) (config.Configer, error) {
file, err := os.Open(filename) file, err := os.Open(filename)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer file.Close() defer file.Close()
x := &XMLConfigContainer{data: make(map[string]interface{})} x := &ConfigContainer{data: make(map[string]interface{})}
content, err := ioutil.ReadAll(file) content, err := ioutil.ReadAll(file)
if err != nil { if err != nil {
return nil, err return nil, err
@ -73,84 +73,86 @@ func (xc *XMLConfig) Parse(filename string) (config.ConfigContainer, error) {
return x, nil return x, nil
} }
func (x *XMLConfig) ParseData(data []byte) (config.ConfigContainer, error) { // ParseData xml data
func (xc *Config) ParseData(data []byte) (config.Configer, error) {
// Save memory data to temporary file // Save memory data to temporary file
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond())) tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond()))
os.MkdirAll(path.Dir(tmpName), os.ModePerm) os.MkdirAll(path.Dir(tmpName), os.ModePerm)
if err := ioutil.WriteFile(tmpName, data, 0655); err != nil { if err := ioutil.WriteFile(tmpName, data, 0655); err != nil {
return nil, err return nil, err
} }
return x.Parse(tmpName) return xc.Parse(tmpName)
} }
// A Config represents the xml configuration. // ConfigContainer A Config represents the xml configuration.
type XMLConfigContainer struct { type ConfigContainer struct {
data map[string]interface{} data map[string]interface{}
sync.Mutex sync.Mutex
} }
// Bool returns the boolean value for a given key. // Bool returns the boolean value for a given key.
func (c *XMLConfigContainer) Bool(key string) (bool, error) { func (c *ConfigContainer) Bool(key string) (bool, error) {
return strconv.ParseBool(c.data[key].(string)) return strconv.ParseBool(c.data[key].(string))
} }
// DefaultBool return the bool value if has no error // DefaultBool return the bool value if has no error
// otherwise return the defaultval // otherwise return the defaultval
func (c *XMLConfigContainer) DefaultBool(key string, defaultval bool) bool { func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool {
if v, err := c.Bool(key); err != nil { v, err := c.Bool(key)
if err != nil {
return defaultval return defaultval
} else {
return v
} }
return v
} }
// Int returns the integer value for a given key. // Int returns the integer value for a given key.
func (c *XMLConfigContainer) Int(key string) (int, error) { func (c *ConfigContainer) Int(key string) (int, error) {
return strconv.Atoi(c.data[key].(string)) return strconv.Atoi(c.data[key].(string))
} }
// DefaultInt returns the integer value for a given key. // DefaultInt returns the integer value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *XMLConfigContainer) DefaultInt(key string, defaultval int) int { func (c *ConfigContainer) DefaultInt(key string, defaultval int) int {
if v, err := c.Int(key); err != nil { v, err := c.Int(key)
if err != nil {
return defaultval return defaultval
} else {
return v
} }
return v
} }
// Int64 returns the int64 value for a given key. // Int64 returns the int64 value for a given key.
func (c *XMLConfigContainer) Int64(key string) (int64, error) { func (c *ConfigContainer) Int64(key string) (int64, error) {
return strconv.ParseInt(c.data[key].(string), 10, 64) return strconv.ParseInt(c.data[key].(string), 10, 64)
} }
// DefaultInt64 returns the int64 value for a given key. // DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *XMLConfigContainer) DefaultInt64(key string, defaultval int64) int64 { func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
if v, err := c.Int64(key); err != nil { v, err := c.Int64(key)
if err != nil {
return defaultval return defaultval
} else {
return v
} }
return v
} }
// Float returns the float value for a given key. // Float returns the float value for a given key.
func (c *XMLConfigContainer) Float(key string) (float64, error) { func (c *ConfigContainer) Float(key string) (float64, error) {
return strconv.ParseFloat(c.data[key].(string), 64) return strconv.ParseFloat(c.data[key].(string), 64)
} }
// DefaultFloat returns the float64 value for a given key. // DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *XMLConfigContainer) DefaultFloat(key string, defaultval float64) float64 { func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
if v, err := c.Float(key); err != nil { v, err := c.Float(key)
if err != nil {
return defaultval return defaultval
} else {
return v
} }
return v
} }
// String returns the string value for a given key. // String returns the string value for a given key.
func (c *XMLConfigContainer) String(key string) string { func (c *ConfigContainer) String(key string) string {
if v, ok := c.data[key].(string); ok { if v, ok := c.data[key].(string); ok {
return v return v
} }
@ -159,40 +161,39 @@ func (c *XMLConfigContainer) String(key string) string {
// DefaultString returns the string value for a given key. // DefaultString returns the string value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *XMLConfigContainer) DefaultString(key string, defaultval string) string { func (c *ConfigContainer) DefaultString(key string, defaultval string) string {
if v := c.String(key); v == "" { v := c.String(key)
if v == "" {
return defaultval return defaultval
} else {
return v
} }
return v
} }
// Strings returns the []string value for a given key. // Strings returns the []string value for a given key.
func (c *XMLConfigContainer) Strings(key string) []string { func (c *ConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";") return strings.Split(c.String(key), ";")
} }
// DefaultStrings returns the []string value for a given key. // DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *XMLConfigContainer) DefaultStrings(key string, defaultval []string) []string { func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string {
if v := c.Strings(key); len(v) == 0 { v := c.Strings(key)
if len(v) == 0 {
return defaultval return defaultval
} else {
return v
} }
return v
} }
// GetSection returns map for the given section // GetSection returns map for the given section
func (c *XMLConfigContainer) GetSection(section string) (map[string]string, error) { func (c *ConfigContainer) GetSection(section string) (map[string]string, error) {
if v, ok := c.data[section]; ok { if v, ok := c.data[section]; ok {
return v.(map[string]string), nil return v.(map[string]string), nil
} else {
return nil, errors.New("not exist setction")
} }
return nil, errors.New("not exist setction")
} }
// SaveConfigFile save the config into file // SaveConfigFile save the config into file
func (c *XMLConfigContainer) SaveConfigFile(filename string) (err error) { func (c *ConfigContainer) SaveConfigFile(filename string) (err error) {
// Write configuration file by filename. // Write configuration file by filename.
f, err := os.Create(filename) f, err := os.Create(filename)
if err != nil { if err != nil {
@ -207,8 +208,8 @@ func (c *XMLConfigContainer) SaveConfigFile(filename string) (err error) {
return err return err
} }
// WriteValue writes a new value for key. // Set writes a new value for key.
func (c *XMLConfigContainer) Set(key, val string) error { func (c *ConfigContainer) Set(key, val string) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
c.data[key] = val c.data[key] = val
@ -216,7 +217,7 @@ func (c *XMLConfigContainer) Set(key, val string) error {
} }
// DIY returns the raw value by a given key. // DIY returns the raw value by a given key.
func (c *XMLConfigContainer) DIY(key string) (v interface{}, err error) { func (c *ConfigContainer) DIY(key string) (v interface{}, err error) {
if v, ok := c.data[key]; ok { if v, ok := c.data[key]; ok {
return v, nil return v, nil
} }
@ -224,5 +225,5 @@ func (c *XMLConfigContainer) DIY(key string) (v interface{}, err error) {
} }
func init() { func init() {
config.Register("xml", &XMLConfig{}) config.Register("xml", &Config{})
} }

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// package yaml for config provider // Package yaml for config provider
// //
// depend on github.com/beego/goyaml2 // depend on github.com/beego/goyaml2
// //
@ -46,22 +46,23 @@ import (
"github.com/beego/goyaml2" "github.com/beego/goyaml2"
) )
// YAMLConfig is a yaml config parser and implements Config interface. // Config is a yaml config parser and implements Config interface.
type YAMLConfig struct{} type Config struct{}
// Parse returns a ConfigContainer with parsed yaml config map. // Parse returns a ConfigContainer with parsed yaml config map.
func (yaml *YAMLConfig) Parse(filename string) (y config.ConfigContainer, err error) { func (yaml *Config) Parse(filename string) (y config.Configer, err error) {
cnf, err := ReadYmlReader(filename) cnf, err := ReadYmlReader(filename)
if err != nil { if err != nil {
return return
} }
y = &YAMLConfigContainer{ y = &ConfigContainer{
data: cnf, data: cnf,
} }
return return
} }
func (yaml *YAMLConfig) ParseData(data []byte) (config.ConfigContainer, error) { // ParseData parse yaml data
func (yaml *Config) ParseData(data []byte) (config.Configer, error) {
// Save memory data to temporary file // Save memory data to temporary file
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond())) tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond()))
os.MkdirAll(path.Dir(tmpName), os.ModePerm) os.MkdirAll(path.Dir(tmpName), os.ModePerm)
@ -71,7 +72,7 @@ func (yaml *YAMLConfig) ParseData(data []byte) (config.ConfigContainer, error) {
return yaml.Parse(tmpName) return yaml.Parse(tmpName)
} }
// Read yaml file to map. // ReadYmlReader Read yaml file to map.
// if json like, use json package, unless goyaml2 package. // if json like, use json package, unless goyaml2 package.
func ReadYmlReader(path string) (cnf map[string]interface{}, err error) { func ReadYmlReader(path string) (cnf map[string]interface{}, err error) {
f, err := os.Open(path) f, err := os.Open(path)
@ -112,14 +113,14 @@ func ReadYmlReader(path string) (cnf map[string]interface{}, err error) {
return return
} }
// A Config represents the yaml configuration. // ConfigContainer A Config represents the yaml configuration.
type YAMLConfigContainer struct { type ConfigContainer struct {
data map[string]interface{} data map[string]interface{}
sync.Mutex sync.Mutex
} }
// Bool returns the boolean value for a given key. // Bool returns the boolean value for a given key.
func (c *YAMLConfigContainer) Bool(key string) (bool, error) { func (c *ConfigContainer) Bool(key string) (bool, error) {
if v, ok := c.data[key].(bool); ok { if v, ok := c.data[key].(bool); ok {
return v, nil return v, nil
} }
@ -128,16 +129,16 @@ func (c *YAMLConfigContainer) Bool(key string) (bool, error) {
// DefaultBool return the bool value if has no error // DefaultBool return the bool value if has no error
// otherwise return the defaultval // otherwise return the defaultval
func (c *YAMLConfigContainer) DefaultBool(key string, defaultval bool) bool { func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool {
if v, err := c.Bool(key); err != nil { v, err := c.Bool(key)
if err != nil {
return defaultval return defaultval
} else {
return v
} }
return v
} }
// Int returns the integer value for a given key. // Int returns the integer value for a given key.
func (c *YAMLConfigContainer) Int(key string) (int, error) { func (c *ConfigContainer) Int(key string) (int, error) {
if v, ok := c.data[key].(int64); ok { if v, ok := c.data[key].(int64); ok {
return int(v), nil return int(v), nil
} }
@ -146,16 +147,16 @@ func (c *YAMLConfigContainer) Int(key string) (int, error) {
// DefaultInt returns the integer value for a given key. // DefaultInt returns the integer value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *YAMLConfigContainer) DefaultInt(key string, defaultval int) int { func (c *ConfigContainer) DefaultInt(key string, defaultval int) int {
if v, err := c.Int(key); err != nil { v, err := c.Int(key)
if err != nil {
return defaultval return defaultval
} else {
return v
} }
return v
} }
// Int64 returns the int64 value for a given key. // Int64 returns the int64 value for a given key.
func (c *YAMLConfigContainer) Int64(key string) (int64, error) { func (c *ConfigContainer) Int64(key string) (int64, error) {
if v, ok := c.data[key].(int64); ok { if v, ok := c.data[key].(int64); ok {
return v, nil return v, nil
} }
@ -164,16 +165,16 @@ func (c *YAMLConfigContainer) Int64(key string) (int64, error) {
// DefaultInt64 returns the int64 value for a given key. // DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *YAMLConfigContainer) DefaultInt64(key string, defaultval int64) int64 { func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
if v, err := c.Int64(key); err != nil { v, err := c.Int64(key)
if err != nil {
return defaultval return defaultval
} else {
return v
} }
return v
} }
// Float returns the float value for a given key. // Float returns the float value for a given key.
func (c *YAMLConfigContainer) Float(key string) (float64, error) { func (c *ConfigContainer) Float(key string) (float64, error) {
if v, ok := c.data[key].(float64); ok { if v, ok := c.data[key].(float64); ok {
return v, nil return v, nil
} }
@ -182,16 +183,16 @@ func (c *YAMLConfigContainer) Float(key string) (float64, error) {
// DefaultFloat returns the float64 value for a given key. // DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *YAMLConfigContainer) DefaultFloat(key string, defaultval float64) float64 { func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
if v, err := c.Float(key); err != nil { v, err := c.Float(key)
if err != nil {
return defaultval return defaultval
} else {
return v
} }
return v
} }
// String returns the string value for a given key. // String returns the string value for a given key.
func (c *YAMLConfigContainer) String(key string) string { func (c *ConfigContainer) String(key string) string {
if v, ok := c.data[key].(string); ok { if v, ok := c.data[key].(string); ok {
return v return v
} }
@ -200,40 +201,40 @@ func (c *YAMLConfigContainer) String(key string) string {
// DefaultString returns the string value for a given key. // DefaultString returns the string value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *YAMLConfigContainer) DefaultString(key string, defaultval string) string { func (c *ConfigContainer) DefaultString(key string, defaultval string) string {
if v := c.String(key); v == "" { v := c.String(key)
if v == "" {
return defaultval return defaultval
} else {
return v
} }
return v
} }
// Strings returns the []string value for a given key. // Strings returns the []string value for a given key.
func (c *YAMLConfigContainer) Strings(key string) []string { func (c *ConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";") return strings.Split(c.String(key), ";")
} }
// DefaultStrings returns the []string value for a given key. // DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval // if err != nil return defaltval
func (c *YAMLConfigContainer) DefaultStrings(key string, defaultval []string) []string { func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string {
if v := c.Strings(key); len(v) == 0 { v := c.Strings(key)
if len(v) == 0 {
return defaultval return defaultval
} else {
return v
} }
return v
} }
// GetSection returns map for the given section // GetSection returns map for the given section
func (c *YAMLConfigContainer) GetSection(section string) (map[string]string, error) { func (c *ConfigContainer) GetSection(section string) (map[string]string, error) {
if v, ok := c.data[section]; ok { v, ok := c.data[section]
if ok {
return v.(map[string]string), nil return v.(map[string]string), nil
} else {
return nil, errors.New("not exist setction")
} }
return nil, errors.New("not exist setction")
} }
// SaveConfigFile save the config into file // SaveConfigFile save the config into file
func (c *YAMLConfigContainer) SaveConfigFile(filename string) (err error) { func (c *ConfigContainer) SaveConfigFile(filename string) (err error) {
// Write configuration file by filename. // Write configuration file by filename.
f, err := os.Create(filename) f, err := os.Create(filename)
if err != nil { if err != nil {
@ -244,8 +245,8 @@ func (c *YAMLConfigContainer) SaveConfigFile(filename string) (err error) {
return err return err
} }
// WriteValue writes a new value for key. // Set writes a new value for key.
func (c *YAMLConfigContainer) Set(key, val string) error { func (c *ConfigContainer) Set(key, val string) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
c.data[key] = val c.data[key] = val
@ -253,7 +254,7 @@ func (c *YAMLConfigContainer) Set(key, val string) error {
} }
// DIY returns the raw value by a given key. // DIY returns the raw value by a given key.
func (c *YAMLConfigContainer) DIY(key string) (v interface{}, err error) { func (c *ConfigContainer) DIY(key string) (v interface{}, err error) {
if v, ok := c.data[key]; ok { if v, ok := c.data[key]; ok {
return v, nil return v, nil
} }
@ -261,5 +262,5 @@ func (c *YAMLConfigContainer) DIY(key string) (v interface{}, err error) {
} }
func init() { func init() {
config.Register("yaml", &YAMLConfig{}) config.Register("yaml", &Config{})
} }

View File

@ -19,11 +19,11 @@ import (
) )
func TestDefaults(t *testing.T) { func TestDefaults(t *testing.T) {
if FlashName != "BEEGO_FLASH" { if BConfig.WebConfig.FlashName != "BEEGO_FLASH" {
t.Errorf("FlashName was not set to default.") t.Errorf("FlashName was not set to default.")
} }
if FlashSeperator != "BEEGOFLASH" { if BConfig.WebConfig.FlashSeparator != "BEEGOFLASH" {
t.Errorf("FlashName was not set to default.") t.Errorf("FlashName was not set to default.")
} }
} }

198
context/acceptencoder.go Normal file
View File

@ -0,0 +1,198 @@
// Copyright 2015 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package context
import (
"bytes"
"compress/flate"
"compress/gzip"
"compress/zlib"
"io"
"net/http"
"os"
"strconv"
"strings"
"sync"
)
type resetWriter interface {
io.Writer
Reset(w io.Writer)
}
type nopResetWriter struct {
io.Writer
}
func (n nopResetWriter) Reset(w io.Writer) {
//do nothing
}
type acceptEncoder struct {
name string
levelEncode func(int) resetWriter
bestSpeedPool *sync.Pool
bestCompressionPool *sync.Pool
}
func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter {
if ac.bestSpeedPool == nil || ac.bestCompressionPool == nil {
return nopResetWriter{wr}
}
var rwr resetWriter
switch level {
case flate.BestSpeed:
rwr = ac.bestSpeedPool.Get().(resetWriter)
case flate.BestCompression:
rwr = ac.bestCompressionPool.Get().(resetWriter)
default:
rwr = ac.levelEncode(level)
}
rwr.Reset(wr)
return rwr
}
func (ac acceptEncoder) put(wr resetWriter, level int) {
if ac.bestSpeedPool == nil || ac.bestCompressionPool == nil {
return
}
wr.Reset(nil)
switch level {
case flate.BestSpeed:
ac.bestSpeedPool.Put(wr)
case flate.BestCompression:
ac.bestCompressionPool.Put(wr)
}
}
var (
noneCompressEncoder = acceptEncoder{"", nil, nil, nil}
gzipCompressEncoder = acceptEncoder{"gzip",
func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr },
&sync.Pool{
New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestSpeed); return wr },
},
&sync.Pool{
New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr },
},
}
//according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed
//deflate
//The "zlib" format defined in RFC 1950 [31] in combination with
//the "deflate" compression mechanism described in RFC 1951 [29].
deflateCompressEncoder = acceptEncoder{"deflate",
func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr },
&sync.Pool{
New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestSpeed); return wr },
},
&sync.Pool{
New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr },
},
}
)
var (
encoderMap = map[string]acceptEncoder{ // all the other compress methods will ignore
"gzip": gzipCompressEncoder,
"deflate": deflateCompressEncoder,
"*": gzipCompressEncoder, // * means any compress will accept,we prefer gzip
"identity": noneCompressEncoder, // identity means none-compress
}
)
// WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate)
func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) {
return writeLevel(encoding, writer, file, flate.BestCompression)
}
// WriteBody reads writes content to writer by the specific encoding(gzip/deflate)
func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) {
return writeLevel(encoding, writer, bytes.NewReader(content), flate.BestSpeed)
}
// writeLevel reads from reader,writes to writer by specific encoding and compress level
// the compress level is defined by deflate package
func writeLevel(encoding string, writer io.Writer, reader io.Reader, level int) (bool, string, error) {
var outputWriter resetWriter
var err error
var ce = noneCompressEncoder
if cf, ok := encoderMap[encoding]; ok {
ce = cf
}
encoding = ce.name
outputWriter = ce.encode(writer, level)
defer ce.put(outputWriter, level)
_, err = io.Copy(outputWriter, reader)
if err != nil {
return false, "", err
}
switch outputWriter.(type) {
case io.WriteCloser:
outputWriter.(io.WriteCloser).Close()
}
return encoding != "", encoding, nil
}
// ParseEncoding will extract the right encoding for response
// the Accept-Encoding's sec is here:
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.3
func ParseEncoding(r *http.Request) string {
if r == nil {
return ""
}
return parseEncoding(r)
}
type q struct {
name string
value float64
}
func parseEncoding(r *http.Request) string {
acceptEncoding := r.Header.Get("Accept-Encoding")
if acceptEncoding == "" {
return ""
}
var lastQ q
for _, v := range strings.Split(acceptEncoding, ",") {
v = strings.TrimSpace(v)
if v == "" {
continue
}
vs := strings.Split(v, ";")
if len(vs) == 1 {
lastQ = q{vs[0], 1}
break
}
if len(vs) == 2 {
f, _ := strconv.ParseFloat(strings.Replace(vs[1], "q=", "", -1), 64)
if f == 0 {
continue
}
if f > lastQ.value {
lastQ = q{vs[0], f}
}
}
}
if cf, ok := encoderMap[lastQ.name]; ok {
return cf.name
} else {
return ""
}
}

View File

@ -0,0 +1,45 @@
// Copyright 2015 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package context
import (
"net/http"
"testing"
)
func Test_ExtractEncoding(t *testing.T) {
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"gzip,deflate"}}}) != "gzip" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"deflate,gzip"}}}) != "deflate" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"gzip;q=.5,deflate"}}}) != "deflate" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"gzip;q=.5,deflate;q=0.3"}}}) != "gzip" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"gzip;q=0,deflate"}}}) != "deflate" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"deflate;q=0.5,gzip;q=0.5,identity"}}}) != "" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": []string{"*"}}}) != "gzip" {
t.Fail()
}
}

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Package context provide the context utils
// Usage: // Usage:
// //
// import "github.com/astaxie/beego/context" // import "github.com/astaxie/beego/context"
@ -22,10 +23,13 @@
package context package context
import ( import (
"bufio"
"crypto/hmac" "crypto/hmac"
"crypto/sha1" "crypto/sha1"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"net"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@ -34,14 +38,30 @@ import (
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
// Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter. // NewContext return the Context with Input and Output
func NewContext() *Context {
return &Context{
Input: NewInput(),
Output: NewOutput(),
}
}
// Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter.
// BeegoInput and BeegoOutput provides some api to operate request and response more easily. // BeegoInput and BeegoOutput provides some api to operate request and response more easily.
type Context struct { type Context struct {
Input *BeegoInput Input *BeegoInput
Output *BeegoOutput Output *BeegoOutput
Request *http.Request Request *http.Request
ResponseWriter http.ResponseWriter ResponseWriter *Response
_xsrf_token string _xsrfToken string
}
// Reset init Context, BeegoInput and BeegoOutput
func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) {
ctx.Request = r
ctx.ResponseWriter = &Response{rw, false, 0}
ctx.Input.Reset(ctx)
ctx.Output.Reset(ctx)
} }
// Redirect does redirection to localurl with http header status code. // Redirect does redirection to localurl with http header status code.
@ -54,29 +74,28 @@ func (ctx *Context) Redirect(status int, localurl string) {
// Abort stops this request. // Abort stops this request.
// if beego.ErrorMaps exists, panic body. // if beego.ErrorMaps exists, panic body.
func (ctx *Context) Abort(status int, body string) { func (ctx *Context) Abort(status int, body string) {
ctx.ResponseWriter.WriteHeader(status)
panic(body) panic(body)
} }
// Write string to response body. // WriteString Write string to response body.
// it sends response body. // it sends response body.
func (ctx *Context) WriteString(content string) { func (ctx *Context) WriteString(content string) {
ctx.ResponseWriter.Write([]byte(content)) ctx.ResponseWriter.Write([]byte(content))
} }
// Get cookie from request by a given key. // GetCookie Get cookie from request by a given key.
// It's alias of BeegoInput.Cookie. // It's alias of BeegoInput.Cookie.
func (ctx *Context) GetCookie(key string) string { func (ctx *Context) GetCookie(key string) string {
return ctx.Input.Cookie(key) return ctx.Input.Cookie(key)
} }
// Set cookie for response. // SetCookie Set cookie for response.
// It's alias of BeegoOutput.Cookie. // It's alias of BeegoOutput.Cookie.
func (ctx *Context) SetCookie(name string, value string, others ...interface{}) { func (ctx *Context) SetCookie(name string, value string, others ...interface{}) {
ctx.Output.Cookie(name, value, others...) ctx.Output.Cookie(name, value, others...)
} }
// Get secure cookie from request by a given key. // GetSecureCookie Get secure cookie from request by a given key.
func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) { func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) {
val := ctx.Input.Cookie(key) val := ctx.Input.Cookie(key)
if val == "" { if val == "" {
@ -103,7 +122,7 @@ func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) {
return string(res), true return string(res), true
} }
// Set Secure cookie for response. // SetSecureCookie Set Secure cookie for response.
func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) { func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) {
vs := base64.URLEncoding.EncodeToString([]byte(value)) vs := base64.URLEncoding.EncodeToString([]byte(value))
timestamp := strconv.FormatInt(time.Now().UnixNano(), 10) timestamp := strconv.FormatInt(time.Now().UnixNano(), 10)
@ -114,23 +133,23 @@ func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interf
ctx.Output.Cookie(name, cookie, others...) ctx.Output.Cookie(name, cookie, others...)
} }
// XsrfToken creates a xsrf token string and returns. // XSRFToken creates a xsrf token string and returns.
func (ctx *Context) XsrfToken(key string, expire int64) string { func (ctx *Context) XSRFToken(key string, expire int64) string {
if ctx._xsrf_token == "" { if ctx._xsrfToken == "" {
token, ok := ctx.GetSecureCookie(key, "_xsrf") token, ok := ctx.GetSecureCookie(key, "_xsrf")
if !ok { if !ok {
token = string(utils.RandomCreateBytes(32)) token = string(utils.RandomCreateBytes(32))
ctx.SetSecureCookie(key, "_xsrf", token, expire) ctx.SetSecureCookie(key, "_xsrf", token, expire)
} }
ctx._xsrf_token = token ctx._xsrfToken = token
} }
return ctx._xsrf_token return ctx._xsrfToken
} }
// CheckXsrfCookie checks xsrf token in this request is valid or not. // CheckXSRFCookie checks xsrf token in this request is valid or not.
// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" // the token can provided in request header "X-Xsrftoken" and "X-CsrfToken"
// or in form field value named as "_xsrf". // or in form field value named as "_xsrf".
func (ctx *Context) CheckXsrfCookie() bool { func (ctx *Context) CheckXSRFCookie() bool {
token := ctx.Input.Query("_xsrf") token := ctx.Input.Query("_xsrf")
if token == "" { if token == "" {
token = ctx.Request.Header.Get("X-Xsrftoken") token = ctx.Request.Header.Get("X-Xsrftoken")
@ -142,9 +161,57 @@ func (ctx *Context) CheckXsrfCookie() bool {
ctx.Abort(403, "'_xsrf' argument missing from POST") ctx.Abort(403, "'_xsrf' argument missing from POST")
return false return false
} }
if ctx._xsrf_token != token { if ctx._xsrfToken != token {
ctx.Abort(403, "XSRF cookie does not match POST argument") ctx.Abort(403, "XSRF cookie does not match POST argument")
return false return false
} }
return true return true
} }
//Response is a wrapper for the http.ResponseWriter
//started set to true if response was written to then don't execute other handler
type Response struct {
http.ResponseWriter
Started bool
Status int
}
// Write writes the data to the connection as part of an HTTP reply,
// and sets `started` to true.
// started means the response has sent out.
func (w *Response) Write(p []byte) (int, error) {
w.Started = true
return w.ResponseWriter.Write(p)
}
// WriteHeader sends an HTTP response header with status code,
// and sets `started` to true.
func (w *Response) WriteHeader(code int) {
w.Status = code
w.Started = true
w.ResponseWriter.WriteHeader(code)
}
// Hijack hijacker for http
func (w *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hj, ok := w.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, errors.New("webserver doesn't support hijacking")
}
return hj.Hijack()
}
// Flush http.Flusher
func (w *Response) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
// CloseNotify http.CloseNotifier
func (w *Response) CloseNotify() <-chan bool {
if cn, ok := w.ResponseWriter.(http.CloseNotifier); ok {
return cn.CloseNotify()
}
return nil
}

View File

@ -17,8 +17,8 @@ package context
import ( import (
"bytes" "bytes"
"errors" "errors"
"io"
"io/ioutil" "io/ioutil"
"net/http"
"net/url" "net/url"
"reflect" "reflect"
"regexp" "regexp"
@ -31,45 +31,55 @@ import (
// Regexes for checking the accept headers // Regexes for checking the accept headers
// TODO make sure these are correct // TODO make sure these are correct
var ( var (
acceptsHtmlRegex = regexp.MustCompile(`(text/html|application/xhtml\+xml)(?:,|$)`) acceptsHTMLRegex = regexp.MustCompile(`(text/html|application/xhtml\+xml)(?:,|$)`)
acceptsXmlRegex = regexp.MustCompile(`(application/xml|text/xml)(?:,|$)`) acceptsXMLRegex = regexp.MustCompile(`(application/xml|text/xml)(?:,|$)`)
acceptsJsonRegex = regexp.MustCompile(`(application/json)(?:,|$)`) acceptsJSONRegex = regexp.MustCompile(`(application/json)(?:,|$)`)
maxParam = 50
) )
// BeegoInput operates the http request header, data, cookie and body. // BeegoInput operates the http request header, data, cookie and body.
// it also contains router params and current session. // it also contains router params and current session.
type BeegoInput struct { type BeegoInput struct {
CruSession session.SessionStore Context *Context
Params map[string]string CruSession session.Store
Data map[interface{}]interface{} // store some values in this context when calling context in filter or controller. pnames []string
Request *http.Request pvalues []string
RequestBody []byte data map[interface{}]interface{} // store some values in this context when calling context in filter or controller.
RunController reflect.Type RequestBody []byte
RunMethod string
} }
// NewInput return BeegoInput generated by http.Request. // NewInput return BeegoInput generated by Context.
func NewInput(req *http.Request) *BeegoInput { func NewInput() *BeegoInput {
return &BeegoInput{ return &BeegoInput{
Params: make(map[string]string), pnames: make([]string, 0, maxParam),
Data: make(map[interface{}]interface{}), pvalues: make([]string, 0, maxParam),
Request: req, data: make(map[interface{}]interface{}),
} }
} }
// Reset init the BeegoInput
func (input *BeegoInput) Reset(ctx *Context) {
input.Context = ctx
input.CruSession = nil
input.pnames = input.pnames[:0]
input.pvalues = input.pvalues[:0]
input.data = nil
input.RequestBody = []byte{}
}
// Protocol returns request protocol name, such as HTTP/1.1 . // Protocol returns request protocol name, such as HTTP/1.1 .
func (input *BeegoInput) Protocol() string { func (input *BeegoInput) Protocol() string {
return input.Request.Proto return input.Context.Request.Proto
} }
// Uri returns full request url with query string, fragment. // URI returns full request url with query string, fragment.
func (input *BeegoInput) Uri() string { func (input *BeegoInput) URI() string {
return input.Request.RequestURI return input.Context.Request.RequestURI
} }
// Url returns request url path (without query string, fragment). // URL returns request url path (without query string, fragment).
func (input *BeegoInput) Url() string { func (input *BeegoInput) URL() string {
return input.Request.URL.Path return input.Context.Request.URL.Path
} }
// Site returns base site url as scheme://domain type. // Site returns base site url as scheme://domain type.
@ -79,10 +89,10 @@ func (input *BeegoInput) Site() string {
// Scheme returns request scheme as "http" or "https". // Scheme returns request scheme as "http" or "https".
func (input *BeegoInput) Scheme() string { func (input *BeegoInput) Scheme() string {
if input.Request.URL.Scheme != "" { if input.Context.Request.URL.Scheme != "" {
return input.Request.URL.Scheme return input.Context.Request.URL.Scheme
} }
if input.Request.TLS == nil { if input.Context.Request.TLS == nil {
return "http" return "http"
} }
return "https" return "https"
@ -97,19 +107,19 @@ func (input *BeegoInput) Domain() string {
// Host returns host name. // Host returns host name.
// if no host info in request, return localhost. // if no host info in request, return localhost.
func (input *BeegoInput) Host() string { func (input *BeegoInput) Host() string {
if input.Request.Host != "" { if input.Context.Request.Host != "" {
hostParts := strings.Split(input.Request.Host, ":") hostParts := strings.Split(input.Context.Request.Host, ":")
if len(hostParts) > 0 { if len(hostParts) > 0 {
return hostParts[0] return hostParts[0]
} }
return input.Request.Host return input.Context.Request.Host
} }
return "localhost" return "localhost"
} }
// Method returns http request method. // Method returns http request method.
func (input *BeegoInput) Method() string { func (input *BeegoInput) Method() string {
return input.Request.Method return input.Context.Request.Method
} }
// Is returns boolean of this request is on given method, such as Is("POST"). // Is returns boolean of this request is on given method, such as Is("POST").
@ -117,37 +127,37 @@ func (input *BeegoInput) Is(method string) bool {
return input.Method() == method return input.Method() == method
} }
// Is this a GET method request? // IsGet Is this a GET method request?
func (input *BeegoInput) IsGet() bool { func (input *BeegoInput) IsGet() bool {
return input.Is("GET") return input.Is("GET")
} }
// Is this a POST method request? // IsPost Is this a POST method request?
func (input *BeegoInput) IsPost() bool { func (input *BeegoInput) IsPost() bool {
return input.Is("POST") return input.Is("POST")
} }
// Is this a Head method request? // IsHead Is this a Head method request?
func (input *BeegoInput) IsHead() bool { func (input *BeegoInput) IsHead() bool {
return input.Is("HEAD") return input.Is("HEAD")
} }
// Is this a OPTIONS method request? // IsOptions Is this a OPTIONS method request?
func (input *BeegoInput) IsOptions() bool { func (input *BeegoInput) IsOptions() bool {
return input.Is("OPTIONS") return input.Is("OPTIONS")
} }
// Is this a PUT method request? // IsPut Is this a PUT method request?
func (input *BeegoInput) IsPut() bool { func (input *BeegoInput) IsPut() bool {
return input.Is("PUT") return input.Is("PUT")
} }
// Is this a DELETE method request? // IsDelete Is this a DELETE method request?
func (input *BeegoInput) IsDelete() bool { func (input *BeegoInput) IsDelete() bool {
return input.Is("DELETE") return input.Is("DELETE")
} }
// Is this a PATCH method request? // IsPatch Is this a PATCH method request?
func (input *BeegoInput) IsPatch() bool { func (input *BeegoInput) IsPatch() bool {
return input.Is("PATCH") return input.Is("PATCH")
} }
@ -172,19 +182,19 @@ func (input *BeegoInput) IsUpload() bool {
return strings.Contains(input.Header("Content-Type"), "multipart/form-data") return strings.Contains(input.Header("Content-Type"), "multipart/form-data")
} }
// Checks if request accepts html response // AcceptsHTML Checks if request accepts html response
func (input *BeegoInput) AcceptsHtml() bool { func (input *BeegoInput) AcceptsHTML() bool {
return acceptsHtmlRegex.MatchString(input.Header("Accept")) return acceptsHTMLRegex.MatchString(input.Header("Accept"))
} }
// Checks if request accepts xml response // AcceptsXML Checks if request accepts xml response
func (input *BeegoInput) AcceptsXml() bool { func (input *BeegoInput) AcceptsXML() bool {
return acceptsXmlRegex.MatchString(input.Header("Accept")) return acceptsXMLRegex.MatchString(input.Header("Accept"))
} }
// Checks if request accepts json response // AcceptsJSON Checks if request accepts json response
func (input *BeegoInput) AcceptsJson() bool { func (input *BeegoInput) AcceptsJSON() bool {
return acceptsJsonRegex.MatchString(input.Header("Accept")) return acceptsJSONRegex.MatchString(input.Header("Accept"))
} }
// IP returns request client ip. // IP returns request client ip.
@ -196,7 +206,7 @@ func (input *BeegoInput) IP() string {
rip := strings.Split(ips[0], ":") rip := strings.Split(ips[0], ":")
return rip[0] return rip[0]
} }
ip := strings.Split(input.Request.RemoteAddr, ":") ip := strings.Split(input.Context.Request.RemoteAddr, ":")
if len(ip) > 0 { if len(ip) > 0 {
if ip[0] != "[" { if ip[0] != "[" {
return ip[0] return ip[0]
@ -236,7 +246,7 @@ func (input *BeegoInput) SubDomains() string {
// Port returns request client port. // Port returns request client port.
// when error or empty, return 80. // when error or empty, return 80.
func (input *BeegoInput) Port() int { func (input *BeegoInput) Port() int {
parts := strings.Split(input.Request.Host, ":") parts := strings.Split(input.Context.Request.Host, ":")
if len(parts) == 2 { if len(parts) == 2 {
port, _ := strconv.Atoi(parts[1]) port, _ := strconv.Atoi(parts[1])
return port return port
@ -249,35 +259,59 @@ func (input *BeegoInput) UserAgent() string {
return input.Header("User-Agent") return input.Header("User-Agent")
} }
// ParamsLen return the length of the params
func (input *BeegoInput) ParamsLen() int {
return len(input.pnames)
}
// Param returns router param by a given key. // Param returns router param by a given key.
func (input *BeegoInput) Param(key string) string { func (input *BeegoInput) Param(key string) string {
if v, ok := input.Params[key]; ok { for i, v := range input.pnames {
return v if v == key && i <= len(input.pvalues) {
return input.pvalues[i]
}
} }
return "" return ""
} }
// Params returns the map[key]value.
func (input *BeegoInput) Params() map[string]string {
m := make(map[string]string)
for i, v := range input.pnames {
if i <= len(input.pvalues) {
m[v] = input.pvalues[i]
}
}
return m
}
// SetParam will set the param with key and value
func (input *BeegoInput) SetParam(key, val string) {
input.pvalues = append(input.pvalues, val)
input.pnames = append(input.pnames, key)
}
// Query returns input data item string by a given string. // Query returns input data item string by a given string.
func (input *BeegoInput) Query(key string) string { func (input *BeegoInput) Query(key string) string {
if val := input.Param(key); val != "" { if val := input.Param(key); val != "" {
return val return val
} }
if input.Request.Form == nil { if input.Context.Request.Form == nil {
input.Request.ParseForm() input.Context.Request.ParseForm()
} }
return input.Request.Form.Get(key) return input.Context.Request.Form.Get(key)
} }
// Header returns request header item string by a given string. // Header returns request header item string by a given string.
// if non-existed, return empty string. // if non-existed, return empty string.
func (input *BeegoInput) Header(key string) string { func (input *BeegoInput) Header(key string) string {
return input.Request.Header.Get(key) return input.Context.Request.Header.Get(key)
} }
// Cookie returns request cookie item string by a given key. // Cookie returns request cookie item string by a given key.
// if non-existed, return empty string. // if non-existed, return empty string.
func (input *BeegoInput) Cookie(key string) string { func (input *BeegoInput) Cookie(key string) string {
ck, err := input.Request.Cookie(key) ck, err := input.Context.Request.Cookie(key)
if err != nil { if err != nil {
return "" return ""
} }
@ -291,18 +325,27 @@ func (input *BeegoInput) Session(key interface{}) interface{} {
} }
// CopyBody returns the raw request body data as bytes. // CopyBody returns the raw request body data as bytes.
func (input *BeegoInput) CopyBody() []byte { func (input *BeegoInput) CopyBody(MaxMemory int64) []byte {
requestbody, _ := ioutil.ReadAll(input.Request.Body) safe := &io.LimitedReader{R: input.Context.Request.Body, N: MaxMemory}
input.Request.Body.Close() requestbody, _ := ioutil.ReadAll(safe)
input.Context.Request.Body.Close()
bf := bytes.NewBuffer(requestbody) bf := bytes.NewBuffer(requestbody)
input.Request.Body = ioutil.NopCloser(bf) input.Context.Request.Body = ioutil.NopCloser(bf)
input.RequestBody = requestbody input.RequestBody = requestbody
return requestbody return requestbody
} }
// Data return the implicit data in the input
func (input *BeegoInput) Data() map[interface{}]interface{} {
if input.data == nil {
input.data = make(map[interface{}]interface{})
}
return input.data
}
// GetData returns the stored data in this context. // GetData returns the stored data in this context.
func (input *BeegoInput) GetData(key interface{}) interface{} { func (input *BeegoInput) GetData(key interface{}) interface{} {
if v, ok := input.Data[key]; ok { if v, ok := input.data[key]; ok {
return v return v
} }
return nil return nil
@ -311,17 +354,20 @@ func (input *BeegoInput) GetData(key interface{}) interface{} {
// SetData stores data with given key in this context. // SetData stores data with given key in this context.
// This data are only available in this context. // This data are only available in this context.
func (input *BeegoInput) SetData(key, val interface{}) { func (input *BeegoInput) SetData(key, val interface{}) {
input.Data[key] = val if input.data == nil {
input.data = make(map[interface{}]interface{})
}
input.data[key] = val
} }
// parseForm or parseMultiForm based on Content-type // ParseFormOrMulitForm parseForm or parseMultiForm based on Content-type
func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error { func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error {
// Parse the body depending on the content type. // Parse the body depending on the content type.
if strings.Contains(input.Header("Content-Type"), "multipart/form-data") { if strings.Contains(input.Header("Content-Type"), "multipart/form-data") {
if err := input.Request.ParseMultipartForm(maxMemory); err != nil { if err := input.Context.Request.ParseMultipartForm(maxMemory); err != nil {
return errors.New("Error parsing request body:" + err.Error()) return errors.New("Error parsing request body:" + err.Error())
} }
} else if err := input.Request.ParseForm(); err != nil { } else if err := input.Context.Request.ParseForm(); err != nil {
return errors.New("Error parsing request body:" + err.Error()) return errors.New("Error parsing request body:" + err.Error())
} }
return nil return nil
@ -353,7 +399,7 @@ func (input *BeegoInput) Bind(dest interface{}, key string) error {
} }
func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value { func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value {
rv := reflect.Zero(reflect.TypeOf(0)) rv := reflect.Zero(typ)
switch typ.Kind() { switch typ.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
val := input.Query(key) val := input.Query(key)
@ -386,19 +432,19 @@ func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value {
} }
rv = input.bindBool(val, typ) rv = input.bindBool(val, typ)
case reflect.Slice: case reflect.Slice:
rv = input.bindSlice(&input.Request.Form, key, typ) rv = input.bindSlice(&input.Context.Request.Form, key, typ)
case reflect.Struct: case reflect.Struct:
rv = input.bindStruct(&input.Request.Form, key, typ) rv = input.bindStruct(&input.Context.Request.Form, key, typ)
case reflect.Ptr: case reflect.Ptr:
rv = input.bindPoint(key, typ) rv = input.bindPoint(key, typ)
case reflect.Map: case reflect.Map:
rv = input.bindMap(&input.Request.Form, key, typ) rv = input.bindMap(&input.Context.Request.Form, key, typ)
} }
return rv return rv
} }
func (input *BeegoInput) bindValue(val string, typ reflect.Type) reflect.Value { func (input *BeegoInput) bindValue(val string, typ reflect.Type) reflect.Value {
rv := reflect.Zero(reflect.TypeOf(0)) rv := reflect.Zero(typ)
switch typ.Kind() { switch typ.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
rv = input.bindInt(val, typ) rv = input.bindInt(val, typ)

View File

@ -17,12 +17,15 @@ package context
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
) )
func TestParse(t *testing.T) { func TestParse(t *testing.T) {
r, _ := http.NewRequest("GET", "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil) r, _ := http.NewRequest("GET", "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil)
beegoInput := NewInput(r) beegoInput := NewInput()
beegoInput.Context = NewContext()
beegoInput.Context.Reset(httptest.NewRecorder(), r)
beegoInput.ParseFormOrMulitForm(1 << 20) beegoInput.ParseFormOrMulitForm(1 << 20)
var id int var id int
@ -73,7 +76,9 @@ func TestParse(t *testing.T) {
func TestSubDomain(t *testing.T) { func TestSubDomain(t *testing.T) {
r, _ := http.NewRequest("GET", "http://www.example.com/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil) r, _ := http.NewRequest("GET", "http://www.example.com/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil)
beegoInput := NewInput(r) beegoInput := NewInput()
beegoInput.Context = NewContext()
beegoInput.Context.Reset(httptest.NewRecorder(), r)
subdomain := beegoInput.SubDomains() subdomain := beegoInput.SubDomains()
if subdomain != "www" { if subdomain != "www" {
@ -81,13 +86,13 @@ func TestSubDomain(t *testing.T) {
} }
r, _ = http.NewRequest("GET", "http://localhost/", nil) r, _ = http.NewRequest("GET", "http://localhost/", nil)
beegoInput.Request = r beegoInput.Context.Request = r
if beegoInput.SubDomains() != "" { if beegoInput.SubDomains() != "" {
t.Fatal("Subdomain parse error, should be empty, got " + beegoInput.SubDomains()) t.Fatal("Subdomain parse error, should be empty, got " + beegoInput.SubDomains())
} }
r, _ = http.NewRequest("GET", "http://aa.bb.example.com/", nil) r, _ = http.NewRequest("GET", "http://aa.bb.example.com/", nil)
beegoInput.Request = r beegoInput.Context.Request = r
if beegoInput.SubDomains() != "aa.bb" { if beegoInput.SubDomains() != "aa.bb" {
t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains())
} }
@ -101,13 +106,13 @@ func TestSubDomain(t *testing.T) {
*/ */
r, _ = http.NewRequest("GET", "http://example.com/", nil) r, _ = http.NewRequest("GET", "http://example.com/", nil)
beegoInput.Request = r beegoInput.Context.Request = r
if beegoInput.SubDomains() != "" { if beegoInput.SubDomains() != "" {
t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains())
} }
r, _ = http.NewRequest("GET", "http://aa.bb.cc.dd.example.com/", nil) r, _ = http.NewRequest("GET", "http://aa.bb.cc.dd.example.com/", nil)
beegoInput.Request = r beegoInput.Context.Request = r
if beegoInput.SubDomains() != "aa.bb.cc.dd" { if beegoInput.SubDomains() != "aa.bb.cc.dd" {
t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains()) t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains())
} }

View File

@ -16,8 +16,6 @@ package context
import ( import (
"bytes" "bytes"
"compress/flate"
"compress/gzip"
"encoding/json" "encoding/json"
"encoding/xml" "encoding/xml"
"errors" "errors"
@ -45,6 +43,12 @@ func NewOutput() *BeegoOutput {
return &BeegoOutput{} return &BeegoOutput{}
} }
// Reset init BeegoOutput
func (output *BeegoOutput) Reset(ctx *Context) {
output.Context = ctx
output.Status = 0
}
// Header sets response header item string via given key. // Header sets response header item string via given key.
func (output *BeegoOutput) Header(key, val string) { func (output *BeegoOutput) Header(key, val string) {
output.Context.ResponseWriter.Header().Set(key, val) output.Context.ResponseWriter.Header().Set(key, val)
@ -54,30 +58,16 @@ func (output *BeegoOutput) Header(key, val string) {
// if EnableGzip, compress content string. // if EnableGzip, compress content string.
// it sends out response body directly. // it sends out response body directly.
func (output *BeegoOutput) Body(content []byte) { func (output *BeegoOutput) Body(content []byte) {
output_writer := output.Context.ResponseWriter.(io.Writer) var encoding string
if output.EnableGzip == true && output.Context.Input.Header("Accept-Encoding") != "" { var buf = &bytes.Buffer{}
splitted := strings.SplitN(output.Context.Input.Header("Accept-Encoding"), ",", -1) if output.EnableGzip {
encodings := make([]string, len(splitted)) encoding = ParseEncoding(output.Context.Request)
}
for i, val := range splitted { if b, n, _ := WriteBody(encoding, buf, content); b {
encodings[i] = strings.TrimSpace(val) output.Header("Content-Encoding", n)
}
for _, val := range encodings {
if val == "gzip" {
output.Header("Content-Encoding", "gzip")
output_writer, _ = gzip.NewWriterLevel(output.Context.ResponseWriter, gzip.BestSpeed)
break
} else if val == "deflate" {
output.Header("Content-Encoding", "deflate")
output_writer, _ = flate.NewWriter(output.Context.ResponseWriter, flate.BestSpeed)
break
}
}
} else { } else {
output.Header("Content-Length", strconv.Itoa(len(content))) output.Header("Content-Length", strconv.Itoa(len(content)))
} }
// Write status code if it has been set manually // Write status code if it has been set manually
// Set it to 0 afterwards to prevent "multiple response.WriteHeader calls" // Set it to 0 afterwards to prevent "multiple response.WriteHeader calls"
if output.Status != 0 { if output.Status != 0 {
@ -85,13 +75,7 @@ func (output *BeegoOutput) Body(content []byte) {
output.Status = 0 output.Status = 0
} }
output_writer.Write(content) io.Copy(output.Context.ResponseWriter, buf)
switch output_writer.(type) {
case *gzip.Writer:
output_writer.(*gzip.Writer).Close()
case *flate.Writer:
output_writer.(*flate.Writer).Close()
}
} }
// Cookie sets cookie value via given key. // Cookie sets cookie value via given key.
@ -100,29 +84,25 @@ func (output *BeegoOutput) Cookie(name string, value string, others ...interface
var b bytes.Buffer var b bytes.Buffer
fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value)) fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value))
//fix cookie not work in IE //fix cookie not work in IE
if len(others) > 0 { if len(others) > 0 {
switch v := others[0].(type) { var maxAge int64
case int:
if v > 0 { switch v := others[0].(type) {
fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(v) * time.Second).UTC().Format(time.RFC1123), v) case int:
} else if v < 0 { maxAge = int64(v)
fmt.Fprintf(&b, "; Max-Age=0") case int32:
} maxAge = int64(v)
case int64: case int64:
if v > 0 { maxAge = v
fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(v) * time.Second).UTC().Format(time.RFC1123), v) }
} else if v < 0 {
fmt.Fprintf(&b, "; Max-Age=0") if maxAge > 0 {
} fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(maxAge)*time.Second).UTC().Format(time.RFC1123), maxAge)
case int32: } else {
if v > 0 { fmt.Fprintf(&b, "; Max-Age=0")
fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(v) * time.Second).UTC().Format(time.RFC1123), v) }
} else if v < 0 { }
fmt.Fprintf(&b, "; Max-Age=0")
}
}
}
// the settings below // the settings below
// Path, Domain, Secure, HttpOnly // Path, Domain, Secure, HttpOnly
@ -188,9 +168,9 @@ func sanitizeValue(v string) string {
return cookieValueSanitizer.Replace(v) return cookieValueSanitizer.Replace(v)
} }
// Json writes json to response body. // JSON writes json to response body.
// if coding is true, it converts utf-8 to \u0000 type. // if coding is true, it converts utf-8 to \u0000 type.
func (output *BeegoOutput) Json(data interface{}, hasIndent bool, coding bool) error { func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, coding bool) error {
output.Header("Content-Type", "application/json; charset=utf-8") output.Header("Content-Type", "application/json; charset=utf-8")
var content []byte var content []byte
var err error var err error
@ -204,14 +184,14 @@ func (output *BeegoOutput) Json(data interface{}, hasIndent bool, coding bool) e
return err return err
} }
if coding { if coding {
content = []byte(stringsToJson(string(content))) content = []byte(stringsToJSON(string(content)))
} }
output.Body(content) output.Body(content)
return nil return nil
} }
// Jsonp writes jsonp to response body. // JSONP writes jsonp to response body.
func (output *BeegoOutput) Jsonp(data interface{}, hasIndent bool) error { func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error {
output.Header("Content-Type", "application/javascript; charset=utf-8") output.Header("Content-Type", "application/javascript; charset=utf-8")
var content []byte var content []byte
var err error var err error
@ -228,16 +208,16 @@ func (output *BeegoOutput) Jsonp(data interface{}, hasIndent bool) error {
if callback == "" { if callback == "" {
return errors.New(`"callback" parameter required`) return errors.New(`"callback" parameter required`)
} }
callback_content := bytes.NewBufferString(" " + template.JSEscapeString(callback)) callbackContent := bytes.NewBufferString(" " + template.JSEscapeString(callback))
callback_content.WriteString("(") callbackContent.WriteString("(")
callback_content.Write(content) callbackContent.Write(content)
callback_content.WriteString(");\r\n") callbackContent.WriteString(");\r\n")
output.Body(callback_content.Bytes()) output.Body(callbackContent.Bytes())
return nil return nil
} }
// Xml writes xml string to response body. // XML writes xml string to response body.
func (output *BeegoOutput) Xml(data interface{}, hasIndent bool) error { func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error {
output.Header("Content-Type", "application/xml; charset=utf-8") output.Header("Content-Type", "application/xml; charset=utf-8")
var content []byte var content []byte
var err error var err error
@ -331,7 +311,7 @@ func (output *BeegoOutput) IsNotFound(status int) bool {
return output.Status == 404 return output.Status == 404
} }
// IsClient returns boolean of this request client sends error data. // IsClientError returns boolean of this request client sends error data.
// HTTP 4xx means forbidden. // HTTP 4xx means forbidden.
func (output *BeegoOutput) IsClientError(status int) bool { func (output *BeegoOutput) IsClientError(status int) bool {
return output.Status >= 400 && output.Status < 500 return output.Status >= 400 && output.Status < 500
@ -343,7 +323,7 @@ func (output *BeegoOutput) IsServerError(status int) bool {
return output.Status >= 500 && output.Status < 600 return output.Status >= 500 && output.Status < 600
} }
func stringsToJson(str string) string { func stringsToJSON(str string) string {
rs := []rune(str) rs := []rune(str)
jsons := "" jsons := ""
for _, r := range rs { for _, r := range rs {
@ -357,7 +337,7 @@ func stringsToJson(str string) string {
return jsons return jsons
} }
// Sessions sets session item value with given key. // Session sets session item value with given key.
func (output *BeegoOutput) Session(name interface{}, value interface{}) { func (output *BeegoOutput) Session(name interface{}, value interface{}) {
output.Context.Input.CruSession.Set(name, value) output.Context.Input.CruSession.Set(name, value)
} }

View File

@ -19,7 +19,6 @@ import (
"errors" "errors"
"html/template" "html/template"
"io" "io"
"io/ioutil"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"net/url" "net/url"
@ -34,18 +33,19 @@ import (
//commonly used mime-types //commonly used mime-types
const ( const (
applicationJson = "application/json" applicationJSON = "application/json"
applicationXml = "application/xml" applicationXML = "application/xml"
textXml = "text/xml" textXML = "text/xml"
) )
var ( var (
// custom error when user stop request handler manually. // ErrAbort custom error when user stop request handler manually.
USERSTOPRUN = errors.New("User stop run") ErrAbort = errors.New("User stop run")
GlobalControllerRouter map[string][]ControllerComments = make(map[string][]ControllerComments) //pkgpath+controller:comments // GlobalControllerRouter store comments with controller. pkgpath+controller:comments
GlobalControllerRouter = make(map[string][]ControllerComments)
) )
// store the comment for the controller method // ControllerComments store the comment for the controller method
type ControllerComments struct { type ControllerComments struct {
Method string Method string
Router string Router string
@ -56,22 +56,31 @@ type ControllerComments struct {
// Controller defines some basic http request handler operations, such as // Controller defines some basic http request handler operations, such as
// http context, template and view, session and xsrf. // http context, template and view, session and xsrf.
type Controller struct { type Controller struct {
Ctx *context.Context // context data
Data map[interface{}]interface{} Ctx *context.Context
Data map[interface{}]interface{}
// route controller info
controllerName string controllerName string
actionName string actionName string
TplNames string methodMapping map[string]func() //method:routertree
gotofunc string
AppController interface{}
// template data
TplName string
Layout string Layout string
LayoutSections map[string]string // the key is the section name and the value is the template name LayoutSections map[string]string // the key is the section name and the value is the template name
TplExt string TplExt string
_xsrf_token string
gotofunc string
CruSession session.SessionStore
XSRFExpire int
AppController interface{}
EnableRender bool EnableRender bool
EnableXSRF bool
methodMapping map[string]func() //method:routertree // xsrf data
_xsrfToken string
XSRFExpire int
EnableXSRF bool
// session
CruSession session.Store
} }
// ControllerInterface is an interface to uniform all controller handler. // ControllerInterface is an interface to uniform all controller handler.
@ -87,8 +96,8 @@ type ControllerInterface interface {
Options() Options()
Finish() Finish()
Render() error Render() error
XsrfToken() string XSRFToken() string
CheckXsrfCookie() bool CheckXSRFCookie() bool
HandlerFunc(fn string) bool HandlerFunc(fn string) bool
URLMapping() URLMapping()
} }
@ -96,7 +105,7 @@ type ControllerInterface interface {
// Init generates default values of controller operations. // Init generates default values of controller operations.
func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) { func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) {
c.Layout = "" c.Layout = ""
c.TplNames = "" c.TplName = ""
c.controllerName = controllerName c.controllerName = controllerName
c.actionName = actionName c.actionName = actionName
c.Ctx = ctx c.Ctx = ctx
@ -104,19 +113,15 @@ func (c *Controller) Init(ctx *context.Context, controllerName, actionName strin
c.AppController = app c.AppController = app
c.EnableRender = true c.EnableRender = true
c.EnableXSRF = true c.EnableXSRF = true
c.Data = ctx.Input.Data c.Data = ctx.Input.Data()
c.methodMapping = make(map[string]func()) c.methodMapping = make(map[string]func())
} }
// Prepare runs after Init before request function execution. // Prepare runs after Init before request function execution.
func (c *Controller) Prepare() { func (c *Controller) Prepare() {}
}
// Finish runs after request function execution. // Finish runs after request function execution.
func (c *Controller) Finish() { func (c *Controller) Finish() {}
}
// Get adds a request function to handle GET request. // Get adds a request function to handle GET request.
func (c *Controller) Get() { func (c *Controller) Get() {
@ -153,20 +158,19 @@ func (c *Controller) Options() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405)
} }
// call function fn // HandlerFunc call function with the name
func (c *Controller) HandlerFunc(fnname string) bool { func (c *Controller) HandlerFunc(fnname string) bool {
if v, ok := c.methodMapping[fnname]; ok { if v, ok := c.methodMapping[fnname]; ok {
v() v()
return true return true
} else {
return false
} }
return false
} }
// URLMapping register the internal Controller router. // URLMapping register the internal Controller router.
func (c *Controller) URLMapping() { func (c *Controller) URLMapping() {}
}
// Mapping the method to function
func (c *Controller) Mapping(method string, fn func()) { func (c *Controller) Mapping(method string, fn func()) {
c.methodMapping[method] = fn c.methodMapping[method] = fn
} }
@ -177,13 +181,11 @@ func (c *Controller) Render() error {
return nil return nil
} }
rb, err := c.RenderBytes() rb, err := c.RenderBytes()
if err != nil { if err != nil {
return err return err
} else {
c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8")
c.Ctx.Output.Body(rb)
} }
c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8")
c.Ctx.Output.Body(rb)
return nil return nil
} }
@ -196,24 +198,33 @@ func (c *Controller) RenderString() (string, error) {
// RenderBytes returns the bytes of rendered template string. Do not send out response. // RenderBytes returns the bytes of rendered template string. Do not send out response.
func (c *Controller) RenderBytes() ([]byte, error) { func (c *Controller) RenderBytes() ([]byte, error) {
//if the controller has set layout, then first get the tplname's content set the content to the layout //if the controller has set layout, then first get the tplname's content set the content to the layout
var buf bytes.Buffer
if c.Layout != "" { if c.Layout != "" {
if c.TplNames == "" { if c.TplName == "" {
c.TplNames = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt c.TplName = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt
} }
if RunMode == "dev" {
BuildTemplate(ViewsPath) if BConfig.RunMode == DEV {
buildFiles := []string{c.TplName}
if c.LayoutSections != nil {
for _, sectionTpl := range c.LayoutSections {
if sectionTpl == "" {
continue
}
buildFiles = append(buildFiles, sectionTpl)
}
}
BuildTemplate(BConfig.WebConfig.ViewsPath, buildFiles...)
} }
newbytes := bytes.NewBufferString("") if _, ok := BeeTemplates[c.TplName]; !ok {
if _, ok := BeeTemplates[c.TplNames]; !ok { panic("can't find templatefile in the path:" + c.TplName)
panic("can't find templatefile in the path:" + c.TplNames)
} }
err := BeeTemplates[c.TplNames].ExecuteTemplate(newbytes, c.TplNames, c.Data) err := BeeTemplates[c.TplName].ExecuteTemplate(&buf, c.TplName, c.Data)
if err != nil { if err != nil {
Trace("template Execute err:", err) Trace("template Execute err:", err)
return nil, err return nil, err
} }
tplcontent, _ := ioutil.ReadAll(newbytes) c.Data["LayoutContent"] = template.HTML(buf.String())
c.Data["LayoutContent"] = template.HTML(string(tplcontent))
if c.LayoutSections != nil { if c.LayoutSections != nil {
for sectionName, sectionTpl := range c.LayoutSections { for sectionName, sectionTpl := range c.LayoutSections {
@ -222,44 +233,41 @@ func (c *Controller) RenderBytes() ([]byte, error) {
continue continue
} }
sectionBytes := bytes.NewBufferString("") buf.Reset()
err = BeeTemplates[sectionTpl].ExecuteTemplate(sectionBytes, sectionTpl, c.Data) err = BeeTemplates[sectionTpl].ExecuteTemplate(&buf, sectionTpl, c.Data)
if err != nil { if err != nil {
Trace("template Execute err:", err) Trace("template Execute err:", err)
return nil, err return nil, err
} }
sectionContent, _ := ioutil.ReadAll(sectionBytes) c.Data[sectionName] = template.HTML(buf.String())
c.Data[sectionName] = template.HTML(string(sectionContent))
} }
} }
ibytes := bytes.NewBufferString("") buf.Reset()
err = BeeTemplates[c.Layout].ExecuteTemplate(ibytes, c.Layout, c.Data) err = BeeTemplates[c.Layout].ExecuteTemplate(&buf, c.Layout, c.Data)
if err != nil { if err != nil {
Trace("template Execute err:", err) Trace("template Execute err:", err)
return nil, err return nil, err
} }
icontent, _ := ioutil.ReadAll(ibytes) return buf.Bytes(), nil
return icontent, nil
} else {
if c.TplNames == "" {
c.TplNames = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt
}
if RunMode == "dev" {
BuildTemplate(ViewsPath)
}
ibytes := bytes.NewBufferString("")
if _, ok := BeeTemplates[c.TplNames]; !ok {
panic("can't find templatefile in the path:" + c.TplNames)
}
err := BeeTemplates[c.TplNames].ExecuteTemplate(ibytes, c.TplNames, c.Data)
if err != nil {
Trace("template Execute err:", err)
return nil, err
}
icontent, _ := ioutil.ReadAll(ibytes)
return icontent, nil
} }
if c.TplName == "" {
c.TplName = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt
}
if BConfig.RunMode == DEV {
BuildTemplate(BConfig.WebConfig.ViewsPath, c.TplName)
}
if _, ok := BeeTemplates[c.TplName]; !ok {
panic("can't find templatefile in the path:" + c.TplName)
}
buf.Reset()
err := BeeTemplates[c.TplName].ExecuteTemplate(&buf, c.TplName, c.Data)
if err != nil {
Trace("template Execute err:", err)
return nil, err
}
return buf.Bytes(), nil
} }
// Redirect sends the redirection response to url with status code. // Redirect sends the redirection response to url with status code.
@ -267,7 +275,7 @@ func (c *Controller) Redirect(url string, code int) {
c.Ctx.Redirect(code, url) c.Ctx.Redirect(code, url)
} }
// Aborts stops controller handler and show the error data if code is defined in ErrorMap or code string. // Abort stops controller handler and show the error data if code is defined in ErrorMap or code string.
func (c *Controller) Abort(code string) { func (c *Controller) Abort(code string) {
status, err := strconv.Atoi(code) status, err := strconv.Atoi(code)
if err != nil { if err != nil {
@ -285,74 +293,69 @@ func (c *Controller) CustomAbort(status int, body string) {
} }
// last panic user string // last panic user string
c.Ctx.ResponseWriter.Write([]byte(body)) c.Ctx.ResponseWriter.Write([]byte(body))
panic(USERSTOPRUN) panic(ErrAbort)
} }
// StopRun makes panic of USERSTOPRUN error and go to recover function if defined. // StopRun makes panic of USERSTOPRUN error and go to recover function if defined.
func (c *Controller) StopRun() { func (c *Controller) StopRun() {
panic(USERSTOPRUN) panic(ErrAbort)
} }
// UrlFor does another controller handler in this request function. // URLFor does another controller handler in this request function.
// it goes to this controller method if endpoint is not clear. // it goes to this controller method if endpoint is not clear.
func (c *Controller) UrlFor(endpoint string, values ...interface{}) string { func (c *Controller) URLFor(endpoint string, values ...interface{}) string {
if len(endpoint) <= 0 { if len(endpoint) == 0 {
return "" return ""
} }
if endpoint[0] == '.' { if endpoint[0] == '.' {
return UrlFor(reflect.Indirect(reflect.ValueOf(c.AppController)).Type().Name()+endpoint, values...) return URLFor(reflect.Indirect(reflect.ValueOf(c.AppController)).Type().Name()+endpoint, values...)
} else {
return UrlFor(endpoint, values...)
} }
return URLFor(endpoint, values...)
} }
// ServeJson sends a json response with encoding charset. // ServeJSON sends a json response with encoding charset.
func (c *Controller) ServeJson(encoding ...bool) { func (c *Controller) ServeJSON(encoding ...bool) {
var hasIndent bool var (
var hasencoding bool hasIndent = true
if RunMode == "prod" { hasEncoding = false
)
if BConfig.RunMode == PROD {
hasIndent = false hasIndent = false
} else {
hasIndent = true
} }
if len(encoding) > 0 && encoding[0] == true { if len(encoding) > 0 && encoding[0] == true {
hasencoding = true hasEncoding = true
} }
c.Ctx.Output.Json(c.Data["json"], hasIndent, hasencoding) c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding)
} }
// ServeJsonp sends a jsonp response. // ServeJSONP sends a jsonp response.
func (c *Controller) ServeJsonp() { func (c *Controller) ServeJSONP() {
var hasIndent bool hasIndent := true
if RunMode == "prod" { if BConfig.RunMode == PROD {
hasIndent = false hasIndent = false
} else {
hasIndent = true
} }
c.Ctx.Output.Jsonp(c.Data["jsonp"], hasIndent) c.Ctx.Output.JSONP(c.Data["jsonp"], hasIndent)
} }
// ServeXml sends xml response. // ServeXML sends xml response.
func (c *Controller) ServeXml() { func (c *Controller) ServeXML() {
var hasIndent bool hasIndent := true
if RunMode == "prod" { if BConfig.RunMode == PROD {
hasIndent = false hasIndent = false
} else {
hasIndent = true
} }
c.Ctx.Output.Xml(c.Data["xml"], hasIndent) c.Ctx.Output.XML(c.Data["xml"], hasIndent)
} }
// ServeFormatted serve Xml OR Json, depending on the value of the Accept header // ServeFormatted serve Xml OR Json, depending on the value of the Accept header
func (c *Controller) ServeFormatted() { func (c *Controller) ServeFormatted() {
accept := c.Ctx.Input.Header("Accept") accept := c.Ctx.Input.Header("Accept")
switch accept { switch accept {
case applicationJson: case applicationJSON:
c.ServeJson() c.ServeJSON()
case applicationXml, textXml: case applicationXML, textXML:
c.ServeXml() c.ServeXML()
default: default:
c.ServeJson() c.ServeJSON()
} }
} }
@ -371,16 +374,13 @@ func (c *Controller) ParseForm(obj interface{}) error {
// GetString returns the input value by key string or the default value while it's present and input is blank // GetString returns the input value by key string or the default value while it's present and input is blank
func (c *Controller) GetString(key string, def ...string) string { func (c *Controller) GetString(key string, def ...string) string {
var defv string
if len(def) > 0 {
defv = def[0]
}
if v := c.Ctx.Input.Query(key); v != "" { if v := c.Ctx.Input.Query(key); v != "" {
return v return v
} else {
return defv
} }
if len(def) > 0 {
return def[0]
}
return ""
} }
// GetStrings returns the input string slice by key string or the default value while it's present and input is blank // GetStrings returns the input string slice by key string or the default value while it's present and input is blank
@ -391,106 +391,81 @@ func (c *Controller) GetStrings(key string, def ...[]string) []string {
defv = def[0] defv = def[0]
} }
f := c.Input() if f := c.Input(); f == nil {
if f == nil {
return defv return defv
} else {
if vs := f[key]; len(vs) > 0 {
return vs
}
} }
vs := f[key] return defv
if len(vs) > 0 {
return vs
} else {
return defv
}
} }
// GetInt returns input as an int or the default value while it's present and input is blank // GetInt returns input as an int or the default value while it's present and input is blank
func (c *Controller) GetInt(key string, def ...int) (int, error) { func (c *Controller) GetInt(key string, def ...int) (int, error) {
if strv := c.Ctx.Input.Query(key); strv != "" { strv := c.Ctx.Input.Query(key)
return strconv.Atoi(strv) if len(strv) == 0 && len(def) > 0 {
} else if len(def) > 0 {
return def[0], nil return def[0], nil
} else {
return strconv.Atoi(strv)
} }
return strconv.Atoi(strv)
} }
// GetInt8 return input as an int8 or the default value while it's present and input is blank // GetInt8 return input as an int8 or the default value while it's present and input is blank
func (c *Controller) GetInt8(key string, def ...int8) (int8, error) { func (c *Controller) GetInt8(key string, def ...int8) (int8, error) {
if strv := c.Ctx.Input.Query(key); strv != "" { strv := c.Ctx.Input.Query(key)
i64, err := strconv.ParseInt(strv, 10, 8) if len(strv) == 0 && len(def) > 0 {
i8 := int8(i64)
return i8, err
} else if len(def) > 0 {
return def[0], nil return def[0], nil
} else {
i64, err := strconv.ParseInt(strv, 10, 8)
i8 := int8(i64)
return i8, err
} }
i64, err := strconv.ParseInt(strv, 10, 8)
return int8(i64), err
} }
// GetInt16 returns input as an int16 or the default value while it's present and input is blank // GetInt16 returns input as an int16 or the default value while it's present and input is blank
func (c *Controller) GetInt16(key string, def ...int16) (int16, error) { func (c *Controller) GetInt16(key string, def ...int16) (int16, error) {
if strv := c.Ctx.Input.Query(key); strv != "" { strv := c.Ctx.Input.Query(key)
i64, err := strconv.ParseInt(strv, 10, 16) if len(strv) == 0 && len(def) > 0 {
i16 := int16(i64)
return i16, err
} else if len(def) > 0 {
return def[0], nil return def[0], nil
} else {
i64, err := strconv.ParseInt(strv, 10, 16)
i16 := int16(i64)
return i16, err
} }
i64, err := strconv.ParseInt(strv, 10, 16)
return int16(i64), err
} }
// GetInt32 returns input as an int32 or the default value while it's present and input is blank // GetInt32 returns input as an int32 or the default value while it's present and input is blank
func (c *Controller) GetInt32(key string, def ...int32) (int32, error) { func (c *Controller) GetInt32(key string, def ...int32) (int32, error) {
if strv := c.Ctx.Input.Query(key); strv != "" { strv := c.Ctx.Input.Query(key)
i64, err := strconv.ParseInt(c.Ctx.Input.Query(key), 10, 32) if len(strv) == 0 && len(def) > 0 {
i32 := int32(i64)
return i32, err
} else if len(def) > 0 {
return def[0], nil return def[0], nil
} else {
i64, err := strconv.ParseInt(c.Ctx.Input.Query(key), 10, 32)
i32 := int32(i64)
return i32, err
} }
i64, err := strconv.ParseInt(strv, 10, 32)
return int32(i64), err
} }
// GetInt64 returns input value as int64 or the default value while it's present and input is blank. // GetInt64 returns input value as int64 or the default value while it's present and input is blank.
func (c *Controller) GetInt64(key string, def ...int64) (int64, error) { func (c *Controller) GetInt64(key string, def ...int64) (int64, error) {
if strv := c.Ctx.Input.Query(key); strv != "" { strv := c.Ctx.Input.Query(key)
return strconv.ParseInt(strv, 10, 64) if len(strv) == 0 && len(def) > 0 {
} else if len(def) > 0 {
return def[0], nil return def[0], nil
} else {
return strconv.ParseInt(strv, 10, 64)
} }
return strconv.ParseInt(strv, 10, 64)
} }
// GetBool returns input value as bool or the default value while it's present and input is blank. // GetBool returns input value as bool or the default value while it's present and input is blank.
func (c *Controller) GetBool(key string, def ...bool) (bool, error) { func (c *Controller) GetBool(key string, def ...bool) (bool, error) {
if strv := c.Ctx.Input.Query(key); strv != "" { strv := c.Ctx.Input.Query(key)
return strconv.ParseBool(strv) if len(strv) == 0 && len(def) > 0 {
} else if len(def) > 0 {
return def[0], nil return def[0], nil
} else {
return strconv.ParseBool(strv)
} }
return strconv.ParseBool(strv)
} }
// GetFloat returns input value as float64 or the default value while it's present and input is blank. // GetFloat returns input value as float64 or the default value while it's present and input is blank.
func (c *Controller) GetFloat(key string, def ...float64) (float64, error) { func (c *Controller) GetFloat(key string, def ...float64) (float64, error) {
if strv := c.Ctx.Input.Query(key); strv != "" { strv := c.Ctx.Input.Query(key)
return strconv.ParseFloat(strv, 64) if len(strv) == 0 && len(def) > 0 {
} else if len(def) > 0 {
return def[0], nil return def[0], nil
} else {
return strconv.ParseFloat(strv, 64)
} }
return strconv.ParseFloat(strv, 64)
} }
// GetFile returns the file data in file upload field named as key. // GetFile returns the file data in file upload field named as key.
@ -527,8 +502,7 @@ func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader,
// } // }
// } // }
func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) { func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) {
files, ok := c.Ctx.Request.MultipartForm.File[key] if files, ok := c.Ctx.Request.MultipartForm.File[key]; ok {
if ok {
return files, nil return files, nil
} }
return nil, http.ErrMissingFile return nil, http.ErrMissingFile
@ -552,7 +526,7 @@ func (c *Controller) SaveToFile(fromfile, tofile string) error {
} }
// StartSession starts session and load old session data info this controller. // StartSession starts session and load old session data info this controller.
func (c *Controller) StartSession() session.SessionStore { func (c *Controller) StartSession() session.Store {
if c.CruSession == nil { if c.CruSession == nil {
c.CruSession = c.Ctx.Input.CruSession c.CruSession = c.Ctx.Input.CruSession
} }
@ -575,7 +549,7 @@ func (c *Controller) GetSession(name interface{}) interface{} {
return c.CruSession.Get(name) return c.CruSession.Get(name)
} }
// SetSession removes value from session. // DelSession removes value from session.
func (c *Controller) DelSession(name interface{}) { func (c *Controller) DelSession(name interface{}) {
if c.CruSession == nil { if c.CruSession == nil {
c.StartSession() c.StartSession()
@ -589,7 +563,7 @@ func (c *Controller) SessionRegenerateID() {
if c.CruSession != nil { if c.CruSession != nil {
c.CruSession.SessionRelease(c.Ctx.ResponseWriter) c.CruSession.SessionRelease(c.Ctx.ResponseWriter)
} }
c.CruSession = GlobalSessions.SessionRegenerateId(c.Ctx.ResponseWriter, c.Ctx.Request) c.CruSession = GlobalSessions.SessionRegenerateID(c.Ctx.ResponseWriter, c.Ctx.Request)
c.Ctx.Input.CruSession = c.CruSession c.Ctx.Input.CruSession = c.CruSession
} }
@ -614,37 +588,35 @@ func (c *Controller) SetSecureCookie(Secret, name, value string, others ...inter
c.Ctx.SetSecureCookie(Secret, name, value, others...) c.Ctx.SetSecureCookie(Secret, name, value, others...)
} }
// XsrfToken creates a xsrf token string and returns. // XSRFToken creates a CSRF token string and returns.
func (c *Controller) XsrfToken() string { func (c *Controller) XSRFToken() string {
if c._xsrf_token == "" { if c._xsrfToken == "" {
var expire int64 expire := int64(BConfig.WebConfig.XSRFExpire)
if c.XSRFExpire > 0 { if c.XSRFExpire > 0 {
expire = int64(c.XSRFExpire) expire = int64(c.XSRFExpire)
} else {
expire = int64(XSRFExpire)
} }
c._xsrf_token = c.Ctx.XsrfToken(XSRFKEY, expire) c._xsrfToken = c.Ctx.XSRFToken(BConfig.WebConfig.XSRFKey, expire)
} }
return c._xsrf_token return c._xsrfToken
} }
// CheckXsrfCookie checks xsrf token in this request is valid or not. // CheckXSRFCookie checks xsrf token in this request is valid or not.
// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" // the token can provided in request header "X-Xsrftoken" and "X-CsrfToken"
// or in form field value named as "_xsrf". // or in form field value named as "_xsrf".
func (c *Controller) CheckXsrfCookie() bool { func (c *Controller) CheckXSRFCookie() bool {
if !c.EnableXSRF { if !c.EnableXSRF {
return true return true
} }
return c.Ctx.CheckXsrfCookie() return c.Ctx.CheckXSRFCookie()
} }
// XsrfFormHtml writes an input field contains xsrf token value. // XSRFFormHTML writes an input field contains xsrf token value.
func (c *Controller) XsrfFormHtml() string { func (c *Controller) XSRFFormHTML() string {
return "<input type=\"hidden\" name=\"_xsrf\" value=\"" + return `<input type="hidden" name="_xsrf" value="` +
c._xsrf_token + "\"/>" c.XSRFToken() + `" />`
} }
// GetControllerAndAction gets the executing controller name and action name. // GetControllerAndAction gets the executing controller name and action name.
func (c *Controller) GetControllerAndAction() (controllerName, actionName string) { func (c *Controller) GetControllerAndAction() (string, string) {
return c.controllerName, c.actionName return c.controllerName, c.actionName
} }

View File

@ -15,61 +15,63 @@
package beego package beego
import ( import (
"fmt" "testing"
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
) )
func ExampleGetInt() { func TestGetInt(t *testing.T) {
i := context.NewInput()
i := &context.BeegoInput{Params: map[string]string{"age": "40"}} i.SetParam("age", "40")
ctx := &context.Context{Input: i} ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx} ctrlr := Controller{Ctx: ctx}
val, _ := ctrlr.GetInt("age") val, _ := ctrlr.GetInt("age")
fmt.Printf("%T", val) if val != 40 {
//Output: int t.Errorf("TestGetInt expect 40,get %T,%v", val, val)
}
} }
func ExampleGetInt8() { func TestGetInt8(t *testing.T) {
i := context.NewInput()
i := &context.BeegoInput{Params: map[string]string{"age": "40"}} i.SetParam("age", "40")
ctx := &context.Context{Input: i} ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx} ctrlr := Controller{Ctx: ctx}
val, _ := ctrlr.GetInt8("age") val, _ := ctrlr.GetInt8("age")
fmt.Printf("%T", val) if val != 40 {
t.Errorf("TestGetInt8 expect 40,get %T,%v", val, val)
}
//Output: int8 //Output: int8
} }
func ExampleGetInt16() { func TestGetInt16(t *testing.T) {
i := context.NewInput()
i := &context.BeegoInput{Params: map[string]string{"age": "40"}} i.SetParam("age", "40")
ctx := &context.Context{Input: i} ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx} ctrlr := Controller{Ctx: ctx}
val, _ := ctrlr.GetInt16("age") val, _ := ctrlr.GetInt16("age")
fmt.Printf("%T", val) if val != 40 {
//Output: int16 t.Errorf("TestGetInt16 expect 40,get %T,%v", val, val)
}
} }
func ExampleGetInt32() { func TestGetInt32(t *testing.T) {
i := context.NewInput()
i := &context.BeegoInput{Params: map[string]string{"age": "40"}} i.SetParam("age", "40")
ctx := &context.Context{Input: i} ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx} ctrlr := Controller{Ctx: ctx}
val, _ := ctrlr.GetInt32("age") val, _ := ctrlr.GetInt32("age")
fmt.Printf("%T", val) if val != 40 {
//Output: int32 t.Errorf("TestGetInt32 expect 40,get %T,%v", val, val)
}
} }
func ExampleGetInt64() { func TestGetInt64(t *testing.T) {
i := context.NewInput()
i := &context.BeegoInput{Params: map[string]string{"age": "40"}} i.SetParam("age", "40")
ctx := &context.Context{Input: i} ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx} ctrlr := Controller{Ctx: ctx}
val, _ := ctrlr.GetInt64("age") val, _ := ctrlr.GetInt64("age")
fmt.Printf("%T", val) if val != 40 {
//Output: int64 t.Errorf("TestGeetInt64 expect 40,get %T,%v", val, val)
}
} }

17
doc.go Normal file
View File

@ -0,0 +1,17 @@
/*
Package beego provide a MVC framework
beego: an open-source, high-performance, modular, full-stack web framework
It is used for rapid development of RESTful APIs, web apps and backend services in Go.
beego is inspired by Tornado, Sinatra and Flask with the added benefit of some Go-specific features such as interfaces and struct embedding.
package main
import "github.com/astaxie/beego"
func main() {
beego.Run()
}
more infomation: http://beego.me
*/
package beego

23
docs.go
View File

@ -15,37 +15,24 @@
package beego package beego
import ( import (
"encoding/json"
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
) )
var GlobalDocApi map[string]interface{} // GlobalDocAPI store the swagger api documents
var GlobalDocAPI = make(map[string]interface{})
func init() {
if EnableDocs {
GlobalDocApi = make(map[string]interface{})
}
}
func serverDocs(ctx *context.Context) { func serverDocs(ctx *context.Context) {
var obj interface{} var obj interface{}
if splat := ctx.Input.Param(":splat"); splat == "" { if splat := ctx.Input.Param(":splat"); splat == "" {
obj = GlobalDocApi["Root"] obj = GlobalDocAPI["Root"]
} else { } else {
if v, ok := GlobalDocApi[splat]; ok { if v, ok := GlobalDocAPI[splat]; ok {
obj = v obj = v
} }
} }
if obj != nil { if obj != nil {
bt, err := json.Marshal(obj)
if err != nil {
ctx.Output.SetStatus(504)
return
}
ctx.Output.Header("Content-Type", "application/json;charset=UTF-8")
ctx.Output.Header("Access-Control-Allow-Origin", "*") ctx.Output.Header("Access-Control-Allow-Origin", "*")
ctx.Output.Body(bt) ctx.Output.JSON(obj, false, false)
return return
} }
ctx.Output.SetStatus(404) ctx.Output.SetStatus(404)

203
error.go
View File

@ -82,16 +82,17 @@ var tpl = `
` `
// render default application error page with error and stack string. // render default application error page with error and stack string.
func showErr(err interface{}, ctx *context.Context, Stack string) { func showErr(err interface{}, ctx *context.Context, stack string) {
t, _ := template.New("beegoerrortemp").Parse(tpl) t, _ := template.New("beegoerrortemp").Parse(tpl)
data := make(map[string]string) data := map[string]string{
data["AppError"] = AppName + ":" + fmt.Sprint(err) "AppError": fmt.Sprintf("%s:%v", BConfig.AppName, err),
data["RequestMethod"] = ctx.Input.Method() "RequestMethod": ctx.Input.Method(),
data["RequestURL"] = ctx.Input.Uri() "RequestURL": ctx.Input.URI(),
data["RemoteAddr"] = ctx.Input.IP() "RemoteAddr": ctx.Input.IP(),
data["Stack"] = Stack "Stack": stack,
data["BeegoVersion"] = VERSION "BeegoVersion": VERSION,
data["GoVersion"] = runtime.Version() "GoVersion": runtime.Version(),
}
ctx.ResponseWriter.WriteHeader(500) ctx.ResponseWriter.WriteHeader(500)
t.Execute(ctx.ResponseWriter, data) t.Execute(ctx.ResponseWriter, data)
} }
@ -204,47 +205,48 @@ type errorInfo struct {
} }
// map of http handlers for each error string. // map of http handlers for each error string.
var ErrorMaps map[string]*errorInfo // there is 10 kinds default error(40x and 50x)
var ErrorMaps = make(map[string]*errorInfo, 10)
func init() {
ErrorMaps = make(map[string]*errorInfo)
}
// show 401 unauthorized error. // show 401 unauthorized error.
func unauthorized(rw http.ResponseWriter, r *http.Request) { func unauthorized(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := map[string]interface{}{
data["Title"] = "Unauthorized" "Title": http.StatusText(401),
"BeegoVersion": VERSION,
}
data["Content"] = template.HTML("<br>The page you have requested can't be authorized." + data["Content"] = template.HTML("<br>The page you have requested can't be authorized." +
"<br>Perhaps you are here because:" + "<br>Perhaps you are here because:" +
"<br><br><ul>" + "<br><br><ul>" +
"<br>The credentials you supplied are incorrect" + "<br>The credentials you supplied are incorrect" +
"<br>There are errors in the website address" + "<br>There are errors in the website address" +
"</ul>") "</ul>")
data["BeegoVersion"] = VERSION
t.Execute(rw, data) t.Execute(rw, data)
} }
// show 402 Payment Required // show 402 Payment Required
func paymentRequired(rw http.ResponseWriter, r *http.Request) { func paymentRequired(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := map[string]interface{}{
data["Title"] = "Payment Required" "Title": http.StatusText(402),
"BeegoVersion": VERSION,
}
data["Content"] = template.HTML("<br>The page you have requested Payment Required." + data["Content"] = template.HTML("<br>The page you have requested Payment Required." +
"<br>Perhaps you are here because:" + "<br>Perhaps you are here because:" +
"<br><br><ul>" + "<br><br><ul>" +
"<br>The credentials you supplied are incorrect" + "<br>The credentials you supplied are incorrect" +
"<br>There are errors in the website address" + "<br>There are errors in the website address" +
"</ul>") "</ul>")
data["BeegoVersion"] = VERSION
t.Execute(rw, data) t.Execute(rw, data)
} }
// show 403 forbidden error. // show 403 forbidden error.
func forbidden(rw http.ResponseWriter, r *http.Request) { func forbidden(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := map[string]interface{}{
data["Title"] = "Forbidden" "Title": http.StatusText(403),
"BeegoVersion": VERSION,
}
data["Content"] = template.HTML("<br>The page you have requested is forbidden." + data["Content"] = template.HTML("<br>The page you have requested is forbidden." +
"<br>Perhaps you are here because:" + "<br>Perhaps you are here because:" +
"<br><br><ul>" + "<br><br><ul>" +
@ -252,15 +254,16 @@ func forbidden(rw http.ResponseWriter, r *http.Request) {
"<br>The site may be disabled" + "<br>The site may be disabled" +
"<br>You need to log in" + "<br>You need to log in" +
"</ul>") "</ul>")
data["BeegoVersion"] = VERSION
t.Execute(rw, data) t.Execute(rw, data)
} }
// show 404 notfound error. // show 404 notfound error.
func notFound(rw http.ResponseWriter, r *http.Request) { func notFound(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := map[string]interface{}{
data["Title"] = "Page Not Found" "Title": http.StatusText(404),
"BeegoVersion": VERSION,
}
data["Content"] = template.HTML("<br>The page you have requested has flown the coop." + data["Content"] = template.HTML("<br>The page you have requested has flown the coop." +
"<br>Perhaps you are here because:" + "<br>Perhaps you are here because:" +
"<br><br><ul>" + "<br><br><ul>" +
@ -269,191 +272,158 @@ func notFound(rw http.ResponseWriter, r *http.Request) {
"<br>You were looking for your puppy and got lost" + "<br>You were looking for your puppy and got lost" +
"<br>You like 404 pages" + "<br>You like 404 pages" +
"</ul>") "</ul>")
data["BeegoVersion"] = VERSION
t.Execute(rw, data) t.Execute(rw, data)
} }
// show 405 Method Not Allowed // show 405 Method Not Allowed
func methodNotAllowed(rw http.ResponseWriter, r *http.Request) { func methodNotAllowed(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := map[string]interface{}{
data["Title"] = "Method Not Allowed" "Title": http.StatusText(405),
"BeegoVersion": VERSION,
}
data["Content"] = template.HTML("<br>The method you have requested Not Allowed." + data["Content"] = template.HTML("<br>The method you have requested Not Allowed." +
"<br>Perhaps you are here because:" + "<br>Perhaps you are here because:" +
"<br><br><ul>" + "<br><br><ul>" +
"<br>The method specified in the Request-Line is not allowed for the resource identified by the Request-URI" + "<br>The method specified in the Request-Line is not allowed for the resource identified by the Request-URI" +
"<br>The response MUST include an Allow header containing a list of valid methods for the requested resource." + "<br>The response MUST include an Allow header containing a list of valid methods for the requested resource." +
"</ul>") "</ul>")
data["BeegoVersion"] = VERSION
t.Execute(rw, data) t.Execute(rw, data)
} }
// show 500 internal server error. // show 500 internal server error.
func internalServerError(rw http.ResponseWriter, r *http.Request) { func internalServerError(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := map[string]interface{}{
data["Title"] = "Internal Server Error" "Title": http.StatusText(500),
"BeegoVersion": VERSION,
}
data["Content"] = template.HTML("<br>The page you have requested is down right now." + data["Content"] = template.HTML("<br>The page you have requested is down right now." +
"<br><br><ul>" + "<br><br><ul>" +
"<br>Please try again later and report the error to the website administrator" + "<br>Please try again later and report the error to the website administrator" +
"<br></ul>") "<br></ul>")
data["BeegoVersion"] = VERSION
t.Execute(rw, data) t.Execute(rw, data)
} }
// show 501 Not Implemented. // show 501 Not Implemented.
func notImplemented(rw http.ResponseWriter, r *http.Request) { func notImplemented(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := map[string]interface{}{
data["Title"] = "Not Implemented" "Title": http.StatusText(504),
"BeegoVersion": VERSION,
}
data["Content"] = template.HTML("<br>The page you have requested is Not Implemented." + data["Content"] = template.HTML("<br>The page you have requested is Not Implemented." +
"<br><br><ul>" + "<br><br><ul>" +
"<br>Please try again later and report the error to the website administrator" + "<br>Please try again later and report the error to the website administrator" +
"<br></ul>") "<br></ul>")
data["BeegoVersion"] = VERSION
t.Execute(rw, data) t.Execute(rw, data)
} }
// show 502 Bad Gateway. // show 502 Bad Gateway.
func badGateway(rw http.ResponseWriter, r *http.Request) { func badGateway(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := map[string]interface{}{
data["Title"] = "Bad Gateway" "Title": http.StatusText(502),
"BeegoVersion": VERSION,
}
data["Content"] = template.HTML("<br>The page you have requested is down right now." + data["Content"] = template.HTML("<br>The page you have requested is down right now." +
"<br><br><ul>" + "<br><br><ul>" +
"<br>The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request." + "<br>The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request." +
"<br>Please try again later and report the error to the website administrator" + "<br>Please try again later and report the error to the website administrator" +
"<br></ul>") "<br></ul>")
data["BeegoVersion"] = VERSION
t.Execute(rw, data) t.Execute(rw, data)
} }
// show 503 service unavailable error. // show 503 service unavailable error.
func serviceUnavailable(rw http.ResponseWriter, r *http.Request) { func serviceUnavailable(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := map[string]interface{}{
data["Title"] = "Service Unavailable" "Title": http.StatusText(503),
"BeegoVersion": VERSION,
}
data["Content"] = template.HTML("<br>The page you have requested is unavailable." + data["Content"] = template.HTML("<br>The page you have requested is unavailable." +
"<br>Perhaps you are here because:" + "<br>Perhaps you are here because:" +
"<br><br><ul>" + "<br><br><ul>" +
"<br><br>The page is overloaded" + "<br><br>The page is overloaded" +
"<br>Please try again later." + "<br>Please try again later." +
"</ul>") "</ul>")
data["BeegoVersion"] = VERSION
t.Execute(rw, data) t.Execute(rw, data)
} }
// show 504 Gateway Timeout. // show 504 Gateway Timeout.
func gatewayTimeout(rw http.ResponseWriter, r *http.Request) { func gatewayTimeout(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{}) data := map[string]interface{}{
data["Title"] = "Gateway Timeout" "Title": http.StatusText(504),
"BeegoVersion": VERSION,
}
data["Content"] = template.HTML("<br>The page you have requested is unavailable." + data["Content"] = template.HTML("<br>The page you have requested is unavailable." +
"<br>Perhaps you are here because:" + "<br>Perhaps you are here because:" +
"<br><br><ul>" + "<br><br><ul>" +
"<br><br>The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI." + "<br><br>The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI." +
"<br>Please try again later." + "<br>Please try again later." +
"</ul>") "</ul>")
data["BeegoVersion"] = VERSION
t.Execute(rw, data) t.Execute(rw, data)
} }
// register default error http handlers, 404,401,403,500 and 503.
func registerDefaultErrorHandler() {
if _, ok := ErrorMaps["401"]; !ok {
Errorhandler("401", unauthorized)
}
if _, ok := ErrorMaps["402"]; !ok {
Errorhandler("402", paymentRequired)
}
if _, ok := ErrorMaps["403"]; !ok {
Errorhandler("403", forbidden)
}
if _, ok := ErrorMaps["404"]; !ok {
Errorhandler("404", notFound)
}
if _, ok := ErrorMaps["405"]; !ok {
Errorhandler("405", methodNotAllowed)
}
if _, ok := ErrorMaps["500"]; !ok {
Errorhandler("500", internalServerError)
}
if _, ok := ErrorMaps["501"]; !ok {
Errorhandler("501", notImplemented)
}
if _, ok := ErrorMaps["502"]; !ok {
Errorhandler("502", badGateway)
}
if _, ok := ErrorMaps["503"]; !ok {
Errorhandler("503", serviceUnavailable)
}
if _, ok := ErrorMaps["504"]; !ok {
Errorhandler("504", gatewayTimeout)
}
}
// ErrorHandler registers http.HandlerFunc to each http err code string. // ErrorHandler registers http.HandlerFunc to each http err code string.
// usage: // usage:
// beego.ErrorHandler("404",NotFound) // beego.ErrorHandler("404",NotFound)
// beego.ErrorHandler("500",InternalServerError) // beego.ErrorHandler("500",InternalServerError)
func Errorhandler(code string, h http.HandlerFunc) *App { func ErrorHandler(code string, h http.HandlerFunc) *App {
errinfo := &errorInfo{} ErrorMaps[code] = &errorInfo{
errinfo.errorType = errorTypeHandler errorType: errorTypeHandler,
errinfo.handler = h handler: h,
errinfo.method = code method: code,
ErrorMaps[code] = errinfo }
return BeeApp return BeeApp
} }
// ErrorController registers ControllerInterface to each http err code string. // ErrorController registers ControllerInterface to each http err code string.
// usage: // usage:
// beego.ErrorHandler(&controllers.ErrorController{}) // beego.ErrorController(&controllers.ErrorController{})
func ErrorController(c ControllerInterface) *App { func ErrorController(c ControllerInterface) *App {
reflectVal := reflect.ValueOf(c) reflectVal := reflect.ValueOf(c)
rt := reflectVal.Type() rt := reflectVal.Type()
ct := reflect.Indirect(reflectVal).Type() ct := reflect.Indirect(reflectVal).Type()
for i := 0; i < rt.NumMethod(); i++ { for i := 0; i < rt.NumMethod(); i++ {
if !utils.InSlice(rt.Method(i).Name, exceptMethod) && strings.HasPrefix(rt.Method(i).Name, "Error") { methodName := rt.Method(i).Name
errinfo := &errorInfo{} if !utils.InSlice(methodName, exceptMethod) && strings.HasPrefix(methodName, "Error") {
errinfo.errorType = errorTypeController errName := strings.TrimPrefix(methodName, "Error")
errinfo.controllerType = ct ErrorMaps[errName] = &errorInfo{
errinfo.method = rt.Method(i).Name errorType: errorTypeController,
errname := strings.TrimPrefix(rt.Method(i).Name, "Error") controllerType: ct,
ErrorMaps[errname] = errinfo method: methodName,
}
} }
} }
return BeeApp return BeeApp
} }
// show error string as simple text message. // show error string as simple text message.
// if error string is empty, show 500 error as default. // if error string is empty, show 503 or 500 error as default.
func exception(errcode string, ctx *context.Context) { func exception(errCode string, ctx *context.Context) {
code, err := strconv.Atoi(errcode) atoi := func(code string) int {
if err != nil { v, err := strconv.Atoi(code)
code = 503 if err == nil {
return v
}
return 503
} }
if h, ok := ErrorMaps[errcode]; ok {
executeError(h, ctx, code) for _, ec := range []string{errCode, "503", "500"} {
return if h, ok := ErrorMaps[ec]; ok {
} else if h, ok := ErrorMaps["503"]; ok { executeError(h, ctx, atoi(ec))
executeError(h, ctx, code) return
return }
} else {
ctx.ResponseWriter.WriteHeader(code)
ctx.WriteString(errcode)
} }
//if 50x error has been removed from errorMap
ctx.ResponseWriter.WriteHeader(atoi(errCode))
ctx.WriteString(errCode)
} }
func executeError(err *errorInfo, ctx *context.Context, code int) { func executeError(err *errorInfo, ctx *context.Context, code int) {
if err.errorType == errorTypeHandler { if err.errorType == errorTypeHandler {
ctx.ResponseWriter.WriteHeader(code)
err.handler(ctx.ResponseWriter, ctx.Request) err.handler(ctx.ResponseWriter, ctx.Request)
return return
} }
@ -473,12 +443,11 @@ func executeError(err *errorInfo, ctx *context.Context, code int) {
execController.URLMapping() execController.URLMapping()
in := make([]reflect.Value, 0)
method := vc.MethodByName(err.method) method := vc.MethodByName(err.method)
method.Call(in) method.Call([]reflect.Value{})
//render template //render template
if AutoRender { if BConfig.WebConfig.AutoRender {
if err := execController.Render(); err != nil { if err := execController.Render(); err != nil {
panic(err) panic(err)
} }

View File

@ -1,5 +0,0 @@
appname = beeapi
httpport = 8080
runmode = dev
autorender = false
copyrequestbody = true

View File

@ -1,63 +0,0 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package controllers
import (
"encoding/json"
"github.com/astaxie/beego"
"github.com/astaxie/beego/example/beeapi/models"
)
type ObjectController struct {
beego.Controller
}
func (o *ObjectController) Post() {
var ob models.Object
json.Unmarshal(o.Ctx.Input.RequestBody, &ob)
objectid := models.AddOne(ob)
o.Data["json"] = map[string]string{"ObjectId": objectid}
o.ServeJson()
}
func (o *ObjectController) Get() {
objectId := o.Ctx.Input.Params[":objectId"]
if objectId != "" {
ob, err := models.GetOne(objectId)
if err != nil {
o.Data["json"] = err
} else {
o.Data["json"] = ob
}
} else {
obs := models.GetAll()
o.Data["json"] = obs
}
o.ServeJson()
}
func (o *ObjectController) Put() {
objectId := o.Ctx.Input.Params[":objectId"]
var ob models.Object
json.Unmarshal(o.Ctx.Input.RequestBody, &ob)
err := models.Update(objectId, ob.Score)
if err != nil {
o.Data["json"] = err
} else {
o.Data["json"] = "update success!"
}
o.ServeJson()
}
func (o *ObjectController) Delete() {
objectId := o.Ctx.Input.Params[":objectId"]
models.Delete(objectId)
o.Data["json"] = "delete success!"
o.ServeJson()
}

View File

@ -1,30 +0,0 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package main
import (
"github.com/astaxie/beego"
"github.com/astaxie/beego/example/beeapi/controllers"
)
// Objects
// URL HTTP Verb Functionality
// /object POST Creating Objects
// /object/<objectId> GET Retrieving Objects
// /object/<objectId> PUT Updating Objects
// /object GET Queries
// /object/<objectId> DELETE Deleting Objects
func main() {
beego.RESTRouter("/object", &controllers.ObjectController{})
beego.Run()
}

View File

@ -1,58 +0,0 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package models
import (
"errors"
"strconv"
"time"
)
var (
Objects map[string]*Object
)
type Object struct {
ObjectId string
Score int64
PlayerName string
}
func init() {
Objects = make(map[string]*Object)
Objects["hjkhsbnmn123"] = &Object{"hjkhsbnmn123", 100, "astaxie"}
Objects["mjjkxsxsaa23"] = &Object{"mjjkxsxsaa23", 101, "someone"}
}
func AddOne(object Object) (ObjectId string) {
object.ObjectId = "astaxie" + strconv.FormatInt(time.Now().UnixNano(), 10)
Objects[object.ObjectId] = &object
return object.ObjectId
}
func GetOne(ObjectId string) (object *Object, err error) {
if v, ok := Objects[ObjectId]; ok {
return v, nil
}
return nil, errors.New("ObjectId Not Exist")
}
func GetAll() map[string]*Object {
return Objects
}
func Update(ObjectId string, Score int64) (err error) {
if v, ok := Objects[ObjectId]; ok {
v.Score = Score
return nil
}
return errors.New("ObjectId Not Exist")
}
func Delete(ObjectId string) {
delete(Objects, ObjectId)
}

View File

@ -1,3 +0,0 @@
appname = chat
httpport = 8080
runmode = dev

View File

@ -1,20 +0,0 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors Unknwon
package controllers
import (
"github.com/astaxie/beego"
)
type MainController struct {
beego.Controller
}
func (m *MainController) Get() {
m.Data["host"] = m.Ctx.Request.Host
m.TplNames = "index.tpl"
}

View File

@ -1,181 +0,0 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors Unknwon
package controllers
import (
"io/ioutil"
"math/rand"
"net/http"
"time"
"github.com/astaxie/beego"
"github.com/gorilla/websocket"
)
const (
// Time allowed to write a message to the client.
writeWait = 10 * time.Second
// Time allowed to read the next message from the client.
readWait = 60 * time.Second
// Send pings to client with this period. Must be less than readWait.
pingPeriod = (readWait * 9) / 10
// Maximum message size allowed from client.
maxMessageSize = 512
)
func init() {
rand.Seed(time.Now().UTC().UnixNano())
go h.run()
}
// connection is an middleman between the websocket connection and the hub.
type connection struct {
username string
// The websocket connection.
ws *websocket.Conn
// Buffered channel of outbound messages.
send chan []byte
}
// readPump pumps messages from the websocket connection to the hub.
func (c *connection) readPump() {
defer func() {
h.unregister <- c
c.ws.Close()
}()
c.ws.SetReadLimit(maxMessageSize)
c.ws.SetReadDeadline(time.Now().Add(readWait))
for {
op, r, err := c.ws.NextReader()
if err != nil {
break
}
switch op {
case websocket.PongMessage:
c.ws.SetReadDeadline(time.Now().Add(readWait))
case websocket.TextMessage:
message, err := ioutil.ReadAll(r)
if err != nil {
break
}
h.broadcast <- []byte(c.username + "_" + time.Now().Format("15:04:05") + ":" + string(message))
}
}
}
// write writes a message with the given opCode and payload.
func (c *connection) write(opCode int, payload []byte) error {
c.ws.SetWriteDeadline(time.Now().Add(writeWait))
return c.ws.WriteMessage(opCode, payload)
}
// writePump pumps messages from the hub to the websocket connection.
func (c *connection) writePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
c.ws.Close()
}()
for {
select {
case message, ok := <-c.send:
if !ok {
c.write(websocket.CloseMessage, []byte{})
return
}
if err := c.write(websocket.TextMessage, message); err != nil {
return
}
case <-ticker.C:
if err := c.write(websocket.PingMessage, []byte{}); err != nil {
return
}
}
}
}
type hub struct {
// Registered connections.
connections map[*connection]bool
// Inbound messages from the connections.
broadcast chan []byte
// Register requests from the connections.
register chan *connection
// Unregister requests from connections.
unregister chan *connection
}
var h = &hub{
broadcast: make(chan []byte, maxMessageSize),
register: make(chan *connection, 1),
unregister: make(chan *connection, 1),
connections: make(map[*connection]bool),
}
func (h *hub) run() {
for {
select {
case c := <-h.register:
h.connections[c] = true
case c := <-h.unregister:
delete(h.connections, c)
close(c.send)
case m := <-h.broadcast:
for c := range h.connections {
select {
case c.send <- m:
default:
close(c.send)
delete(h.connections, c)
}
}
}
}
}
type WSController struct {
beego.Controller
}
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
func (w *WSController) Get() {
ws, err := upgrader.Upgrade(w.Ctx.ResponseWriter, w.Ctx.Request, nil)
if _, ok := err.(websocket.HandshakeError); ok {
http.Error(w.Ctx.ResponseWriter, "Not a websocket handshake", 400)
return
} else if err != nil {
return
}
c := &connection{send: make(chan []byte, 256), ws: ws, username: randomString(10)}
h.register <- c
go c.writePump()
c.readPump()
}
func randomString(l int) string {
bytes := make([]byte, l)
for i := 0; i < l; i++ {
bytes[i] = byte(randInt(65, 90))
}
return string(bytes)
}
func randInt(min int, max int) int {
return min + rand.Intn(max-min)
}

View File

@ -1,17 +0,0 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors Unknwon
package main
import (
"github.com/astaxie/beego"
"github.com/astaxie/beego/example/chat/controllers"
)
func main() {
beego.Router("/", &controllers.MainController{})
beego.Router("/ws", &controllers.WSController{})
beego.Run()
}

View File

@ -1,92 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<title>Chat Example</title>
<script src="//code.jquery.com/jquery-2.1.3.min.js"></script>
<script type="text/javascript">
$(function() {
var conn;
var msg = $("#msg");
var log = $("#log");
function appendLog(msg) {
var d = log[0]
var doScroll = d.scrollTop == d.scrollHeight - d.clientHeight;
msg.appendTo(log)
if (doScroll) {
d.scrollTop = d.scrollHeight - d.clientHeight;
}
}
$("#form").submit(function() {
if (!conn) {
return false;
}
if (!msg.val()) {
return false;
}
conn.send(msg.val());
msg.val("");
return false
});
if (window["WebSocket"]) {
conn = new WebSocket("ws://{{.host}}/ws");
conn.onclose = function(evt) {
appendLog($("<div><b>Connection closed.</b></div>"))
}
conn.onmessage = function(evt) {
appendLog($("<div/>").text(evt.data))
}
} else {
appendLog($("<div><b>Your browser does not support WebSockets.</b></div>"))
}
});
</script>
<style type="text/css">
html {
overflow: hidden;
}
body {
overflow: hidden;
padding: 0;
margin: 0;
width: 100%;
height: 100%;
background: gray;
}
#log {
background: white;
margin: 0;
padding: 0.5em 0.5em 0.5em 0.5em;
position: absolute;
top: 0.5em;
left: 0.5em;
right: 0.5em;
bottom: 3em;
overflow: auto;
}
#form {
padding: 0 0.5em 0 0.5em;
margin: 0;
position: absolute;
bottom: 1em;
left: 0px;
width: 100%;
overflow: hidden;
}
</style>
</head>
<body>
<div id="log"></div>
<form id="form">
<input type="submit" value="Send" />
<input type="text" id="msg" size="64"/>
</form>
</body>
</html>

View File

@ -32,14 +32,12 @@ type FilterRouter struct {
// ValidRouter checks if the current request is matched by this filter. // ValidRouter checks if the current request is matched by this filter.
// If the request is matched, the values of the URL parameters defined // If the request is matched, the values of the URL parameters defined
// by the filter pattern are also returned. // by the filter pattern are also returned.
func (f *FilterRouter) ValidRouter(url string) (bool, map[string]string) { func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool {
isok, params := f.tree.Match(url) isOk := f.tree.Match(url, ctx)
if isok == nil { if isOk != nil {
return false, nil if b, ok := isOk.(bool); ok {
} return b
if isok, ok := isok.(bool); ok { }
return isok, params
} else {
return false, nil
} }
return false
} }

View File

@ -20,10 +20,16 @@ import (
"testing" "testing"
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
"github.com/astaxie/beego/logs"
) )
func init() {
BeeLogger = logs.NewLogger(10000)
BeeLogger.SetLogger("console", "")
}
var FilterUser = func(ctx *context.Context) { var FilterUser = func(ctx *context.Context) {
ctx.Output.Body([]byte("i am " + ctx.Input.Params[":last"] + ctx.Input.Params[":first"])) ctx.Output.Body([]byte("i am " + ctx.Input.Param(":last") + ctx.Input.Param(":first")))
} }
func TestFilter(t *testing.T) { func TestFilter(t *testing.T) {

View File

@ -83,27 +83,27 @@ func (fd *FlashData) Store(c *Controller) {
c.Data["flash"] = fd.Data c.Data["flash"] = fd.Data
var flashValue string var flashValue string
for key, value := range fd.Data { for key, value := range fd.Data {
flashValue += "\x00" + key + "\x23" + FlashSeperator + "\x23" + value + "\x00" flashValue += "\x00" + key + "\x23" + BConfig.WebConfig.FlashSeparator + "\x23" + value + "\x00"
} }
c.Ctx.SetCookie(FlashName, url.QueryEscape(flashValue), 0, "/") c.Ctx.SetCookie(BConfig.WebConfig.FlashName, url.QueryEscape(flashValue), 0, "/")
} }
// ReadFromRequest parsed flash data from encoded values in cookie. // ReadFromRequest parsed flash data from encoded values in cookie.
func ReadFromRequest(c *Controller) *FlashData { func ReadFromRequest(c *Controller) *FlashData {
flash := NewFlash() flash := NewFlash()
if cookie, err := c.Ctx.Request.Cookie(FlashName); err == nil { if cookie, err := c.Ctx.Request.Cookie(BConfig.WebConfig.FlashName); err == nil {
v, _ := url.QueryUnescape(cookie.Value) v, _ := url.QueryUnescape(cookie.Value)
vals := strings.Split(v, "\x00") vals := strings.Split(v, "\x00")
for _, v := range vals { for _, v := range vals {
if len(v) > 0 { if len(v) > 0 {
kv := strings.Split(v, "\x23"+FlashSeperator+"\x23") kv := strings.Split(v, "\x23"+BConfig.WebConfig.FlashSeparator+"\x23")
if len(kv) == 2 { if len(kv) == 2 {
flash.Data[kv[0]] = kv[1] flash.Data[kv[0]] = kv[1]
} }
} }
} }
//read one time then delete it //read one time then delete it
c.Ctx.SetCookie(FlashName, "", -1, "/") c.Ctx.SetCookie(BConfig.WebConfig.FlashName, "", -1, "/")
} }
c.Data["flash"] = flash.Data c.Data["flash"] = flash.Data
return flash return flash

View File

@ -30,7 +30,7 @@ func (t *TestFlashController) TestWriteFlash() {
flash.Notice("TestFlashString") flash.Notice("TestFlashString")
flash.Store(&t.Controller) flash.Store(&t.Controller)
// we choose to serve json because we don't want to load a template html file // we choose to serve json because we don't want to load a template html file
t.ServeJson(true) t.ServeJSON(true)
} }
func TestFlashHeader(t *testing.T) { func TestFlashHeader(t *testing.T) {

View File

@ -1,13 +1,28 @@
package grace package grace
import "net" import (
"errors"
"net"
)
type graceConn struct { type graceConn struct {
net.Conn net.Conn
server *graceServer server *Server
} }
func (c graceConn) Close() error { func (c graceConn) Close() (err error) {
defer func() {
if r := recover(); r != nil {
switch x := r.(type) {
case string:
err = errors.New(x)
case error:
err = x
default:
err = errors.New("Unknown panic")
}
}
}()
c.server.wg.Done() c.server.wg.Done()
return c.Conn.Close() return c.Conn.Close()
} }

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Package grace use to hot reload
// Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/ // Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/
// //
// Usage: // Usage:
@ -32,7 +33,7 @@
// mux := http.NewServeMux() // mux := http.NewServeMux()
// mux.HandleFunc("/hello", handler) // mux.HandleFunc("/hello", handler)
// //
// err := grace.ListenAndServe("localhost:8080", mux1) // err := grace.ListenAndServe("localhost:8080", mux)
// if err != nil { // if err != nil {
// log.Println(err) // log.Println(err)
// } // }
@ -52,46 +53,53 @@ import (
) )
const ( const (
PRE_SIGNAL = iota // PreSignal is the position to add filter before signal
POST_SIGNAL PreSignal = iota
// PostSignal is the position to add filter after signal
STATE_INIT PostSignal
STATE_RUNNING // StateInit represent the application inited
STATE_SHUTTING_DOWN StateInit
STATE_TERMINATE // StateRunning represent the application is running
StateRunning
// StateShuttingDown represent the application is shutting down
StateShuttingDown
// StateTerminate represent the application is killed
StateTerminate
) )
var ( var (
regLock *sync.Mutex regLock *sync.Mutex
runningServers map[string]*graceServer runningServers map[string]*Server
runningServersOrder []string runningServersOrder []string
socketPtrOffsetMap map[string]uint socketPtrOffsetMap map[string]uint
runningServersForked bool runningServersForked bool
DefaultReadTimeOut time.Duration // DefaultReadTimeOut is the HTTP read timeout
DefaultWriteTimeOut time.Duration DefaultReadTimeOut time.Duration
// DefaultWriteTimeOut is the HTTP Write timeout
DefaultWriteTimeOut time.Duration
// DefaultMaxHeaderBytes is the Max HTTP Herder size, default is 0, no limit
DefaultMaxHeaderBytes int DefaultMaxHeaderBytes int
DefaultTimeout time.Duration // DefaultTimeout is the shutdown server's timeout. default is 60s
DefaultTimeout = 60 * time.Second
isChild bool isChild bool
socketOrder string socketOrder string
once sync.Once
) )
func init() { func onceInit() {
regLock = &sync.Mutex{} regLock = &sync.Mutex{}
flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)") flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)")
flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started") flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started")
runningServers = make(map[string]*graceServer) runningServers = make(map[string]*Server)
runningServersOrder = []string{} runningServersOrder = []string{}
socketPtrOffsetMap = make(map[string]uint) socketPtrOffsetMap = make(map[string]uint)
DefaultMaxHeaderBytes = 0
DefaultTimeout = 60 * time.Second
} }
// NewServer returns a new graceServer. // NewServer returns a new graceServer.
func NewServer(addr string, handler http.Handler) (srv *graceServer) { func NewServer(addr string, handler http.Handler) (srv *Server) {
once.Do(onceInit)
regLock.Lock() regLock.Lock()
defer regLock.Unlock() defer regLock.Unlock()
if !flag.Parsed() { if !flag.Parsed() {
@ -105,23 +113,23 @@ func NewServer(addr string, handler http.Handler) (srv *graceServer) {
socketPtrOffsetMap[addr] = uint(len(runningServersOrder)) socketPtrOffsetMap[addr] = uint(len(runningServersOrder))
} }
srv = &graceServer{ srv = &Server{
wg: sync.WaitGroup{}, wg: sync.WaitGroup{},
sigChan: make(chan os.Signal), sigChan: make(chan os.Signal),
isChild: isChild, isChild: isChild,
SignalHooks: map[int]map[os.Signal][]func(){ SignalHooks: map[int]map[os.Signal][]func(){
PRE_SIGNAL: map[os.Signal][]func(){ PreSignal: map[os.Signal][]func(){
syscall.SIGHUP: []func(){}, syscall.SIGHUP: []func(){},
syscall.SIGINT: []func(){}, syscall.SIGINT: []func(){},
syscall.SIGTERM: []func(){}, syscall.SIGTERM: []func(){},
}, },
POST_SIGNAL: map[os.Signal][]func(){ PostSignal: map[os.Signal][]func(){
syscall.SIGHUP: []func(){}, syscall.SIGHUP: []func(){},
syscall.SIGINT: []func(){}, syscall.SIGINT: []func(){},
syscall.SIGTERM: []func(){}, syscall.SIGTERM: []func(){},
}, },
}, },
state: STATE_INIT, state: StateInit,
Network: "tcp", Network: "tcp",
} }
srv.Server = &http.Server{} srv.Server = &http.Server{}
@ -137,13 +145,13 @@ func NewServer(addr string, handler http.Handler) (srv *graceServer) {
return return
} }
// refer http.ListenAndServe // ListenAndServe refer http.ListenAndServe
func ListenAndServe(addr string, handler http.Handler) error { func ListenAndServe(addr string, handler http.Handler) error {
server := NewServer(addr, handler) server := NewServer(addr, handler)
return server.ListenAndServe() return server.ListenAndServe()
} }
// refer http.ListenAndServeTLS // ListenAndServeTLS refer http.ListenAndServeTLS
func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error { func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error {
server := NewServer(addr, handler) server := NewServer(addr, handler)
return server.ListenAndServeTLS(certFile, keyFile) return server.ListenAndServeTLS(certFile, keyFile)

View File

@ -11,10 +11,10 @@ type graceListener struct {
net.Listener net.Listener
stop chan error stop chan error
stopped bool stopped bool
server *graceServer server *Server
} }
func newGraceListener(l net.Listener, srv *graceServer) (el *graceListener) { func newGraceListener(l net.Listener, srv *Server) (el *graceListener) {
el = &graceListener{ el = &graceListener{
Listener: l, Listener: l,
stop: make(chan error), stop: make(chan error),
@ -46,17 +46,17 @@ func (gl *graceListener) Accept() (c net.Conn, err error) {
return return
} }
func (el *graceListener) Close() error { func (gl *graceListener) Close() error {
if el.stopped { if gl.stopped {
return syscall.EINVAL return syscall.EINVAL
} }
el.stop <- nil gl.stop <- nil
return <-el.stop return <-gl.stop
} }
func (el *graceListener) File() *os.File { func (gl *graceListener) File() *os.File {
// returns a dup(2) - FD_CLOEXEC flag *not* set // returns a dup(2) - FD_CLOEXEC flag *not* set
tl := el.Listener.(*net.TCPListener) tl := gl.Listener.(*net.TCPListener)
fl, _ := tl.File() fl, _ := tl.File()
return fl return fl
} }

View File

@ -15,7 +15,8 @@ import (
"time" "time"
) )
type graceServer struct { // Server embedded http.Server
type Server struct {
*http.Server *http.Server
GraceListener net.Listener GraceListener net.Listener
SignalHooks map[int]map[os.Signal][]func() SignalHooks map[int]map[os.Signal][]func()
@ -30,19 +31,19 @@ type graceServer struct {
// Serve accepts incoming connections on the Listener l, // Serve accepts incoming connections on the Listener l,
// creating a new service goroutine for each. // creating a new service goroutine for each.
// The service goroutines read requests and then call srv.Handler to reply to them. // The service goroutines read requests and then call srv.Handler to reply to them.
func (srv *graceServer) Serve() (err error) { func (srv *Server) Serve() (err error) {
srv.state = STATE_RUNNING srv.state = StateRunning
err = srv.Server.Serve(srv.GraceListener) err = srv.Server.Serve(srv.GraceListener)
log.Println(syscall.Getpid(), "Waiting for connections to finish...") log.Println(syscall.Getpid(), "Waiting for connections to finish...")
srv.wg.Wait() srv.wg.Wait()
srv.state = STATE_TERMINATE srv.state = StateTerminate
return return
} }
// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
// to handle requests on incoming connections. If srv.Addr is blank, ":http" is // to handle requests on incoming connections. If srv.Addr is blank, ":http" is
// used. // used.
func (srv *graceServer) ListenAndServe() (err error) { func (srv *Server) ListenAndServe() (err error) {
addr := srv.Addr addr := srv.Addr
if addr == "" { if addr == "" {
addr = ":http" addr = ":http"
@ -83,7 +84,7 @@ func (srv *graceServer) ListenAndServe() (err error) {
// CA's certificate. // CA's certificate.
// //
// If srv.Addr is blank, ":https" is used. // If srv.Addr is blank, ":https" is used.
func (srv *graceServer) ListenAndServeTLS(certFile, keyFile string) (err error) { func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
addr := srv.Addr addr := srv.Addr
if addr == "" { if addr == "" {
addr = ":https" addr = ":https"
@ -131,9 +132,9 @@ func (srv *graceServer) ListenAndServeTLS(certFile, keyFile string) (err error)
// getListener either opens a new socket to listen on, or takes the acceptor socket // getListener either opens a new socket to listen on, or takes the acceptor socket
// it got passed when restarted. // it got passed when restarted.
func (srv *graceServer) getListener(laddr string) (l net.Listener, err error) { func (srv *Server) getListener(laddr string) (l net.Listener, err error) {
if srv.isChild { if srv.isChild {
var ptrOffset uint = 0 var ptrOffset uint
if len(socketPtrOffsetMap) > 0 { if len(socketPtrOffsetMap) > 0 {
ptrOffset = socketPtrOffsetMap[laddr] ptrOffset = socketPtrOffsetMap[laddr]
log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr]) log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr])
@ -157,7 +158,7 @@ func (srv *graceServer) getListener(laddr string) (l net.Listener, err error) {
// handleSignals listens for os Signals and calls any hooked in function that the // handleSignals listens for os Signals and calls any hooked in function that the
// user had registered with the signal. // user had registered with the signal.
func (srv *graceServer) handleSignals() { func (srv *Server) handleSignals() {
var sig os.Signal var sig os.Signal
signal.Notify( signal.Notify(
@ -170,7 +171,7 @@ func (srv *graceServer) handleSignals() {
pid := syscall.Getpid() pid := syscall.Getpid()
for { for {
sig = <-srv.sigChan sig = <-srv.sigChan
srv.signalHooks(PRE_SIGNAL, sig) srv.signalHooks(PreSignal, sig)
switch sig { switch sig {
case syscall.SIGHUP: case syscall.SIGHUP:
log.Println(pid, "Received SIGHUP. forking.") log.Println(pid, "Received SIGHUP. forking.")
@ -187,11 +188,11 @@ func (srv *graceServer) handleSignals() {
default: default:
log.Printf("Received %v: nothing i care about...\n", sig) log.Printf("Received %v: nothing i care about...\n", sig)
} }
srv.signalHooks(POST_SIGNAL, sig) srv.signalHooks(PostSignal, sig)
} }
} }
func (srv *graceServer) signalHooks(ppFlag int, sig os.Signal) { func (srv *Server) signalHooks(ppFlag int, sig os.Signal) {
if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet { if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet {
return return
} }
@ -204,12 +205,12 @@ func (srv *graceServer) signalHooks(ppFlag int, sig os.Signal) {
// shutdown closes the listener so that no new connections are accepted. it also // shutdown closes the listener so that no new connections are accepted. it also
// starts a goroutine that will serverTimeout (stop all running requests) the server // starts a goroutine that will serverTimeout (stop all running requests) the server
// after DefaultTimeout. // after DefaultTimeout.
func (srv *graceServer) shutdown() { func (srv *Server) shutdown() {
if srv.state != STATE_RUNNING { if srv.state != StateRunning {
return return
} }
srv.state = STATE_SHUTTING_DOWN srv.state = StateShuttingDown
if DefaultTimeout >= 0 { if DefaultTimeout >= 0 {
go srv.serverTimeout(DefaultTimeout) go srv.serverTimeout(DefaultTimeout)
} }
@ -224,26 +225,26 @@ func (srv *graceServer) shutdown() {
// serverTimeout forces the server to shutdown in a given timeout - whether it // serverTimeout forces the server to shutdown in a given timeout - whether it
// finished outstanding requests or not. if Read/WriteTimeout are not set or the // finished outstanding requests or not. if Read/WriteTimeout are not set or the
// max header size is very big a connection could hang // max header size is very big a connection could hang
func (srv *graceServer) serverTimeout(d time.Duration) { func (srv *Server) serverTimeout(d time.Duration) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
log.Println("WaitGroup at 0", r) log.Println("WaitGroup at 0", r)
} }
}() }()
if srv.state != STATE_SHUTTING_DOWN { if srv.state != StateShuttingDown {
return return
} }
time.Sleep(d) time.Sleep(d)
log.Println("[STOP - Hammer Time] Forcefully shutting down parent") log.Println("[STOP - Hammer Time] Forcefully shutting down parent")
for { for {
if srv.state == STATE_TERMINATE { if srv.state == StateTerminate {
break break
} }
srv.wg.Done() srv.wg.Done()
} }
} }
func (srv *graceServer) fork() (err error) { func (srv *Server) fork() (err error) {
regLock.Lock() regLock.Lock()
defer regLock.Unlock() defer regLock.Unlock()
if runningServersForked { if runningServersForked {

95
hooks.go Normal file
View File

@ -0,0 +1,95 @@
package beego
import (
"encoding/json"
"mime"
"net/http"
"path/filepath"
"github.com/astaxie/beego/session"
)
//
func registerMime() error {
for k, v := range mimemaps {
mime.AddExtensionType(k, v)
}
return nil
}
// register default error http handlers, 404,401,403,500 and 503.
func registerDefaultErrorHandler() error {
m := map[string]func(http.ResponseWriter, *http.Request){
"401": unauthorized,
"402": paymentRequired,
"403": forbidden,
"404": notFound,
"405": methodNotAllowed,
"500": internalServerError,
"501": notImplemented,
"502": badGateway,
"503": serviceUnavailable,
"504": gatewayTimeout,
}
for e, h := range m {
if _, ok := ErrorMaps[e]; !ok {
ErrorHandler(e, h)
}
}
return nil
}
func registerSession() error {
if BConfig.WebConfig.Session.SessionOn {
var err error
sessionConfig := AppConfig.String("sessionConfig")
if sessionConfig == "" {
conf := map[string]interface{}{
"cookieName": BConfig.WebConfig.Session.SessionName,
"gclifetime": BConfig.WebConfig.Session.SessionGCMaxLifetime,
"providerConfig": filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig),
"secure": BConfig.Listen.EnableHTTPS,
"enableSetCookie": BConfig.WebConfig.Session.SessionAutoSetCookie,
"domain": BConfig.WebConfig.Session.SessionDomain,
"cookieLifeTime": BConfig.WebConfig.Session.SessionCookieLifeTime,
}
confBytes, err := json.Marshal(conf)
if err != nil {
return err
}
sessionConfig = string(confBytes)
}
if GlobalSessions, err = session.NewManager(BConfig.WebConfig.Session.SessionProvider, sessionConfig); err != nil {
return err
}
go GlobalSessions.GC()
}
return nil
}
func registerTemplate() error {
if BConfig.WebConfig.AutoRender {
if err := BuildTemplate(BConfig.WebConfig.ViewsPath); err != nil {
if BConfig.RunMode == DEV {
Warn(err)
}
return err
}
}
return nil
}
func registerDocs() error {
if BConfig.WebConfig.EnableDocs {
Get("/docs", serverDocs)
Get("/docs/*", serverDocs)
}
return nil
}
func registerAdmin() error {
if BConfig.Listen.EnableAdmin {
go beeAdminApp.Run()
}
return nil
}

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Package httplib is used as http.Client
// Usage: // Usage:
// //
// import "github.com/astaxie/beego/httplib" // import "github.com/astaxie/beego/httplib"
@ -51,7 +52,7 @@ import (
"time" "time"
) )
var defaultSetting = BeegoHttpSettings{ var defaultSetting = BeegoHTTPSettings{
UserAgent: "beegoServer", UserAgent: "beegoServer",
ConnectTimeout: 60 * time.Second, ConnectTimeout: 60 * time.Second,
ReadWriteTimeout: 60 * time.Second, ReadWriteTimeout: 60 * time.Second,
@ -69,25 +70,19 @@ func createDefaultCookie() {
defaultCookieJar, _ = cookiejar.New(nil) defaultCookieJar, _ = cookiejar.New(nil)
} }
// Overwrite default settings // SetDefaultSetting Overwrite default settings
func SetDefaultSetting(setting BeegoHttpSettings) { func SetDefaultSetting(setting BeegoHTTPSettings) {
settingMutex.Lock() settingMutex.Lock()
defer settingMutex.Unlock() defer settingMutex.Unlock()
defaultSetting = setting defaultSetting = setting
if defaultSetting.ConnectTimeout == 0 {
defaultSetting.ConnectTimeout = 60 * time.Second
}
if defaultSetting.ReadWriteTimeout == 0 {
defaultSetting.ReadWriteTimeout = 60 * time.Second
}
} }
// return *BeegoHttpRequest with specific method // NewBeegoRequest return *BeegoHttpRequest with specific method
func NewBeegoRequest(rawurl, method string) *BeegoHttpRequest { func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest {
var resp http.Response var resp http.Response
u, err := url.Parse(rawurl) u, err := url.Parse(rawurl)
if err != nil { if err != nil {
log.Fatal(err) log.Println("Httplib:", err)
} }
req := http.Request{ req := http.Request{
URL: u, URL: u,
@ -97,10 +92,10 @@ func NewBeegoRequest(rawurl, method string) *BeegoHttpRequest {
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 1, ProtoMinor: 1,
} }
return &BeegoHttpRequest{ return &BeegoHTTPRequest{
url: rawurl, url: rawurl,
req: &req, req: &req,
params: map[string]string{}, params: map[string][]string{},
files: map[string]string{}, files: map[string]string{},
setting: defaultSetting, setting: defaultSetting,
resp: &resp, resp: &resp,
@ -108,37 +103,37 @@ func NewBeegoRequest(rawurl, method string) *BeegoHttpRequest {
} }
// Get returns *BeegoHttpRequest with GET method. // Get returns *BeegoHttpRequest with GET method.
func Get(url string) *BeegoHttpRequest { func Get(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "GET") return NewBeegoRequest(url, "GET")
} }
// Post returns *BeegoHttpRequest with POST method. // Post returns *BeegoHttpRequest with POST method.
func Post(url string) *BeegoHttpRequest { func Post(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "POST") return NewBeegoRequest(url, "POST")
} }
// Put returns *BeegoHttpRequest with PUT method. // Put returns *BeegoHttpRequest with PUT method.
func Put(url string) *BeegoHttpRequest { func Put(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "PUT") return NewBeegoRequest(url, "PUT")
} }
// Delete returns *BeegoHttpRequest DELETE method. // Delete returns *BeegoHttpRequest DELETE method.
func Delete(url string) *BeegoHttpRequest { func Delete(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "DELETE") return NewBeegoRequest(url, "DELETE")
} }
// Head returns *BeegoHttpRequest with HEAD method. // Head returns *BeegoHttpRequest with HEAD method.
func Head(url string) *BeegoHttpRequest { func Head(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "HEAD") return NewBeegoRequest(url, "HEAD")
} }
// BeegoHttpSettings // BeegoHTTPSettings is the http.Client setting
type BeegoHttpSettings struct { type BeegoHTTPSettings struct {
ShowDebug bool ShowDebug bool
UserAgent string UserAgent string
ConnectTimeout time.Duration ConnectTimeout time.Duration
ReadWriteTimeout time.Duration ReadWriteTimeout time.Duration
TlsClientConfig *tls.Config TLSClientConfig *tls.Config
Proxy func(*http.Request) (*url.URL, error) Proxy func(*http.Request) (*url.URL, error)
Transport http.RoundTripper Transport http.RoundTripper
EnableCookie bool EnableCookie bool
@ -146,92 +141,92 @@ type BeegoHttpSettings struct {
DumpBody bool DumpBody bool
} }
// BeegoHttpRequest provides more useful methods for requesting one url than http.Request. // BeegoHTTPRequest provides more useful methods for requesting one url than http.Request.
type BeegoHttpRequest struct { type BeegoHTTPRequest struct {
url string url string
req *http.Request req *http.Request
params map[string]string params map[string][]string
files map[string]string files map[string]string
setting BeegoHttpSettings setting BeegoHTTPSettings
resp *http.Response resp *http.Response
body []byte body []byte
dump []byte dump []byte
} }
// get request // GetRequest return the request object
func (b *BeegoHttpRequest) GetRequest() *http.Request { func (b *BeegoHTTPRequest) GetRequest() *http.Request {
return b.req return b.req
} }
// Change request settings // Setting Change request settings
func (b *BeegoHttpRequest) Setting(setting BeegoHttpSettings) *BeegoHttpRequest { func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest {
b.setting = setting b.setting = setting
return b return b
} }
// SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password. // SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password.
func (b *BeegoHttpRequest) SetBasicAuth(username, password string) *BeegoHttpRequest { func (b *BeegoHTTPRequest) SetBasicAuth(username, password string) *BeegoHTTPRequest {
b.req.SetBasicAuth(username, password) b.req.SetBasicAuth(username, password)
return b return b
} }
// SetEnableCookie sets enable/disable cookiejar // SetEnableCookie sets enable/disable cookiejar
func (b *BeegoHttpRequest) SetEnableCookie(enable bool) *BeegoHttpRequest { func (b *BeegoHTTPRequest) SetEnableCookie(enable bool) *BeegoHTTPRequest {
b.setting.EnableCookie = enable b.setting.EnableCookie = enable
return b return b
} }
// SetUserAgent sets User-Agent header field // SetUserAgent sets User-Agent header field
func (b *BeegoHttpRequest) SetUserAgent(useragent string) *BeegoHttpRequest { func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest {
b.setting.UserAgent = useragent b.setting.UserAgent = useragent
return b return b
} }
// Debug sets show debug or not when executing request. // Debug sets show debug or not when executing request.
func (b *BeegoHttpRequest) Debug(isdebug bool) *BeegoHttpRequest { func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest {
b.setting.ShowDebug = isdebug b.setting.ShowDebug = isdebug
return b return b
} }
// Dump Body. // DumpBody setting whether need to Dump the Body.
func (b *BeegoHttpRequest) DumpBody(isdump bool) *BeegoHttpRequest { func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest {
b.setting.DumpBody = isdump b.setting.DumpBody = isdump
return b return b
} }
// return the DumpRequest // DumpRequest return the DumpRequest
func (b *BeegoHttpRequest) DumpRequest() []byte { func (b *BeegoHTTPRequest) DumpRequest() []byte {
return b.dump return b.dump
} }
// SetTimeout sets connect time out and read-write time out for BeegoRequest. // SetTimeout sets connect time out and read-write time out for BeegoRequest.
func (b *BeegoHttpRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHttpRequest { func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest {
b.setting.ConnectTimeout = connectTimeout b.setting.ConnectTimeout = connectTimeout
b.setting.ReadWriteTimeout = readWriteTimeout b.setting.ReadWriteTimeout = readWriteTimeout
return b return b
} }
// SetTLSClientConfig sets tls connection configurations if visiting https url. // SetTLSClientConfig sets tls connection configurations if visiting https url.
func (b *BeegoHttpRequest) SetTLSClientConfig(config *tls.Config) *BeegoHttpRequest { func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest {
b.setting.TlsClientConfig = config b.setting.TLSClientConfig = config
return b return b
} }
// Header add header item string in request. // Header add header item string in request.
func (b *BeegoHttpRequest) Header(key, value string) *BeegoHttpRequest { func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest {
b.req.Header.Set(key, value) b.req.Header.Set(key, value)
return b return b
} }
// Set HOST // SetHost set the request host
func (b *BeegoHttpRequest) SetHost(host string) *BeegoHttpRequest { func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest {
b.req.Host = host b.req.Host = host
return b return b
} }
// Set the protocol version for incoming requests. // SetProtocolVersion Set the protocol version for incoming requests.
// Client requests always use HTTP/1.1. // Client requests always use HTTP/1.1.
func (b *BeegoHttpRequest) SetProtocolVersion(vers string) *BeegoHttpRequest { func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest {
if len(vers) == 0 { if len(vers) == 0 {
vers = "HTTP/1.1" vers = "HTTP/1.1"
} }
@ -247,44 +242,49 @@ func (b *BeegoHttpRequest) SetProtocolVersion(vers string) *BeegoHttpRequest {
} }
// SetCookie add cookie into request. // SetCookie add cookie into request.
func (b *BeegoHttpRequest) SetCookie(cookie *http.Cookie) *BeegoHttpRequest { func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest {
b.req.Header.Add("Cookie", cookie.String()) b.req.Header.Add("Cookie", cookie.String())
return b return b
} }
// Set transport to // SetTransport set the setting transport
func (b *BeegoHttpRequest) SetTransport(transport http.RoundTripper) *BeegoHttpRequest { func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest {
b.setting.Transport = transport b.setting.Transport = transport
return b return b
} }
// Set http proxy // SetProxy set the http proxy
// example: // example:
// //
// func(req *http.Request) (*url.URL, error) { // func(req *http.Request) (*url.URL, error) {
// u, _ := url.ParseRequestURI("http://127.0.0.1:8118") // u, _ := url.ParseRequestURI("http://127.0.0.1:8118")
// return u, nil // return u, nil
// } // }
func (b *BeegoHttpRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHttpRequest { func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHTTPRequest {
b.setting.Proxy = proxy b.setting.Proxy = proxy
return b return b
} }
// Param adds query param in to request. // Param adds query param in to request.
// params build query string as ?key1=value1&key2=value2... // params build query string as ?key1=value1&key2=value2...
func (b *BeegoHttpRequest) Param(key, value string) *BeegoHttpRequest { func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest {
b.params[key] = value if param, ok := b.params[key]; ok {
b.params[key] = append(param, value)
} else {
b.params[key] = []string{value}
}
return b return b
} }
func (b *BeegoHttpRequest) PostFile(formname, filename string) *BeegoHttpRequest { // PostFile add a post file to the request
func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest {
b.files[formname] = filename b.files[formname] = filename
return b return b
} }
// Body adds request raw body. // Body adds request raw body.
// it supports string and []byte. // it supports string and []byte.
func (b *BeegoHttpRequest) Body(data interface{}) *BeegoHttpRequest { func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest {
switch t := data.(type) { switch t := data.(type) {
case string: case string:
bf := bytes.NewBufferString(t) bf := bytes.NewBufferString(t)
@ -298,8 +298,8 @@ func (b *BeegoHttpRequest) Body(data interface{}) *BeegoHttpRequest {
return b return b
} }
// JsonBody adds request raw body encoding by JSON. // JSONBody adds request raw body encoding by JSON.
func (b *BeegoHttpRequest) JsonBody(obj interface{}) (*BeegoHttpRequest, error) { func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) {
if b.req.Body == nil && obj != nil { if b.req.Body == nil && obj != nil {
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
enc := json.NewEncoder(buf) enc := json.NewEncoder(buf)
@ -313,7 +313,7 @@ func (b *BeegoHttpRequest) JsonBody(obj interface{}) (*BeegoHttpRequest, error)
return b, nil return b, nil
} }
func (b *BeegoHttpRequest) buildUrl(paramBody string) { func (b *BeegoHTTPRequest) buildURL(paramBody string) {
// build GET url with query string // build GET url with query string
if b.req.Method == "GET" && len(paramBody) > 0 { if b.req.Method == "GET" && len(paramBody) > 0 {
if strings.Index(b.url, "?") != -1 { if strings.Index(b.url, "?") != -1 {
@ -334,21 +334,23 @@ func (b *BeegoHttpRequest) buildUrl(paramBody string) {
for formname, filename := range b.files { for formname, filename := range b.files {
fileWriter, err := bodyWriter.CreateFormFile(formname, filename) fileWriter, err := bodyWriter.CreateFormFile(formname, filename)
if err != nil { if err != nil {
log.Fatal(err) log.Println("Httplib:", err)
} }
fh, err := os.Open(filename) fh, err := os.Open(filename)
if err != nil { if err != nil {
log.Fatal(err) log.Println("Httplib:", err)
} }
//iocopy //iocopy
_, err = io.Copy(fileWriter, fh) _, err = io.Copy(fileWriter, fh)
fh.Close() fh.Close()
if err != nil { if err != nil {
log.Fatal(err) log.Println("Httplib:", err)
} }
} }
for k, v := range b.params { for k, v := range b.params {
bodyWriter.WriteField(k, v) for _, vv := range v {
bodyWriter.WriteField(k, vv)
}
} }
bodyWriter.Close() bodyWriter.Close()
pw.Close() pw.Close()
@ -366,11 +368,11 @@ func (b *BeegoHttpRequest) buildUrl(paramBody string) {
} }
} }
func (b *BeegoHttpRequest) getResponse() (*http.Response, error) { func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) {
if b.resp.StatusCode != 0 { if b.resp.StatusCode != 0 {
return b.resp, nil return b.resp, nil
} }
resp, err := b.SendOut() resp, err := b.DoRequest()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -378,21 +380,24 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
return resp, nil return resp, nil
} }
func (b *BeegoHttpRequest) SendOut() (*http.Response, error) { // DoRequest will do the client.Do
func (b *BeegoHTTPRequest) DoRequest() (*http.Response, error) {
var paramBody string var paramBody string
if len(b.params) > 0 { if len(b.params) > 0 {
var buf bytes.Buffer var buf bytes.Buffer
for k, v := range b.params { for k, v := range b.params {
buf.WriteString(url.QueryEscape(k)) for _, vv := range v {
buf.WriteByte('=') buf.WriteString(url.QueryEscape(k))
buf.WriteString(url.QueryEscape(v)) buf.WriteByte('=')
buf.WriteByte('&') buf.WriteString(url.QueryEscape(vv))
buf.WriteByte('&')
}
} }
paramBody = buf.String() paramBody = buf.String()
paramBody = paramBody[0 : len(paramBody)-1] paramBody = paramBody[0 : len(paramBody)-1]
} }
b.buildUrl(paramBody) b.buildURL(paramBody)
url, err := url.Parse(b.url) url, err := url.Parse(b.url)
if err != nil { if err != nil {
return nil, err return nil, err
@ -405,7 +410,7 @@ func (b *BeegoHttpRequest) SendOut() (*http.Response, error) {
if trans == nil { if trans == nil {
// create default transport // create default transport
trans = &http.Transport{ trans = &http.Transport{
TLSClientConfig: b.setting.TlsClientConfig, TLSClientConfig: b.setting.TLSClientConfig,
Proxy: b.setting.Proxy, Proxy: b.setting.Proxy,
Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout), Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout),
} }
@ -413,7 +418,7 @@ func (b *BeegoHttpRequest) SendOut() (*http.Response, error) {
// if b.transport is *http.Transport then set the settings. // if b.transport is *http.Transport then set the settings.
if t, ok := trans.(*http.Transport); ok { if t, ok := trans.(*http.Transport); ok {
if t.TLSClientConfig == nil { if t.TLSClientConfig == nil {
t.TLSClientConfig = b.setting.TlsClientConfig t.TLSClientConfig = b.setting.TLSClientConfig
} }
if t.Proxy == nil { if t.Proxy == nil {
t.Proxy = b.setting.Proxy t.Proxy = b.setting.Proxy
@ -424,7 +429,7 @@ func (b *BeegoHttpRequest) SendOut() (*http.Response, error) {
} }
} }
var jar http.CookieJar = nil var jar http.CookieJar
if b.setting.EnableCookie { if b.setting.EnableCookie {
if defaultCookieJar == nil { if defaultCookieJar == nil {
createDefaultCookie() createDefaultCookie()
@ -453,7 +458,7 @@ func (b *BeegoHttpRequest) SendOut() (*http.Response, error) {
// String returns the body string in response. // String returns the body string in response.
// it calls Response inner. // it calls Response inner.
func (b *BeegoHttpRequest) String() (string, error) { func (b *BeegoHTTPRequest) String() (string, error) {
data, err := b.Bytes() data, err := b.Bytes()
if err != nil { if err != nil {
return "", err return "", err
@ -464,7 +469,7 @@ func (b *BeegoHttpRequest) String() (string, error) {
// Bytes returns the body []byte in response. // Bytes returns the body []byte in response.
// it calls Response inner. // it calls Response inner.
func (b *BeegoHttpRequest) Bytes() ([]byte, error) { func (b *BeegoHTTPRequest) Bytes() ([]byte, error) {
if b.body != nil { if b.body != nil {
return b.body, nil return b.body, nil
} }
@ -490,7 +495,7 @@ func (b *BeegoHttpRequest) Bytes() ([]byte, error) {
// ToFile saves the body data in response to one file. // ToFile saves the body data in response to one file.
// it calls Response inner. // it calls Response inner.
func (b *BeegoHttpRequest) ToFile(filename string) error { func (b *BeegoHTTPRequest) ToFile(filename string) error {
f, err := os.Create(filename) f, err := os.Create(filename)
if err != nil { if err != nil {
return err return err
@ -509,9 +514,9 @@ func (b *BeegoHttpRequest) ToFile(filename string) error {
return err return err
} }
// ToJson returns the map that marshals from the body bytes as json in response . // ToJSON returns the map that marshals from the body bytes as json in response .
// it calls Response inner. // it calls Response inner.
func (b *BeegoHttpRequest) ToJson(v interface{}) error { func (b *BeegoHTTPRequest) ToJSON(v interface{}) error {
data, err := b.Bytes() data, err := b.Bytes()
if err != nil { if err != nil {
return err return err
@ -519,9 +524,9 @@ func (b *BeegoHttpRequest) ToJson(v interface{}) error {
return json.Unmarshal(data, v) return json.Unmarshal(data, v)
} }
// ToXml returns the map that marshals from the body bytes as xml in response . // ToXML returns the map that marshals from the body bytes as xml in response .
// it calls Response inner. // it calls Response inner.
func (b *BeegoHttpRequest) ToXml(v interface{}) error { func (b *BeegoHTTPRequest) ToXML(v interface{}) error {
data, err := b.Bytes() data, err := b.Bytes()
if err != nil { if err != nil {
return err return err
@ -530,7 +535,7 @@ func (b *BeegoHttpRequest) ToXml(v interface{}) error {
} }
// Response executes request client gets response mannually. // Response executes request client gets response mannually.
func (b *BeegoHttpRequest) Response() (*http.Response, error) { func (b *BeegoHTTPRequest) Response() (*http.Response, error) {
return b.getResponse() return b.getResponse()
} }

View File

@ -19,6 +19,7 @@ import (
"os" "os"
"strings" "strings"
"testing" "testing"
"time"
) )
func TestResponse(t *testing.T) { func TestResponse(t *testing.T) {
@ -149,10 +150,11 @@ func TestWithUserAgent(t *testing.T) {
func TestWithSetting(t *testing.T) { func TestWithSetting(t *testing.T) {
v := "beego" v := "beego"
var setting BeegoHttpSettings var setting BeegoHTTPSettings
setting.EnableCookie = true setting.EnableCookie = true
setting.UserAgent = v setting.UserAgent = v
setting.Transport = nil setting.Transport = nil
setting.ReadWriteTimeout = 5 * time.Second
SetDefaultSetting(setting) SetDefaultSetting(setting)
str, err := Get("http://httpbin.org/get").String() str, err := Get("http://httpbin.org/get").String()
@ -176,11 +178,11 @@ func TestToJson(t *testing.T) {
t.Log(resp) t.Log(resp)
// httpbin will return http remote addr // httpbin will return http remote addr
type Ip struct { type IP struct {
Origin string `json:"origin"` Origin string `json:"origin"`
} }
var ip Ip var ip IP
err = req.ToJson(&ip) err = req.ToJSON(&ip)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

19
log.go
View File

@ -32,20 +32,20 @@ const (
LevelDebug LevelDebug
) )
// SetLogLevel sets the global log level used by the simple // BeeLogger references the used application logger.
// logger. var BeeLogger = logs.NewLogger(100)
// SetLevel sets the global log level used by the simple logger.
func SetLevel(l int) { func SetLevel(l int) {
BeeLogger.SetLevel(l) BeeLogger.SetLevel(l)
} }
// SetLogFuncCall set the CallDepth, default is 3
func SetLogFuncCall(b bool) { func SetLogFuncCall(b bool) {
BeeLogger.EnableFuncCallDepth(b) BeeLogger.EnableFuncCallDepth(b)
BeeLogger.SetLogFuncCallDepth(3) BeeLogger.SetLogFuncCallDepth(3)
} }
// logger references the used application logger.
var BeeLogger *logs.BeeLogger
// SetLogger sets a new logger. // SetLogger sets a new logger.
func SetLogger(adaptername string, config string) error { func SetLogger(adaptername string, config string) error {
err := BeeLogger.SetLogger(adaptername, config) err := BeeLogger.SetLogger(adaptername, config)
@ -55,10 +55,12 @@ func SetLogger(adaptername string, config string) error {
return nil return nil
} }
// Emergency logs a message at emergency level.
func Emergency(v ...interface{}) { func Emergency(v ...interface{}) {
BeeLogger.Emergency(generateFmtStr(len(v)), v...) BeeLogger.Emergency(generateFmtStr(len(v)), v...)
} }
// Alert logs a message at alert level.
func Alert(v ...interface{}) { func Alert(v ...interface{}) {
BeeLogger.Alert(generateFmtStr(len(v)), v...) BeeLogger.Alert(generateFmtStr(len(v)), v...)
} }
@ -78,21 +80,22 @@ func Warning(v ...interface{}) {
BeeLogger.Warning(generateFmtStr(len(v)), v...) BeeLogger.Warning(generateFmtStr(len(v)), v...)
} }
// compatibility alias for Warning() // Warn compatibility alias for Warning()
func Warn(v ...interface{}) { func Warn(v ...interface{}) {
BeeLogger.Warn(generateFmtStr(len(v)), v...) BeeLogger.Warn(generateFmtStr(len(v)), v...)
} }
// Notice logs a message at notice level.
func Notice(v ...interface{}) { func Notice(v ...interface{}) {
BeeLogger.Notice(generateFmtStr(len(v)), v...) BeeLogger.Notice(generateFmtStr(len(v)), v...)
} }
// Info logs a message at info level. // Informational logs a message at info level.
func Informational(v ...interface{}) { func Informational(v ...interface{}) {
BeeLogger.Informational(generateFmtStr(len(v)), v...) BeeLogger.Informational(generateFmtStr(len(v)), v...)
} }
// compatibility alias for Warning() // Info compatibility alias for Warning()
func Info(v ...interface{}) { func Info(v ...interface{}) {
BeeLogger.Info(generateFmtStr(len(v)), v...) BeeLogger.Info(generateFmtStr(len(v)), v...)
} }

View File

@ -21,9 +21,9 @@ import (
"net" "net"
) )
// ConnWriter implements LoggerInterface. // connWriter implements LoggerInterface.
// it writes messages in keep-live tcp connection. // it writes messages in keep-live tcp connection.
type ConnWriter struct { type connWriter struct {
lg *log.Logger lg *log.Logger
innerWriter io.WriteCloser innerWriter io.WriteCloser
ReconnectOnMsg bool `json:"reconnectOnMsg"` ReconnectOnMsg bool `json:"reconnectOnMsg"`
@ -33,22 +33,22 @@ type ConnWriter struct {
Level int `json:"level"` Level int `json:"level"`
} }
// create new ConnWrite returning as LoggerInterface. // NewConn create new ConnWrite returning as LoggerInterface.
func NewConn() LoggerInterface { func NewConn() Logger {
conn := new(ConnWriter) conn := new(connWriter)
conn.Level = LevelTrace conn.Level = LevelTrace
return conn return conn
} }
// init connection writer with json config. // Init init connection writer with json config.
// json config only need key "level". // json config only need key "level".
func (c *ConnWriter) Init(jsonconfig string) error { func (c *connWriter) Init(jsonconfig string) error {
return json.Unmarshal([]byte(jsonconfig), c) return json.Unmarshal([]byte(jsonconfig), c)
} }
// write message in connection. // WriteMsg write message in connection.
// if connection is down, try to re-connect. // if connection is down, try to re-connect.
func (c *ConnWriter) WriteMsg(msg string, level int) error { func (c *connWriter) WriteMsg(msg string, level int) error {
if level > c.Level { if level > c.Level {
return nil return nil
} }
@ -66,19 +66,19 @@ func (c *ConnWriter) WriteMsg(msg string, level int) error {
return nil return nil
} }
// implementing method. empty. // Flush implementing method. empty.
func (c *ConnWriter) Flush() { func (c *connWriter) Flush() {
} }
// destroy connection writer and close tcp listener. // Destroy destroy connection writer and close tcp listener.
func (c *ConnWriter) Destroy() { func (c *connWriter) Destroy() {
if c.innerWriter != nil { if c.innerWriter != nil {
c.innerWriter.Close() c.innerWriter.Close()
} }
} }
func (c *ConnWriter) connect() error { func (c *connWriter) connect() error {
if c.innerWriter != nil { if c.innerWriter != nil {
c.innerWriter.Close() c.innerWriter.Close()
c.innerWriter = nil c.innerWriter = nil
@ -98,7 +98,7 @@ func (c *ConnWriter) connect() error {
return nil return nil
} }
func (c *ConnWriter) neddedConnectOnMsg() bool { func (c *connWriter) neddedConnectOnMsg() bool {
if c.Reconnect { if c.Reconnect {
c.Reconnect = false c.Reconnect = false
return true return true

View File

@ -21,9 +21,11 @@ import (
"runtime" "runtime"
) )
type Brush func(string) string // brush is a color join function
type brush func(string) string
func NewBrush(color string) Brush { // newBrush return a fix color Brush
func newBrush(color string) brush {
pre := "\033[" pre := "\033["
reset := "\033[0m" reset := "\033[0m"
return func(text string) string { return func(text string) string {
@ -31,43 +33,43 @@ func NewBrush(color string) Brush {
} }
} }
var colors = []Brush{ var colors = []brush{
NewBrush("1;37"), // Emergency white newBrush("1;37"), // Emergency white
NewBrush("1;36"), // Alert cyan newBrush("1;36"), // Alert cyan
NewBrush("1;35"), // Critical magenta newBrush("1;35"), // Critical magenta
NewBrush("1;31"), // Error red newBrush("1;31"), // Error red
NewBrush("1;33"), // Warning yellow newBrush("1;33"), // Warning yellow
NewBrush("1;32"), // Notice green newBrush("1;32"), // Notice green
NewBrush("1;34"), // Informational blue newBrush("1;34"), // Informational blue
NewBrush("1;34"), // Debug blue newBrush("1;34"), // Debug blue
} }
// ConsoleWriter implements LoggerInterface and writes messages to terminal. // consoleWriter implements LoggerInterface and writes messages to terminal.
type ConsoleWriter struct { type consoleWriter struct {
lg *log.Logger lg *log.Logger
Level int `json:"level"` Level int `json:"level"`
} }
// create ConsoleWriter returning as LoggerInterface. // NewConsole create ConsoleWriter returning as LoggerInterface.
func NewConsole() LoggerInterface { func NewConsole() Logger {
cw := &ConsoleWriter{ cw := &consoleWriter{
lg: log.New(os.Stdout, "", log.Ldate|log.Ltime), lg: log.New(os.Stdout, "", log.Ldate|log.Ltime),
Level: LevelDebug, Level: LevelDebug,
} }
return cw return cw
} }
// init console logger. // Init init console logger.
// jsonconfig like '{"level":LevelTrace}'. // jsonconfig like '{"level":LevelTrace}'.
func (c *ConsoleWriter) Init(jsonconfig string) error { func (c *consoleWriter) Init(jsonconfig string) error {
if len(jsonconfig) == 0 { if len(jsonconfig) == 0 {
return nil return nil
} }
return json.Unmarshal([]byte(jsonconfig), c) return json.Unmarshal([]byte(jsonconfig), c)
} }
// write message in console. // WriteMsg write message in console.
func (c *ConsoleWriter) WriteMsg(msg string, level int) error { func (c *consoleWriter) WriteMsg(msg string, level int) error {
if level > c.Level { if level > c.Level {
return nil return nil
} }
@ -80,13 +82,13 @@ func (c *ConsoleWriter) WriteMsg(msg string, level int) error {
return nil return nil
} }
// implementing method. empty. // Destroy implementing method. empty.
func (c *ConsoleWriter) Destroy() { func (c *consoleWriter) Destroy() {
} }
// implementing method. empty. // Flush implementing method. empty.
func (c *ConsoleWriter) Flush() { func (c *consoleWriter) Flush() {
} }

View File

@ -43,11 +43,3 @@ func TestConsole(t *testing.T) {
testConsoleCalls(log2) testConsoleCalls(log2)
} }
func BenchmarkConsole(b *testing.B) {
log := NewLogger(10000)
log.EnableFuncCallDepth(true)
log.SetLogger("console", "")
for i := 0; i < b.N; i++ {
log.Debug("debug")
}
}

View File

@ -12,7 +12,8 @@ import (
"github.com/belogik/goes" "github.com/belogik/goes"
) )
func NewES() logs.LoggerInterface { // NewES return a LoggerInterface
func NewES() logs.Logger {
cw := &esLogger{ cw := &esLogger{
Level: logs.LevelDebug, Level: logs.LevelDebug,
} }
@ -46,6 +47,7 @@ func (el *esLogger) Init(jsonconfig string) error {
return nil return nil
} }
// WriteMsg will write the msg and level into es
func (el *esLogger) WriteMsg(msg string, level int) error { func (el *esLogger) WriteMsg(msg string, level int) error {
if level > el.Level { if level > el.Level {
return nil return nil
@ -63,10 +65,12 @@ func (el *esLogger) WriteMsg(msg string, level int) error {
return err return err
} }
// Destroy is a empty method
func (el *esLogger) Destroy() { func (el *esLogger) Destroy() {
} }
// Flush is a empty method
func (el *esLogger) Flush() { func (el *esLogger) Flush() {
} }

View File

@ -20,7 +20,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@ -28,84 +27,62 @@ import (
"time" "time"
) )
// FileLogWriter implements LoggerInterface. // fileLogWriter implements LoggerInterface.
// It writes messages by lines limit, file size limit, or time frequency. // It writes messages by lines limit, file size limit, or time frequency.
type FileLogWriter struct { type fileLogWriter struct {
*log.Logger sync.Mutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize
mw *MuxWriter
// The opened file // The opened file
Filename string `json:"filename"` Filename string `json:"filename"`
fileWriter *os.File
Maxlines int `json:"maxlines"` // Rotate at line
maxlines_curlines int MaxLines int `json:"maxlines"`
maxLinesCurLines int
// Rotate at size // Rotate at size
Maxsize int `json:"maxsize"` MaxSize int `json:"maxsize"`
maxsize_cursize int maxSizeCurSize int
// Rotate daily // Rotate daily
Daily bool `json:"daily"` Daily bool `json:"daily"`
Maxdays int64 `json:"maxdays"` MaxDays int64 `json:"maxdays"`
daily_opendate int dailyOpenDate int
Rotate bool `json:"rotate"` Rotate bool `json:"rotate"`
startLock sync.Mutex // Only one log can write to the file
Level int `json:"level"` Level int `json:"level"`
Perm os.FileMode `json:"perm"`
} }
// an *os.File writer with locker. // NewFileWriter create a FileLogWriter returning as LoggerInterface.
type MuxWriter struct { func newFileWriter() Logger {
sync.Mutex w := &fileLogWriter{
fd *os.File
}
// write to os.File.
func (l *MuxWriter) Write(b []byte) (int, error) {
l.Lock()
defer l.Unlock()
return l.fd.Write(b)
}
// set os.File in writer.
func (l *MuxWriter) SetFd(fd *os.File) {
if l.fd != nil {
l.fd.Close()
}
l.fd = fd
}
// create a FileLogWriter returning as LoggerInterface.
func NewFileWriter() LoggerInterface {
w := &FileLogWriter{
Filename: "", Filename: "",
Maxlines: 1000000, MaxLines: 1000000,
Maxsize: 1 << 28, //256 MB MaxSize: 1 << 28, //256 MB
Daily: true, Daily: true,
Maxdays: 7, MaxDays: 7,
Rotate: true, Rotate: true,
Level: LevelTrace, Level: LevelTrace,
Perm: 0660,
} }
// use MuxWriter instead direct use os.File for lock write when rotate
w.mw = new(MuxWriter)
// set MuxWriter as Logger's io.Writer
w.Logger = log.New(w.mw, "", log.Ldate|log.Ltime)
return w return w
} }
// Init file logger with json config. // Init file logger with json config.
// jsonconfig like: // jsonConfig like:
// { // {
// "filename":"logs/beego.log", // "filename":"logs/beego.log",
// "maxlines":10000, // "maxLines":10000,
// "maxsize":1<<30, // "maxsize":1<<30,
// "daily":true, // "daily":true,
// "maxdays":15, // "maxDays":15,
// "rotate":true // "rotate":true,
// "perm":0600
// } // }
func (w *FileLogWriter) Init(jsonconfig string) error { func (w *fileLogWriter) Init(jsonConfig string) error {
err := json.Unmarshal([]byte(jsonconfig), w) err := json.Unmarshal([]byte(jsonConfig), w)
if err != nil { if err != nil {
return err return err
} }
@ -117,67 +94,120 @@ func (w *FileLogWriter) Init(jsonconfig string) error {
} }
// start file logger. create log file and set to locker-inside file writer. // start file logger. create log file and set to locker-inside file writer.
func (w *FileLogWriter) startLogger() error { func (w *fileLogWriter) startLogger() error {
fd, err := w.createLogFile() file, err := w.createLogFile()
if err != nil { if err != nil {
return err return err
} }
w.mw.SetFd(fd) if w.fileWriter != nil {
w.fileWriter.Close()
}
w.fileWriter = file
return w.initFd() return w.initFd()
} }
func (w *FileLogWriter) docheck(size int) { func (w *fileLogWriter) needRotate(size int, day int) bool {
w.startLock.Lock() return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) ||
defer w.startLock.Unlock() (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) ||
if w.Rotate && ((w.Maxlines > 0 && w.maxlines_curlines >= w.Maxlines) || (w.Daily && day != w.dailyOpenDate)
(w.Maxsize > 0 && w.maxsize_cursize >= w.Maxsize) ||
(w.Daily && time.Now().Day() != w.daily_opendate)) {
if err := w.DoRotate(); err != nil {
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
return
}
}
w.maxlines_curlines++
w.maxsize_cursize += size
} }
// write logger message into file. // WriteMsg write logger message into file.
func (w *FileLogWriter) WriteMsg(msg string, level int) error { func (w *fileLogWriter) WriteMsg(msg string, level int) error {
if level > w.Level { if level > w.Level {
return nil return nil
} }
n := 24 + len(msg) // 24 stand for the length "2013/06/23 21:00:22 [T] " //2016/01/12 21:34:33
w.docheck(n) now := time.Now()
w.Logger.Println(msg) y, mo, d := now.Date()
return nil h, mi, s := now.Clock()
//len(2006/01/02 15:03:04)==19
var buf [20]byte
t := 3
for y >= 10 {
p := y / 10
buf[t] = byte('0' + y - p*10)
y = p
t--
}
buf[0] = byte('0' + y)
buf[4] = '/'
if mo > 9 {
buf[5] = '1'
buf[6] = byte('0' + mo - 9)
} else {
buf[5] = '0'
buf[6] = byte('0' + mo)
}
buf[7] = '/'
t = d / 10
buf[8] = byte('0' + t)
buf[9] = byte('0' + d - t*10)
buf[10] = ' '
t = h / 10
buf[11] = byte('0' + t)
buf[12] = byte('0' + h - t*10)
buf[13] = ':'
t = mi / 10
buf[14] = byte('0' + t)
buf[15] = byte('0' + mi - t*10)
buf[16] = ':'
t = s / 10
buf[17] = byte('0' + t)
buf[18] = byte('0' + s - t*10)
buf[19] = ' '
msg = string(buf[0:]) + msg + "\n"
if w.Rotate {
if w.needRotate(len(msg), d) {
w.Lock()
if w.needRotate(len(msg), d) {
if err := w.doRotate(); err != nil {
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
}
}
w.Unlock()
}
}
w.Lock()
_, err := w.fileWriter.Write([]byte(msg))
if err == nil {
w.maxLinesCurLines++
w.maxSizeCurSize += len(msg)
}
w.Unlock()
return err
} }
func (w *FileLogWriter) createLogFile() (*os.File, error) { func (w *fileLogWriter) createLogFile() (*os.File, error) {
// Open the log file // Open the log file
fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0660) fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, w.Perm)
return fd, err return fd, err
} }
func (w *FileLogWriter) initFd() error { func (w *fileLogWriter) initFd() error {
fd := w.mw.fd fd := w.fileWriter
finfo, err := fd.Stat() fInfo, err := fd.Stat()
if err != nil { if err != nil {
return fmt.Errorf("get stat err: %s\n", err) return fmt.Errorf("get stat err: %s\n", err)
} }
w.maxsize_cursize = int(finfo.Size()) w.maxSizeCurSize = int(fInfo.Size())
w.daily_opendate = time.Now().Day() w.dailyOpenDate = time.Now().Day()
w.maxlines_curlines = 0 w.maxLinesCurLines = 0
if finfo.Size() > 0 { if fInfo.Size() > 0 {
count, err := w.lines() count, err := w.lines()
if err != nil { if err != nil {
return err return err
} }
w.maxlines_curlines = count w.maxLinesCurLines = count
} }
return nil return nil
} }
func (w *FileLogWriter) lines() (int, error) { func (w *fileLogWriter) lines() (int, error) {
fd, err := os.Open(w.Filename) fd, err := os.Open(w.Filename)
if err != nil { if err != nil {
return 0, err return 0, err
@ -205,59 +235,60 @@ func (w *FileLogWriter) lines() (int, error) {
} }
// DoRotate means it need to write file in new file. // DoRotate means it need to write file in new file.
// new file name like xx.log.2013-01-01.2 // new file name like xx.2013-01-01.2.log
func (w *FileLogWriter) DoRotate() error { func (w *fileLogWriter) doRotate() error {
_, err := os.Lstat(w.Filename) _, err := os.Lstat(w.Filename)
if err == nil { // file exists if err != nil {
// Find the next available number return err
num := 1 }
fname := "" // file exists
for ; err == nil && num <= 999; num++ { // Find the next available number
fname = w.Filename + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), num) num := 1
_, err = os.Lstat(fname) fName := ""
} suffix := filepath.Ext(w.Filename)
// return error if the last file checked still existed filenameOnly := strings.TrimSuffix(w.Filename, suffix)
if err == nil { if suffix == "" {
return fmt.Errorf("Rotate: Cannot find free log number to rename %s\n", w.Filename) suffix = ".log"
} }
for ; err == nil && num <= 999; num++ {
// block Logger's io.Writer fName = filenameOnly + fmt.Sprintf(".%s.%03d%s", time.Now().Format("2006-01-02"), num, suffix)
w.mw.Lock() _, err = os.Lstat(fName)
defer w.mw.Unlock() }
// return error if the last file checked still existed
fd := w.mw.fd if err == nil {
fd.Close() return fmt.Errorf("Rotate: Cannot find free log number to rename %s\n", w.Filename)
// close fd before rename
// Rename the file to its newfound home
err = os.Rename(w.Filename, fname)
if err != nil {
return fmt.Errorf("Rotate: %s\n", err)
}
// re-start logger
err = w.startLogger()
if err != nil {
return fmt.Errorf("Rotate StartLogger: %s\n", err)
}
go w.deleteOldLog()
} }
// close fileWriter before rename
w.fileWriter.Close()
// Rename the file to its new found name
// even if occurs error,we MUST guarantee to restart new logger
renameErr := os.Rename(w.Filename, fName)
// re-start logger
startLoggerErr := w.startLogger()
go w.deleteOldLog()
if startLoggerErr != nil {
return fmt.Errorf("Rotate StartLogger: %s\n", startLoggerErr)
}
if renameErr != nil {
return fmt.Errorf("Rotate: %s\n", renameErr)
}
return nil return nil
} }
func (w *FileLogWriter) deleteOldLog() { func (w *fileLogWriter) deleteOldLog() {
dir := filepath.Dir(w.Filename) dir := filepath.Dir(w.Filename)
filepath.Walk(dir, func(path string, info os.FileInfo, err error) (returnErr error) { filepath.Walk(dir, func(path string, info os.FileInfo, err error) (returnErr error) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
returnErr = fmt.Errorf("Unable to delete old log '%s', error: %+v", path, r) fmt.Fprintf(os.Stderr, "Unable to delete old log '%s', error: %v\n", path, r)
fmt.Println(returnErr)
} }
}() }()
if !info.IsDir() && info.ModTime().Unix() < (time.Now().Unix()-60*60*24*w.Maxdays) { if !info.IsDir() && info.ModTime().Unix() < (time.Now().Unix()-60*60*24*w.MaxDays) {
if strings.HasPrefix(filepath.Base(path), filepath.Base(w.Filename)) { if strings.HasPrefix(filepath.Base(path), filepath.Base(w.Filename)) {
os.Remove(path) os.Remove(path)
} }
@ -266,18 +297,18 @@ func (w *FileLogWriter) deleteOldLog() {
}) })
} }
// destroy file logger, close file writer. // Destroy close the file description, close file writer.
func (w *FileLogWriter) Destroy() { func (w *fileLogWriter) Destroy() {
w.mw.fd.Close() w.fileWriter.Close()
} }
// flush file logger. // Flush flush file logger.
// there are no buffering messages in file logger in memory. // there are no buffering messages in file logger in memory.
// flush file means sync file from disk. // flush file means sync file from disk.
func (w *FileLogWriter) Flush() { func (w *fileLogWriter) Flush() {
w.mw.fd.Sync() w.fileWriter.Sync()
} }
func init() { func init() {
Register("file", NewFileWriter) Register("file", newFileWriter)
} }

View File

@ -23,7 +23,7 @@ import (
"time" "time"
) )
func TestFile(t *testing.T) { func TestFile1(t *testing.T) {
log := NewLogger(10000) log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test.log"}`) log.SetLogger("file", `{"filename":"test.log"}`)
log.Debug("debug") log.Debug("debug")
@ -34,25 +34,24 @@ func TestFile(t *testing.T) {
log.Alert("alert") log.Alert("alert")
log.Critical("critical") log.Critical("critical")
log.Emergency("emergency") log.Emergency("emergency")
time.Sleep(time.Second * 4)
f, err := os.Open("test.log") f, err := os.Open("test.log")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
b := bufio.NewReader(f) b := bufio.NewReader(f)
linenum := 0 lineNum := 0
for { for {
line, _, err := b.ReadLine() line, _, err := b.ReadLine()
if err != nil { if err != nil {
break break
} }
if len(line) > 0 { if len(line) > 0 {
linenum++ lineNum++
} }
} }
var expected = LevelDebug + 1 var expected = LevelDebug + 1
if linenum != expected { if lineNum != expected {
t.Fatal(linenum, "not "+strconv.Itoa(expected)+" lines") t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines")
} }
os.Remove("test.log") os.Remove("test.log")
} }
@ -68,25 +67,24 @@ func TestFile2(t *testing.T) {
log.Alert("alert") log.Alert("alert")
log.Critical("critical") log.Critical("critical")
log.Emergency("emergency") log.Emergency("emergency")
time.Sleep(time.Second * 4)
f, err := os.Open("test2.log") f, err := os.Open("test2.log")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
b := bufio.NewReader(f) b := bufio.NewReader(f)
linenum := 0 lineNum := 0
for { for {
line, _, err := b.ReadLine() line, _, err := b.ReadLine()
if err != nil { if err != nil {
break break
} }
if len(line) > 0 { if len(line) > 0 {
linenum++ lineNum++
} }
} }
var expected = LevelError + 1 var expected = LevelError + 1
if linenum != expected { if lineNum != expected {
t.Fatal(linenum, "not "+strconv.Itoa(expected)+" lines") t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines")
} }
os.Remove("test2.log") os.Remove("test2.log")
} }
@ -102,13 +100,13 @@ func TestFileRotate(t *testing.T) {
log.Alert("alert") log.Alert("alert")
log.Critical("critical") log.Critical("critical")
log.Emergency("emergency") log.Emergency("emergency")
time.Sleep(time.Second * 4) rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log"
rotatename := "test3.log" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) b, err := exists(rotateName)
b, err := exists(rotatename)
if !b || err != nil { if !b || err != nil {
os.Remove("test3.log")
t.Fatal("rotate not generated") t.Fatal("rotate not generated")
} }
os.Remove(rotatename) os.Remove(rotateName)
os.Remove("test3.log") os.Remove("test3.log")
} }
@ -131,3 +129,46 @@ func BenchmarkFile(b *testing.B) {
} }
os.Remove("test4.log") os.Remove("test4.log")
} }
func BenchmarkFileAsynchronous(b *testing.B) {
log := NewLogger(100000)
log.SetLogger("file", `{"filename":"test4.log"}`)
log.Async()
for i := 0; i < b.N; i++ {
log.Debug("debug")
}
os.Remove("test4.log")
}
func BenchmarkFileCallDepth(b *testing.B) {
log := NewLogger(100000)
log.SetLogger("file", `{"filename":"test4.log"}`)
log.EnableFuncCallDepth(true)
log.SetLogFuncCallDepth(2)
for i := 0; i < b.N; i++ {
log.Debug("debug")
}
os.Remove("test4.log")
}
func BenchmarkFileAsynchronousCallDepth(b *testing.B) {
log := NewLogger(100000)
log.SetLogger("file", `{"filename":"test4.log"}`)
log.EnableFuncCallDepth(true)
log.SetLogFuncCallDepth(2)
log.Async()
for i := 0; i < b.N; i++ {
log.Debug("debug")
}
os.Remove("test4.log")
}
func BenchmarkFileOnGoroutine(b *testing.B) {
log := NewLogger(100000)
log.SetLogger("file", `{"filename":"test4.log"}`)
for i := 0; i < b.N; i++ {
go log.Debug("debug")
}
os.Remove("test4.log")
}

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Package logs provide a general log interface
// Usage: // Usage:
// //
// import "github.com/astaxie/beego/logs" // import "github.com/astaxie/beego/logs"
@ -34,8 +35,10 @@ package logs
import ( import (
"fmt" "fmt"
"os"
"path" "path"
"runtime" "runtime"
"strconv"
"sync" "sync"
) )
@ -60,10 +63,10 @@ const (
LevelWarn = LevelWarning LevelWarn = LevelWarning
) )
type loggerType func() LoggerInterface type loggerType func() Logger
// LoggerInterface defines the behavior of a log provider. // Logger defines the behavior of a log provider.
type LoggerInterface interface { type Logger interface {
Init(config string) error Init(config string) error
WriteMsg(msg string, level int) error WriteMsg(msg string, level int) error
Destroy() Destroy()
@ -93,8 +96,13 @@ type BeeLogger struct {
enableFuncCallDepth bool enableFuncCallDepth bool
loggerFuncCallDepth int loggerFuncCallDepth int
asynchronous bool asynchronous bool
msg chan *logMsg msgChan chan *logMsg
outputs map[string]LoggerInterface outputs []*nameLogger
}
type nameLogger struct {
Logger
name string
} }
type logMsg struct { type logMsg struct {
@ -102,59 +110,79 @@ type logMsg struct {
msg string msg string
} }
var logMsgPool *sync.Pool
// NewLogger returns a new BeeLogger. // NewLogger returns a new BeeLogger.
// channellen means the number of messages in chan. // channelLen means the number of messages in chan(used where asynchronous is true).
// if the buffering chan is full, logger adapters write to file or other way. // if the buffering chan is full, logger adapters write to file or other way.
func NewLogger(channellen int64) *BeeLogger { func NewLogger(channelLen int64) *BeeLogger {
bl := new(BeeLogger) bl := new(BeeLogger)
bl.level = LevelDebug bl.level = LevelDebug
bl.loggerFuncCallDepth = 2 bl.loggerFuncCallDepth = 2
bl.msg = make(chan *logMsg, channellen) bl.msgChan = make(chan *logMsg, channelLen)
bl.outputs = make(map[string]LoggerInterface)
return bl return bl
} }
// Async set the log to asynchronous and start the goroutine
func (bl *BeeLogger) Async() *BeeLogger { func (bl *BeeLogger) Async() *BeeLogger {
bl.asynchronous = true bl.asynchronous = true
logMsgPool = &sync.Pool{
New: func() interface{} {
return &logMsg{}
},
}
go bl.startLogger() go bl.startLogger()
return bl return bl
} }
// SetLogger provides a given logger adapter into BeeLogger with config string. // SetLogger provides a given logger adapter into BeeLogger with config string.
// config need to be correct JSON as string: {"interval":360}. // config need to be correct JSON as string: {"interval":360}.
func (bl *BeeLogger) SetLogger(adaptername string, config string) error { func (bl *BeeLogger) SetLogger(adapterName string, config string) error {
bl.lock.Lock() bl.lock.Lock()
defer bl.lock.Unlock() defer bl.lock.Unlock()
if log, ok := adapters[adaptername]; ok { if log, ok := adapters[adapterName]; ok {
lg := log() lg := log()
err := lg.Init(config) err := lg.Init(config)
bl.outputs[adaptername] = lg
if err != nil { if err != nil {
fmt.Println("logs.BeeLogger.SetLogger: " + err.Error()) fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error())
return err return err
} }
bl.outputs = append(bl.outputs, &nameLogger{name: adapterName, Logger: lg})
} else { } else {
return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adaptername) return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName)
} }
return nil return nil
} }
// remove a logger adapter in BeeLogger. // DelLogger remove a logger adapter in BeeLogger.
func (bl *BeeLogger) DelLogger(adaptername string) error { func (bl *BeeLogger) DelLogger(adapterName string) error {
bl.lock.Lock() bl.lock.Lock()
defer bl.lock.Unlock() defer bl.lock.Unlock()
if lg, ok := bl.outputs[adaptername]; ok { outputs := []*nameLogger{}
lg.Destroy() for _, lg := range bl.outputs {
delete(bl.outputs, adaptername) if lg.name == adapterName {
return nil lg.Destroy()
} else { } else {
return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adaptername) outputs = append(outputs, lg)
}
}
if len(outputs) == len(bl.outputs) {
return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName)
}
bl.outputs = outputs
return nil
}
func (bl *BeeLogger) writeToLoggers(msg string, level int) {
for _, l := range bl.outputs {
err := l.WriteMsg(msg, level)
if err != nil {
fmt.Fprintf(os.Stderr, "unable to WriteMsg to adapter:%v,error:%v\n", l.name, err)
}
} }
} }
func (bl *BeeLogger) writerMsg(loglevel int, msg string) error { func (bl *BeeLogger) writeMsg(logLevel int, msg string) error {
lm := new(logMsg)
lm.level = loglevel
if bl.enableFuncCallDepth { if bl.enableFuncCallDepth {
_, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth) _, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth)
if !ok { if !ok {
@ -162,43 +190,37 @@ func (bl *BeeLogger) writerMsg(loglevel int, msg string) error {
line = 0 line = 0
} }
_, filename := path.Split(file) _, filename := path.Split(file)
lm.msg = fmt.Sprintf("[%s:%d] %s", filename, line, msg) msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "]" + msg
} else {
lm.msg = msg
} }
if bl.asynchronous { if bl.asynchronous {
bl.msg <- lm lm := logMsgPool.Get().(*logMsg)
lm.level = logLevel
lm.msg = msg
bl.msgChan <- lm
} else { } else {
for name, l := range bl.outputs { bl.writeToLoggers(msg, logLevel)
err := l.WriteMsg(lm.msg, lm.level)
if err != nil {
fmt.Println("unable to WriteMsg to adapter:", name, err)
return err
}
}
} }
return nil return nil
} }
// Set log message level. // SetLevel Set log message level.
//
// If message level (such as LevelDebug) is higher than logger level (such as LevelWarning), // If message level (such as LevelDebug) is higher than logger level (such as LevelWarning),
// log providers will not even be sent the message. // log providers will not even be sent the message.
func (bl *BeeLogger) SetLevel(l int) { func (bl *BeeLogger) SetLevel(l int) {
bl.level = l bl.level = l
} }
// set log funcCallDepth // SetLogFuncCallDepth set log funcCallDepth
func (bl *BeeLogger) SetLogFuncCallDepth(d int) { func (bl *BeeLogger) SetLogFuncCallDepth(d int) {
bl.loggerFuncCallDepth = d bl.loggerFuncCallDepth = d
} }
// get log funcCallDepth for wrapper // GetLogFuncCallDepth return log funcCallDepth for wrapper
func (bl *BeeLogger) GetLogFuncCallDepth() int { func (bl *BeeLogger) GetLogFuncCallDepth() int {
return bl.loggerFuncCallDepth return bl.loggerFuncCallDepth
} }
// enable log funcCallDepth // EnableFuncCallDepth enable log funcCallDepth
func (bl *BeeLogger) EnableFuncCallDepth(b bool) { func (bl *BeeLogger) EnableFuncCallDepth(b bool) {
bl.enableFuncCallDepth = b bl.enableFuncCallDepth = b
} }
@ -208,137 +230,129 @@ func (bl *BeeLogger) EnableFuncCallDepth(b bool) {
func (bl *BeeLogger) startLogger() { func (bl *BeeLogger) startLogger() {
for { for {
select { select {
case bm := <-bl.msg: case bm := <-bl.msgChan:
for _, l := range bl.outputs { bl.writeToLoggers(bm.msg, bm.level)
err := l.WriteMsg(bm.msg, bm.level) logMsgPool.Put(bm)
if err != nil {
fmt.Println("ERROR, unable to WriteMsg:", err)
}
}
} }
} }
} }
// Log EMERGENCY level message. // Emergency Log EMERGENCY level message.
func (bl *BeeLogger) Emergency(format string, v ...interface{}) { func (bl *BeeLogger) Emergency(format string, v ...interface{}) {
if LevelEmergency > bl.level { if LevelEmergency > bl.level {
return return
} }
msg := fmt.Sprintf("[M] "+format, v...) msg := fmt.Sprintf("[M] "+format, v...)
bl.writerMsg(LevelEmergency, msg) bl.writeMsg(LevelEmergency, msg)
} }
// Log ALERT level message. // Alert Log ALERT level message.
func (bl *BeeLogger) Alert(format string, v ...interface{}) { func (bl *BeeLogger) Alert(format string, v ...interface{}) {
if LevelAlert > bl.level { if LevelAlert > bl.level {
return return
} }
msg := fmt.Sprintf("[A] "+format, v...) msg := fmt.Sprintf("[A] "+format, v...)
bl.writerMsg(LevelAlert, msg) bl.writeMsg(LevelAlert, msg)
} }
// Log CRITICAL level message. // Critical Log CRITICAL level message.
func (bl *BeeLogger) Critical(format string, v ...interface{}) { func (bl *BeeLogger) Critical(format string, v ...interface{}) {
if LevelCritical > bl.level { if LevelCritical > bl.level {
return return
} }
msg := fmt.Sprintf("[C] "+format, v...) msg := fmt.Sprintf("[C] "+format, v...)
bl.writerMsg(LevelCritical, msg) bl.writeMsg(LevelCritical, msg)
} }
// Log ERROR level message. // Error Log ERROR level message.
func (bl *BeeLogger) Error(format string, v ...interface{}) { func (bl *BeeLogger) Error(format string, v ...interface{}) {
if LevelError > bl.level { if LevelError > bl.level {
return return
} }
msg := fmt.Sprintf("[E] "+format, v...) msg := fmt.Sprintf("[E] "+format, v...)
bl.writerMsg(LevelError, msg) bl.writeMsg(LevelError, msg)
} }
// Log WARNING level message. // Warning Log WARNING level message.
func (bl *BeeLogger) Warning(format string, v ...interface{}) { func (bl *BeeLogger) Warning(format string, v ...interface{}) {
if LevelWarning > bl.level { if LevelWarning > bl.level {
return return
} }
msg := fmt.Sprintf("[W] "+format, v...) msg := fmt.Sprintf("[W] "+format, v...)
bl.writerMsg(LevelWarning, msg) bl.writeMsg(LevelWarning, msg)
} }
// Log NOTICE level message. // Notice Log NOTICE level message.
func (bl *BeeLogger) Notice(format string, v ...interface{}) { func (bl *BeeLogger) Notice(format string, v ...interface{}) {
if LevelNotice > bl.level { if LevelNotice > bl.level {
return return
} }
msg := fmt.Sprintf("[N] "+format, v...) msg := fmt.Sprintf("[N] "+format, v...)
bl.writerMsg(LevelNotice, msg) bl.writeMsg(LevelNotice, msg)
} }
// Log INFORMATIONAL level message. // Informational Log INFORMATIONAL level message.
func (bl *BeeLogger) Informational(format string, v ...interface{}) { func (bl *BeeLogger) Informational(format string, v ...interface{}) {
if LevelInformational > bl.level { if LevelInformational > bl.level {
return return
} }
msg := fmt.Sprintf("[I] "+format, v...) msg := fmt.Sprintf("[I] "+format, v...)
bl.writerMsg(LevelInformational, msg) bl.writeMsg(LevelInformational, msg)
} }
// Log DEBUG level message. // Debug Log DEBUG level message.
func (bl *BeeLogger) Debug(format string, v ...interface{}) { func (bl *BeeLogger) Debug(format string, v ...interface{}) {
if LevelDebug > bl.level { if LevelDebug > bl.level {
return return
} }
msg := fmt.Sprintf("[D] "+format, v...) msg := fmt.Sprintf("[D] "+format, v...)
bl.writerMsg(LevelDebug, msg) bl.writeMsg(LevelDebug, msg)
} }
// Log WARN level message. // Warn Log WARN level message.
// compatibility alias for Warning() // compatibility alias for Warning()
func (bl *BeeLogger) Warn(format string, v ...interface{}) { func (bl *BeeLogger) Warn(format string, v ...interface{}) {
if LevelWarning > bl.level { if LevelWarning > bl.level {
return return
} }
msg := fmt.Sprintf("[W] "+format, v...) msg := fmt.Sprintf("[W] "+format, v...)
bl.writerMsg(LevelWarning, msg) bl.writeMsg(LevelWarning, msg)
} }
// Log INFO level message. // Info Log INFO level message.
// compatibility alias for Informational() // compatibility alias for Informational()
func (bl *BeeLogger) Info(format string, v ...interface{}) { func (bl *BeeLogger) Info(format string, v ...interface{}) {
if LevelInformational > bl.level { if LevelInformational > bl.level {
return return
} }
msg := fmt.Sprintf("[I] "+format, v...) msg := fmt.Sprintf("[I] "+format, v...)
bl.writerMsg(LevelInformational, msg) bl.writeMsg(LevelInformational, msg)
} }
// Log TRACE level message. // Trace Log TRACE level message.
// compatibility alias for Debug() // compatibility alias for Debug()
func (bl *BeeLogger) Trace(format string, v ...interface{}) { func (bl *BeeLogger) Trace(format string, v ...interface{}) {
if LevelDebug > bl.level { if LevelDebug > bl.level {
return return
} }
msg := fmt.Sprintf("[D] "+format, v...) msg := fmt.Sprintf("[D] "+format, v...)
bl.writerMsg(LevelDebug, msg) bl.writeMsg(LevelDebug, msg)
} }
// flush all chan data. // Flush flush all chan data.
func (bl *BeeLogger) Flush() { func (bl *BeeLogger) Flush() {
for _, l := range bl.outputs { for _, l := range bl.outputs {
l.Flush() l.Flush()
} }
} }
// close logger, flush all chan data and destroy all adapters in BeeLogger. // Close close logger, flush all chan data and destroy all adapters in BeeLogger.
func (bl *BeeLogger) Close() { func (bl *BeeLogger) Close() {
for { for {
if len(bl.msg) > 0 { if len(bl.msgChan) > 0 {
bm := <-bl.msg bm := <-bl.msgChan
for _, l := range bl.outputs { bl.writeToLoggers(bm.msg, bm.level)
err := l.WriteMsg(bm.msg, bm.level) logMsgPool.Put(bm)
if err != nil {
fmt.Println("ERROR, unable to WriteMsg (while closing logger):", err)
}
}
continue continue
} }
break break

View File

@ -24,31 +24,26 @@ import (
"time" "time"
) )
const ( // SMTPWriter implements LoggerInterface and is used to send emails via given SMTP-server.
// no usage type SMTPWriter struct {
// subjectPhrase = "Diagnostic message from server" Username string `json:"username"`
)
// smtpWriter implements LoggerInterface and is used to send emails via given SMTP-server.
type SmtpWriter struct {
Username string `json:"Username"`
Password string `json:"password"` Password string `json:"password"`
Host string `json:"Host"` Host string `json:"host"`
Subject string `json:"subject"` Subject string `json:"subject"`
FromAddress string `json:"fromAddress"` FromAddress string `json:"fromAddress"`
RecipientAddresses []string `json:"sendTos"` RecipientAddresses []string `json:"sendTos"`
Level int `json:"level"` Level int `json:"level"`
} }
// create smtp writer. // NewSMTPWriter create smtp writer.
func NewSmtpWriter() LoggerInterface { func newSMTPWriter() Logger {
return &SmtpWriter{Level: LevelTrace} return &SMTPWriter{Level: LevelTrace}
} }
// init smtp writer with json config. // Init smtp writer with json config.
// config like: // config like:
// { // {
// "Username":"example@gmail.com", // "username":"example@gmail.com",
// "password:"password", // "password:"password",
// "host":"smtp.gmail.com:465", // "host":"smtp.gmail.com:465",
// "subject":"email title", // "subject":"email title",
@ -56,7 +51,7 @@ func NewSmtpWriter() LoggerInterface {
// "sendTos":["email1","email2"], // "sendTos":["email1","email2"],
// "level":LevelError // "level":LevelError
// } // }
func (s *SmtpWriter) Init(jsonconfig string) error { func (s *SMTPWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), s) err := json.Unmarshal([]byte(jsonconfig), s)
if err != nil { if err != nil {
return err return err
@ -64,7 +59,7 @@ func (s *SmtpWriter) Init(jsonconfig string) error {
return nil return nil
} }
func (s *SmtpWriter) GetSmtpAuth(host string) smtp.Auth { func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth {
if len(strings.Trim(s.Username, " ")) == 0 && len(strings.Trim(s.Password, " ")) == 0 { if len(strings.Trim(s.Username, " ")) == 0 && len(strings.Trim(s.Password, " ")) == 0 {
return nil return nil
} }
@ -76,7 +71,7 @@ func (s *SmtpWriter) GetSmtpAuth(host string) smtp.Auth {
) )
} }
func (s *SmtpWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error { func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error {
client, err := smtp.Dial(hostAddressWithPort) client, err := smtp.Dial(hostAddressWithPort)
if err != nil { if err != nil {
return err return err
@ -129,9 +124,9 @@ func (s *SmtpWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd
return nil return nil
} }
// write message in smtp writer. // WriteMsg write message in smtp writer.
// it will send an email with subject and only this message. // it will send an email with subject and only this message.
func (s *SmtpWriter) WriteMsg(msg string, level int) error { func (s *SMTPWriter) WriteMsg(msg string, level int) error {
if level > s.Level { if level > s.Level {
return nil return nil
} }
@ -139,27 +134,27 @@ func (s *SmtpWriter) WriteMsg(msg string, level int) error {
hp := strings.Split(s.Host, ":") hp := strings.Split(s.Host, ":")
// Set up authentication information. // Set up authentication information.
auth := s.GetSmtpAuth(hp[0]) auth := s.getSMTPAuth(hp[0])
// Connect to the server, authenticate, set the sender and recipient, // Connect to the server, authenticate, set the sender and recipient,
// and send the email all in one step. // and send the email all in one step.
content_type := "Content-Type: text/plain" + "; charset=UTF-8" contentType := "Content-Type: text/plain" + "; charset=UTF-8"
mailmsg := []byte("To: " + strings.Join(s.RecipientAddresses, ";") + "\r\nFrom: " + s.FromAddress + "<" + s.FromAddress + mailmsg := []byte("To: " + strings.Join(s.RecipientAddresses, ";") + "\r\nFrom: " + s.FromAddress + "<" + s.FromAddress +
">\r\nSubject: " + s.Subject + "\r\n" + content_type + "\r\n\r\n" + fmt.Sprintf(".%s", time.Now().Format("2006-01-02 15:04:05")) + msg) ">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", time.Now().Format("2006-01-02 15:04:05")) + msg)
return s.sendMail(s.Host, auth, s.FromAddress, s.RecipientAddresses, mailmsg) return s.sendMail(s.Host, auth, s.FromAddress, s.RecipientAddresses, mailmsg)
} }
// implementing method. empty. // Flush implementing method. empty.
func (s *SmtpWriter) Flush() { func (s *SMTPWriter) Flush() {
return return
} }
// implementing method. empty. // Destroy implementing method. empty.
func (s *SmtpWriter) Destroy() { func (s *SMTPWriter) Destroy() {
return return
} }
func init() { func init() {
Register("smtp", NewSmtpWriter) Register("smtp", newSMTPWriter)
} }

View File

@ -1,212 +0,0 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package beego
import (
"bytes"
"compress/flate"
"compress/gzip"
"errors"
"io"
"io/ioutil"
"net/http"
"os"
"strings"
"sync"
"time"
)
var gmfim map[string]*memFileInfo = make(map[string]*memFileInfo)
var lock sync.RWMutex
// OpenMemZipFile returns MemFile object with a compressed static file.
// it's used for serve static file if gzip enable.
func openMemZipFile(path string, zip string) (*memFile, error) {
osfile, e := os.Open(path)
if e != nil {
return nil, e
}
defer osfile.Close()
osfileinfo, e := osfile.Stat()
if e != nil {
return nil, e
}
modtime := osfileinfo.ModTime()
fileSize := osfileinfo.Size()
lock.RLock()
cfi, ok := gmfim[zip+":"+path]
lock.RUnlock()
if !(ok && cfi.ModTime() == modtime && cfi.fileSize == fileSize) {
var content []byte
if zip == "gzip" {
var zipbuf bytes.Buffer
gzipwriter, e := gzip.NewWriterLevel(&zipbuf, gzip.BestCompression)
if e != nil {
return nil, e
}
_, e = io.Copy(gzipwriter, osfile)
gzipwriter.Close()
if e != nil {
return nil, e
}
content, e = ioutil.ReadAll(&zipbuf)
if e != nil {
return nil, e
}
} else if zip == "deflate" {
var zipbuf bytes.Buffer
deflatewriter, e := flate.NewWriter(&zipbuf, flate.BestCompression)
if e != nil {
return nil, e
}
_, e = io.Copy(deflatewriter, osfile)
deflatewriter.Close()
if e != nil {
return nil, e
}
content, e = ioutil.ReadAll(&zipbuf)
if e != nil {
return nil, e
}
} else {
content, e = ioutil.ReadAll(osfile)
if e != nil {
return nil, e
}
}
cfi = &memFileInfo{osfileinfo, modtime, content, int64(len(content)), fileSize}
lock.Lock()
defer lock.Unlock()
gmfim[zip+":"+path] = cfi
}
return &memFile{fi: cfi, offset: 0}, nil
}
// MemFileInfo contains a compressed file bytes and file information.
// it implements os.FileInfo interface.
type memFileInfo struct {
os.FileInfo
modTime time.Time
content []byte
contentSize int64
fileSize int64
}
// Name returns the compressed filename.
func (fi *memFileInfo) Name() string {
return fi.Name()
}
// Size returns the raw file content size, not compressed size.
func (fi *memFileInfo) Size() int64 {
return fi.contentSize
}
// Mode returns file mode.
func (fi *memFileInfo) Mode() os.FileMode {
return fi.Mode()
}
// ModTime returns the last modified time of raw file.
func (fi *memFileInfo) ModTime() time.Time {
return fi.modTime
}
// IsDir returns the compressing file is a directory or not.
func (fi *memFileInfo) IsDir() bool {
return fi.IsDir()
}
// return nil. implement the os.FileInfo interface method.
func (fi *memFileInfo) Sys() interface{} {
return nil
}
// MemFile contains MemFileInfo and bytes offset when reading.
// it implements io.Reader,io.ReadCloser and io.Seeker.
type memFile struct {
fi *memFileInfo
offset int64
}
// Close memfile.
func (f *memFile) Close() error {
return nil
}
// Get os.FileInfo of memfile.
func (f *memFile) Stat() (os.FileInfo, error) {
return f.fi, nil
}
// read os.FileInfo of files in directory of memfile.
// it returns empty slice.
func (f *memFile) Readdir(count int) ([]os.FileInfo, error) {
infos := []os.FileInfo{}
return infos, nil
}
// Read bytes from the compressed file bytes.
func (f *memFile) Read(p []byte) (n int, err error) {
if len(f.fi.content)-int(f.offset) >= len(p) {
n = len(p)
} else {
n = len(f.fi.content) - int(f.offset)
err = io.EOF
}
copy(p, f.fi.content[f.offset:f.offset+int64(n)])
f.offset += int64(n)
return
}
var errWhence = errors.New("Seek: invalid whence")
var errOffset = errors.New("Seek: invalid offset")
// Read bytes from the compressed file bytes by seeker.
func (f *memFile) Seek(offset int64, whence int) (ret int64, err error) {
switch whence {
default:
return 0, errWhence
case os.SEEK_SET:
case os.SEEK_CUR:
offset += f.offset
case os.SEEK_END:
offset += int64(len(f.fi.content))
}
if offset < 0 || int(offset) > len(f.fi.content) {
return 0, errOffset
}
f.offset = offset
return f.offset, nil
}
// GetAcceptEncodingZip returns accept encoding format in http header.
// zip is first, then deflate if both accepted.
// If no accepted, return empty string.
func getAcceptEncodingZip(r *http.Request) string {
ss := r.Header.Get("Accept-Encoding")
ss = strings.ToLower(ss)
if strings.Contains(ss, "gzip") {
return "gzip"
} else if strings.Contains(ss, "deflate") {
return "deflate"
} else {
return ""
}
}

View File

@ -1,71 +0,0 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Usage:
//
// import "github.com/astaxie/beego/middleware"
//
// I18N = middleware.NewLocale("conf/i18n.conf", beego.AppConfig.String("language"))
//
// more docs: http://beego.me/docs/module/i18n.md
package middleware
import (
"encoding/json"
"io/ioutil"
"os"
)
type Translation struct {
filepath string
CurrentLocal string
Locales map[string]map[string]string
}
func NewLocale(filepath string, defaultlocal string) *Translation {
file, err := os.Open(filepath)
if err != nil {
panic("open " + filepath + " err :" + err.Error())
}
data, err := ioutil.ReadAll(file)
if err != nil {
panic("read " + filepath + " err :" + err.Error())
}
i18n := make(map[string]map[string]string)
if err = json.Unmarshal(data, &i18n); err != nil {
panic("json.Unmarshal " + filepath + " err :" + err.Error())
}
return &Translation{
filepath: filepath,
CurrentLocal: defaultlocal,
Locales: i18n,
}
}
func (t *Translation) SetLocale(local string) {
t.CurrentLocal = local
}
func (t *Translation) Translate(key string, local string) string {
if local == "" {
local = t.CurrentLocal
}
if ct, ok := t.Locales[key]; ok {
if v, o := ct[local]; o {
return v
}
}
return key
}

View File

@ -14,33 +14,40 @@
package migration package migration
// Table store the tablename and Column
type Table struct { type Table struct {
TableName string TableName string
Columns []*Column Columns []*Column
} }
// Create return the create sql
func (t *Table) Create() string { func (t *Table) Create() string {
return "" return ""
} }
// Drop return the drop sql
func (t *Table) Drop() string { func (t *Table) Drop() string {
return "" return ""
} }
// Column define the columns name type and Default
type Column struct { type Column struct {
Name string Name string
Type string Type string
Default interface{} Default interface{}
} }
// Create return create sql with the provided tbname and columns
func Create(tbname string, columns ...Column) string { func Create(tbname string, columns ...Column) string {
return "" return ""
} }
// Drop return the drop sql with the provided tbname and columns
func Drop(tbname string, columns ...Column) string { func Drop(tbname string, columns ...Column) string {
return "" return ""
} }
// TableDDL is still in think
func TableDDL(tbname string, columns ...Column) string { func TableDDL(tbname string, columns ...Column) string {
return "" return ""
} }

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// migration package for migration // Package migration is used for migration
// //
// The table structure is as follow: // The table structure is as follow:
// //
@ -39,8 +39,8 @@ import (
// const the data format for the bee generate migration datatype // const the data format for the bee generate migration datatype
const ( const (
M_DATE_FORMAT = "20060102_150405" DateFormat = "20060102_150405"
M_DB_DATE_FORMAT = "2006-01-02 15:04:05" DBDateFormat = "2006-01-02 15:04:05"
) )
// Migrationer is an interface for all Migration struct // Migrationer is an interface for all Migration struct
@ -60,24 +60,24 @@ func init() {
migrationMap = make(map[string]Migrationer) migrationMap = make(map[string]Migrationer)
} }
// the basic type which will implement the basic type // Migration the basic type which will implement the basic type
type Migration struct { type Migration struct {
sqls []string sqls []string
Created string Created string
} }
// implement in the Inheritance struct for upgrade // Up implement in the Inheritance struct for upgrade
func (m *Migration) Up() { func (m *Migration) Up() {
} }
// implement in the Inheritance struct for down // Down implement in the Inheritance struct for down
func (m *Migration) Down() { func (m *Migration) Down() {
} }
// add sql want to execute // SQL add sql want to execute
func (m *Migration) Sql(sql string) { func (m *Migration) SQL(sql string) {
m.sqls = append(m.sqls, sql) m.sqls = append(m.sqls, sql)
} }
@ -86,7 +86,7 @@ func (m *Migration) Reset() {
m.sqls = make([]string, 0) m.sqls = make([]string, 0)
} }
// execute the sql already add in the sql // Exec execute the sql already add in the sql
func (m *Migration) Exec(name, status string) error { func (m *Migration) Exec(name, status string) error {
o := orm.NewOrm() o := orm.NewOrm()
for _, s := range m.sqls { for _, s := range m.sqls {
@ -104,33 +104,32 @@ func (m *Migration) addOrUpdateRecord(name, status string) error {
o := orm.NewOrm() o := orm.NewOrm()
if status == "down" { if status == "down" {
status = "rollback" status = "rollback"
p, err := o.Raw("update migrations set `status` = ?, `rollback_statements` = ?, `created_at` = ? where name = ?").Prepare() p, err := o.Raw("update migrations set status = ?, rollback_statements = ?, created_at = ? where name = ?").Prepare()
if err != nil { if err != nil {
return nil return nil
} }
_, err = p.Exec(status, strings.Join(m.sqls, "; "), time.Now().Format(M_DB_DATE_FORMAT), name) _, err = p.Exec(status, strings.Join(m.sqls, "; "), time.Now().Format(DBDateFormat), name)
return err
} else {
status = "update"
p, err := o.Raw("insert into migrations(`name`, `created_at`, `statements`, `status`) values(?,?,?,?)").Prepare()
if err != nil {
return err
}
_, err = p.Exec(name, time.Now().Format(M_DB_DATE_FORMAT), strings.Join(m.sqls, "; "), status)
return err return err
} }
status = "update"
p, err := o.Raw("insert into migrations(name, created_at, statements, status) values(?,?,?,?)").Prepare()
if err != nil {
return err
}
_, err = p.Exec(name, time.Now().Format(DBDateFormat), strings.Join(m.sqls, "; "), status)
return err
} }
// get the unixtime from the Created // GetCreated get the unixtime from the Created
func (m *Migration) GetCreated() int64 { func (m *Migration) GetCreated() int64 {
t, err := time.Parse(M_DATE_FORMAT, m.Created) t, err := time.Parse(DateFormat, m.Created)
if err != nil { if err != nil {
return 0 return 0
} }
return t.Unix() return t.Unix()
} }
// register the Migration in the map // Register register the Migration in the map
func Register(name string, m Migrationer) error { func Register(name string, m Migrationer) error {
if _, ok := migrationMap[name]; ok { if _, ok := migrationMap[name]; ok {
return errors.New("already exist name:" + name) return errors.New("already exist name:" + name)
@ -139,7 +138,7 @@ func Register(name string, m Migrationer) error {
return nil return nil
} }
// upgrate the migration from lasttime // Upgrade upgrate the migration from lasttime
func Upgrade(lasttime int64) error { func Upgrade(lasttime int64) error {
sm := sortMap(migrationMap) sm := sortMap(migrationMap)
i := 0 i := 0
@ -163,7 +162,7 @@ func Upgrade(lasttime int64) error {
return nil return nil
} }
//rollback the migration by the name // Rollback rollback the migration by the name
func Rollback(name string) error { func Rollback(name string) error {
if v, ok := migrationMap[name]; ok { if v, ok := migrationMap[name]; ok {
beego.Info("start rollback") beego.Info("start rollback")
@ -178,14 +177,13 @@ func Rollback(name string) error {
beego.Info("end rollback") beego.Info("end rollback")
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
return nil return nil
} else {
beego.Error("not exist the migrationMap name:" + name)
time.Sleep(2 * time.Second)
return errors.New("not exist the migrationMap name:" + name)
} }
beego.Error("not exist the migrationMap name:" + name)
time.Sleep(2 * time.Second)
return errors.New("not exist the migrationMap name:" + name)
} }
// reset all migration // Reset reset all migration
// run all migration's down function // run all migration's down function
func Reset() error { func Reset() error {
sm := sortMap(migrationMap) sm := sortMap(migrationMap)
@ -214,7 +212,7 @@ func Reset() error {
return nil return nil
} }
// first Reset, then Upgrade // Refresh first Reset, then Upgrade
func Refresh() error { func Refresh() error {
err := Reset() err := Reset()
if err != nil { if err != nil {

13
mime.go
View File

@ -14,11 +14,7 @@
package beego package beego
import ( var mimemaps = map[string]string{
"mime"
)
var mimemaps map[string]string = map[string]string{
".3dm": "x-world/x-3dmf", ".3dm": "x-world/x-3dmf",
".3dmf": "x-world/x-3dmf", ".3dmf": "x-world/x-3dmf",
".7z": "application/x-7z-compressed", ".7z": "application/x-7z-compressed",
@ -558,10 +554,3 @@ var mimemaps map[string]string = map[string]string{
".oex": "application/x-opera-extension", ".oex": "application/x-opera-extension",
".mustache": "text/html", ".mustache": "text/html",
} }
func initMime() error {
for k, v := range mimemaps {
mime.AddExtensionType(k, v)
}
return nil
}

View File

@ -23,16 +23,17 @@ import (
type namespaceCond func(*beecontext.Context) bool type namespaceCond func(*beecontext.Context) bool
type innnerNamespace func(*Namespace) // LinkNamespace used as link action
type LinkNamespace func(*Namespace)
// Namespace is store all the info // Namespace is store all the info
type Namespace struct { type Namespace struct {
prefix string prefix string
handlers *ControllerRegistor handlers *ControllerRegister
} }
// get new Namespace // NewNamespace get new Namespace
func NewNamespace(prefix string, params ...innnerNamespace) *Namespace { func NewNamespace(prefix string, params ...LinkNamespace) *Namespace {
ns := &Namespace{ ns := &Namespace{
prefix: prefix, prefix: prefix,
handlers: NewControllerRegister(), handlers: NewControllerRegister(),
@ -43,7 +44,7 @@ func NewNamespace(prefix string, params ...innnerNamespace) *Namespace {
return ns return ns
} }
// set condtion function // Cond set condtion function
// if cond return true can run this namespace, else can't // if cond return true can run this namespace, else can't
// usage: // usage:
// ns.Cond(func (ctx *context.Context) bool{ // ns.Cond(func (ctx *context.Context) bool{
@ -72,7 +73,7 @@ func (n *Namespace) Cond(cond namespaceCond) *Namespace {
return n return n
} }
// add filter in the Namespace // Filter add filter in the Namespace
// action has before & after // action has before & after
// FilterFunc // FilterFunc
// usage: // usage:
@ -95,98 +96,98 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace {
return n return n
} }
// same as beego.Rourer // Router same as beego.Rourer
// refer: https://godoc.org/github.com/astaxie/beego#Router // refer: https://godoc.org/github.com/astaxie/beego#Router
func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace {
n.handlers.Add(rootpath, c, mappingMethods...) n.handlers.Add(rootpath, c, mappingMethods...)
return n return n
} }
// same as beego.AutoRouter // AutoRouter same as beego.AutoRouter
// refer: https://godoc.org/github.com/astaxie/beego#AutoRouter // refer: https://godoc.org/github.com/astaxie/beego#AutoRouter
func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace { func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace {
n.handlers.AddAuto(c) n.handlers.AddAuto(c)
return n return n
} }
// same as beego.AutoPrefix // AutoPrefix same as beego.AutoPrefix
// refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix // refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix
func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace { func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace {
n.handlers.AddAutoPrefix(prefix, c) n.handlers.AddAutoPrefix(prefix, c)
return n return n
} }
// same as beego.Get // Get same as beego.Get
// refer: https://godoc.org/github.com/astaxie/beego#Get // refer: https://godoc.org/github.com/astaxie/beego#Get
func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace { func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace {
n.handlers.Get(rootpath, f) n.handlers.Get(rootpath, f)
return n return n
} }
// same as beego.Post // Post same as beego.Post
// refer: https://godoc.org/github.com/astaxie/beego#Post // refer: https://godoc.org/github.com/astaxie/beego#Post
func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace { func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace {
n.handlers.Post(rootpath, f) n.handlers.Post(rootpath, f)
return n return n
} }
// same as beego.Delete // Delete same as beego.Delete
// refer: https://godoc.org/github.com/astaxie/beego#Delete // refer: https://godoc.org/github.com/astaxie/beego#Delete
func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace { func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace {
n.handlers.Delete(rootpath, f) n.handlers.Delete(rootpath, f)
return n return n
} }
// same as beego.Put // Put same as beego.Put
// refer: https://godoc.org/github.com/astaxie/beego#Put // refer: https://godoc.org/github.com/astaxie/beego#Put
func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace { func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace {
n.handlers.Put(rootpath, f) n.handlers.Put(rootpath, f)
return n return n
} }
// same as beego.Head // Head same as beego.Head
// refer: https://godoc.org/github.com/astaxie/beego#Head // refer: https://godoc.org/github.com/astaxie/beego#Head
func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace { func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace {
n.handlers.Head(rootpath, f) n.handlers.Head(rootpath, f)
return n return n
} }
// same as beego.Options // Options same as beego.Options
// refer: https://godoc.org/github.com/astaxie/beego#Options // refer: https://godoc.org/github.com/astaxie/beego#Options
func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace { func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace {
n.handlers.Options(rootpath, f) n.handlers.Options(rootpath, f)
return n return n
} }
// same as beego.Patch // Patch same as beego.Patch
// refer: https://godoc.org/github.com/astaxie/beego#Patch // refer: https://godoc.org/github.com/astaxie/beego#Patch
func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace { func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace {
n.handlers.Patch(rootpath, f) n.handlers.Patch(rootpath, f)
return n return n
} }
// same as beego.Any // Any same as beego.Any
// refer: https://godoc.org/github.com/astaxie/beego#Any // refer: https://godoc.org/github.com/astaxie/beego#Any
func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace { func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace {
n.handlers.Any(rootpath, f) n.handlers.Any(rootpath, f)
return n return n
} }
// same as beego.Handler // Handler same as beego.Handler
// refer: https://godoc.org/github.com/astaxie/beego#Handler // refer: https://godoc.org/github.com/astaxie/beego#Handler
func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace { func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace {
n.handlers.Handler(rootpath, h) n.handlers.Handler(rootpath, h)
return n return n
} }
// add include class // Include add include class
// refer: https://godoc.org/github.com/astaxie/beego#Include // refer: https://godoc.org/github.com/astaxie/beego#Include
func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { func (n *Namespace) Include(cList ...ControllerInterface) *Namespace {
n.handlers.Include(cList...) n.handlers.Include(cList...)
return n return n
} }
// nest Namespace // Namespace add nest Namespace
// usage: // usage:
//ns := beego.NewNamespace(“/v1”). //ns := beego.NewNamespace(“/v1”).
//Namespace( //Namespace(
@ -230,7 +231,7 @@ func (n *Namespace) Namespace(ns ...*Namespace) *Namespace {
return n return n
} }
// register Namespace into beego.Handler // AddNamespace register Namespace into beego.Handler
// support multi Namespace // support multi Namespace
func AddNamespace(nl ...*Namespace) { func AddNamespace(nl ...*Namespace) {
for _, n := range nl { for _, n := range nl {
@ -275,113 +276,113 @@ func addPrefix(t *Tree, prefix string) {
} }
// Namespace Condition // NSCond is Namespace Condition
func NSCond(cond namespaceCond) innnerNamespace { func NSCond(cond namespaceCond) LinkNamespace {
return func(ns *Namespace) { return func(ns *Namespace) {
ns.Cond(cond) ns.Cond(cond)
} }
} }
// Namespace BeforeRouter filter // NSBefore Namespace BeforeRouter filter
func NSBefore(filiterList ...FilterFunc) innnerNamespace { func NSBefore(filiterList ...FilterFunc) LinkNamespace {
return func(ns *Namespace) { return func(ns *Namespace) {
ns.Filter("before", filiterList...) ns.Filter("before", filiterList...)
} }
} }
// Namespace FinishRouter filter // NSAfter add Namespace FinishRouter filter
func NSAfter(filiterList ...FilterFunc) innnerNamespace { func NSAfter(filiterList ...FilterFunc) LinkNamespace {
return func(ns *Namespace) { return func(ns *Namespace) {
ns.Filter("after", filiterList...) ns.Filter("after", filiterList...)
} }
} }
// Namespace Include ControllerInterface // NSInclude Namespace Include ControllerInterface
func NSInclude(cList ...ControllerInterface) innnerNamespace { func NSInclude(cList ...ControllerInterface) LinkNamespace {
return func(ns *Namespace) { return func(ns *Namespace) {
ns.Include(cList...) ns.Include(cList...)
} }
} }
// Namespace Router // NSRouter call Namespace Router
func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) innnerNamespace { func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace {
return func(ns *Namespace) { return func(ns *Namespace) {
ns.Router(rootpath, c, mappingMethods...) ns.Router(rootpath, c, mappingMethods...)
} }
} }
// Namespace Get // NSGet call Namespace Get
func NSGet(rootpath string, f FilterFunc) innnerNamespace { func NSGet(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) { return func(ns *Namespace) {
ns.Get(rootpath, f) ns.Get(rootpath, f)
} }
} }
// Namespace Post // NSPost call Namespace Post
func NSPost(rootpath string, f FilterFunc) innnerNamespace { func NSPost(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) { return func(ns *Namespace) {
ns.Post(rootpath, f) ns.Post(rootpath, f)
} }
} }
// Namespace Head // NSHead call Namespace Head
func NSHead(rootpath string, f FilterFunc) innnerNamespace { func NSHead(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) { return func(ns *Namespace) {
ns.Head(rootpath, f) ns.Head(rootpath, f)
} }
} }
// Namespace Put // NSPut call Namespace Put
func NSPut(rootpath string, f FilterFunc) innnerNamespace { func NSPut(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) { return func(ns *Namespace) {
ns.Put(rootpath, f) ns.Put(rootpath, f)
} }
} }
// Namespace Delete // NSDelete call Namespace Delete
func NSDelete(rootpath string, f FilterFunc) innnerNamespace { func NSDelete(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) { return func(ns *Namespace) {
ns.Delete(rootpath, f) ns.Delete(rootpath, f)
} }
} }
// Namespace Any // NSAny call Namespace Any
func NSAny(rootpath string, f FilterFunc) innnerNamespace { func NSAny(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) { return func(ns *Namespace) {
ns.Any(rootpath, f) ns.Any(rootpath, f)
} }
} }
// Namespace Options // NSOptions call Namespace Options
func NSOptions(rootpath string, f FilterFunc) innnerNamespace { func NSOptions(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) { return func(ns *Namespace) {
ns.Options(rootpath, f) ns.Options(rootpath, f)
} }
} }
// Namespace Patch // NSPatch call Namespace Patch
func NSPatch(rootpath string, f FilterFunc) innnerNamespace { func NSPatch(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) { return func(ns *Namespace) {
ns.Patch(rootpath, f) ns.Patch(rootpath, f)
} }
} }
//Namespace AutoRouter // NSAutoRouter call Namespace AutoRouter
func NSAutoRouter(c ControllerInterface) innnerNamespace { func NSAutoRouter(c ControllerInterface) LinkNamespace {
return func(ns *Namespace) { return func(ns *Namespace) {
ns.AutoRouter(c) ns.AutoRouter(c)
} }
} }
// Namespace AutoPrefix // NSAutoPrefix call Namespace AutoPrefix
func NSAutoPrefix(prefix string, c ControllerInterface) innnerNamespace { func NSAutoPrefix(prefix string, c ControllerInterface) LinkNamespace {
return func(ns *Namespace) { return func(ns *Namespace) {
ns.AutoPrefix(prefix, c) ns.AutoPrefix(prefix, c)
} }
} }
// Namespace add sub Namespace // NSNamespace add sub Namespace
func NSNamespace(prefix string, params ...innnerNamespace) innnerNamespace { func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace {
return func(ns *Namespace) { return func(ns *Namespace) {
n := NewNamespace(prefix, params...) n := NewNamespace(prefix, params...)
ns.Namespace(n) ns.Namespace(n)

View File

@ -46,7 +46,7 @@ func printHelp(errs ...string) {
os.Exit(2) os.Exit(2)
} }
// listen for orm command and then run it if command arguments passed. // RunCommand listen for orm command and then run it if command arguments passed.
func RunCommand() { func RunCommand() {
if len(os.Args) < 2 || os.Args[1] != "orm" { if len(os.Args) < 2 || os.Args[1] != "orm" {
return return
@ -100,7 +100,7 @@ func (d *commandSyncDb) Parse(args []string) {
func (d *commandSyncDb) Run() error { func (d *commandSyncDb) Run() error {
var drops []string var drops []string
if d.force { if d.force {
drops = getDbDropSql(d.al) drops = getDbDropSQL(d.al)
} }
db := d.al.DB db := d.al.DB
@ -124,7 +124,7 @@ func (d *commandSyncDb) Run() error {
} }
} }
sqls, indexes := getDbCreateSql(d.al) sqls, indexes := getDbCreateSQL(d.al)
tables, err := d.al.DbBaser.GetTables(db) tables, err := d.al.DbBaser.GetTables(db)
if err != nil { if err != nil {
@ -180,7 +180,7 @@ func (d *commandSyncDb) Run() error {
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
} }
query := idx.Sql query := idx.SQL
_, err := db.Exec(query) _, err := db.Exec(query)
if d.verbose { if d.verbose {
fmt.Printf(" %s\n", query) fmt.Printf(" %s\n", query)
@ -203,7 +203,7 @@ func (d *commandSyncDb) Run() error {
queries := []string{sqls[i]} queries := []string{sqls[i]}
for _, idx := range indexes[mi.table] { for _, idx := range indexes[mi.table] {
queries = append(queries, idx.Sql) queries = append(queries, idx.SQL)
} }
for _, query := range queries { for _, query := range queries {
@ -228,12 +228,12 @@ func (d *commandSyncDb) Run() error {
} }
// database creation commander interface implement. // database creation commander interface implement.
type commandSqlAll struct { type commandSQLAll struct {
al *alias al *alias
} }
// parse orm command line arguments. // parse orm command line arguments.
func (d *commandSqlAll) Parse(args []string) { func (d *commandSQLAll) Parse(args []string) {
var name string var name string
flagSet := flag.NewFlagSet("orm command: sqlall", flag.ExitOnError) flagSet := flag.NewFlagSet("orm command: sqlall", flag.ExitOnError)
@ -244,13 +244,13 @@ func (d *commandSqlAll) Parse(args []string) {
} }
// run orm line command. // run orm line command.
func (d *commandSqlAll) Run() error { func (d *commandSQLAll) Run() error {
sqls, indexes := getDbCreateSql(d.al) sqls, indexes := getDbCreateSQL(d.al)
var all []string var all []string
for i, mi := range modelCache.allOrdered() { for i, mi := range modelCache.allOrdered() {
queries := []string{sqls[i]} queries := []string{sqls[i]}
for _, idx := range indexes[mi.table] { for _, idx := range indexes[mi.table] {
queries = append(queries, idx.Sql) queries = append(queries, idx.SQL)
} }
sql := strings.Join(queries, "\n") sql := strings.Join(queries, "\n")
all = append(all, sql) all = append(all, sql)
@ -262,10 +262,10 @@ func (d *commandSqlAll) Run() error {
func init() { func init() {
commands["syncdb"] = new(commandSyncDb) commands["syncdb"] = new(commandSyncDb)
commands["sqlall"] = new(commandSqlAll) commands["sqlall"] = new(commandSQLAll)
} }
// run syncdb command line. // RunSyncdb run syncdb command line.
// name means table's alias name. default is "default". // name means table's alias name. default is "default".
// force means run next sql if the current is error. // force means run next sql if the current is error.
// verbose means show all info when running command or not. // verbose means show all info when running command or not.

View File

@ -23,11 +23,11 @@ import (
type dbIndex struct { type dbIndex struct {
Table string Table string
Name string Name string
Sql string SQL string
} }
// create database drop sql. // create database drop sql.
func getDbDropSql(al *alias) (sqls []string) { func getDbDropSQL(al *alias) (sqls []string) {
if len(modelCache.cache) == 0 { if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model") fmt.Println("no Model found, need register your model")
os.Exit(2) os.Exit(2)
@ -45,13 +45,14 @@ func getDbDropSql(al *alias) (sqls []string) {
func getColumnTyp(al *alias, fi *fieldInfo) (col string) { func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
T := al.DbBaser.DbTypes() T := al.DbBaser.DbTypes()
fieldType := fi.fieldType fieldType := fi.fieldType
fieldSize := fi.size
checkColumn: checkColumn:
switch fieldType { switch fieldType {
case TypeBooleanField: case TypeBooleanField:
col = T["bool"] col = T["bool"]
case TypeCharField: case TypeCharField:
col = fmt.Sprintf(T["string"], fi.size) col = fmt.Sprintf(T["string"], fieldSize)
case TypeTextField: case TypeTextField:
col = T["string-text"] col = T["string-text"]
case TypeDateField: case TypeDateField:
@ -65,7 +66,7 @@ checkColumn:
case TypeIntegerField: case TypeIntegerField:
col = T["int32"] col = T["int32"]
case TypeBigIntegerField: case TypeBigIntegerField:
if al.Driver == DR_Sqlite { if al.Driver == DRSqlite {
fieldType = TypeIntegerField fieldType = TypeIntegerField
goto checkColumn goto checkColumn
} }
@ -89,6 +90,7 @@ checkColumn:
} }
case RelForeignKey, RelOneToOne: case RelForeignKey, RelOneToOne:
fieldType = fi.relModelInfo.fields.pk.fieldType fieldType = fi.relModelInfo.fields.pk.fieldType
fieldSize = fi.relModelInfo.fields.pk.size
goto checkColumn goto checkColumn
} }
@ -104,15 +106,15 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string {
typ += " " + "NOT NULL" typ += " " + "NOT NULL"
} }
return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s", return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s",
Q, fi.mi.table, Q, Q, fi.mi.table, Q,
Q, fi.column, Q, Q, fi.column, Q,
typ, getColumnDefault(fi), typ, getColumnDefault(fi),
) )
} }
// create database creation string. // create database creation string.
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
if len(modelCache.cache) == 0 { if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model") fmt.Println("no Model found, need register your model")
os.Exit(2) os.Exit(2)
@ -142,7 +144,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
if fi.auto { if fi.auto {
switch al.Driver { switch al.Driver {
case DR_Sqlite, DR_Postgres: case DRSqlite, DRPostgres:
column += T["auto"] column += T["auto"]
default: default:
column += col + " " + T["auto"] column += col + " " + T["auto"]
@ -159,7 +161,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
//if fi.initial.String() != "" { //if fi.initial.String() != "" {
// column += " DEFAULT " + fi.initial.String() // column += " DEFAULT " + fi.initial.String()
//} //}
// Append attribute DEFAULT // Append attribute DEFAULT
column += getColumnDefault(fi) column += getColumnDefault(fi)
@ -201,7 +203,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
sql += strings.Join(columns, ",\n") sql += strings.Join(columns, ",\n")
sql += "\n)" sql += "\n)"
if al.Driver == DR_MySQL { if al.Driver == DRMySQL {
var engine string var engine string
if mi.model != nil { if mi.model != nil {
engine = getTableEngine(mi.addrField) engine = getTableEngine(mi.addrField)
@ -237,7 +239,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
index := dbIndex{} index := dbIndex{}
index.Table = mi.table index.Table = mi.table
index.Name = name index.Name = name
index.Sql = sql index.SQL = sql
tableIndexes[mi.table] = append(tableIndexes[mi.table], index) tableIndexes[mi.table] = append(tableIndexes[mi.table], index)
} }
@ -247,7 +249,6 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
return return
} }
// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands // Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands
func getColumnDefault(fi *fieldInfo) string { func getColumnDefault(fi *fieldInfo) string {
var ( var (
@ -263,16 +264,20 @@ func getColumnDefault(fi *fieldInfo) string {
// These defaults will be useful if there no config value orm:"default" and NOT NULL is on // These defaults will be useful if there no config value orm:"default" and NOT NULL is on
switch fi.fieldType { switch fi.fieldType {
case TypeDateField, TypeDateTimeField: case TypeDateField, TypeDateTimeField, TypeTextField:
return v; return v
case TypeBooleanField, TypeBitField, TypeSmallIntegerField, TypeIntegerField, case TypeBitField, TypeSmallIntegerField, TypeIntegerField,
TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField, TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField,
TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField, TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField,
TypeDecimalField: TypeDecimalField:
d = "0" t = " DEFAULT %s "
d = "0"
case TypeBooleanField:
t = " DEFAULT %s "
d = "FALSE"
} }
if fi.colDefault { if fi.colDefault {
if !fi.initial.Exist() { if !fi.initial.Exist() {
v = fmt.Sprintf(t, "") v = fmt.Sprintf(t, "")

175
orm/db.go
View File

@ -24,12 +24,13 @@ import (
) )
const ( const (
format_Date = "2006-01-02" formatDate = "2006-01-02"
format_DateTime = "2006-01-02 15:04:05" formatDateTime = "2006-01-02 15:04:05"
) )
var ( var (
ErrMissPK = errors.New("missed pk value") // missing pk error // ErrMissPK missing pk error
ErrMissPK = errors.New("missed pk value")
) )
var ( var (
@ -216,14 +217,14 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} }
} }
if fi.null == false && value == nil { if fi.null == false && value == nil {
return nil, errors.New(fmt.Sprintf("field `%s` cannot be NULL", fi.fullName)) return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName)
} }
} }
} }
} }
switch fi.fieldType { switch fi.fieldType {
case TypeDateField, TypeDateTimeField: case TypeDateField, TypeDateTimeField:
if fi.auto_now || fi.auto_now_add && insert { if fi.autoNow || fi.autoNowAdd && insert {
if insert { if insert {
if t, ok := value.(time.Time); ok && !t.IsZero() { if t, ok := value.(time.Time); ok && !t.IsZero() {
break break
@ -282,13 +283,12 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
var id int64 var id int64
err := row.Scan(&id) err := row.Scan(&id)
return id, err return id, err
} else {
if res, err := stmt.Exec(values...); err == nil {
return res.LastInsertId()
} else {
return 0, err
}
} }
res, err := stmt.Exec(values...)
if err == nil {
return res.LastInsertId()
}
return 0, err
} }
// query sql ,read records and persist in dbBaser. // query sql ,read records and persist in dbBaser.
@ -339,15 +339,11 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
return ErrNoRows return ErrNoRows
} }
return err return err
} else {
elm := reflect.New(mi.addrField.Elem().Type())
mind := reflect.Indirect(elm)
d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz)
ind.Set(mind)
} }
elm := reflect.New(mi.addrField.Elem().Type())
mind := reflect.Indirect(elm)
d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz)
ind.Set(mind)
return nil return nil
} }
@ -444,20 +440,19 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) { if isMulti || !d.ins.HasReturningID(mi, &query) {
if res, err := q.Exec(query, values...); err == nil { res, err := q.Exec(query, values...)
if err == nil {
if isMulti { if isMulti {
return res.RowsAffected() return res.RowsAffected()
} }
return res.LastInsertId() return res.LastInsertId()
} else {
return 0, err
} }
} else { return 0, err
row := q.QueryRow(query, values...)
var id int64
err := row.Scan(&id)
return id, err
} }
row := q.QueryRow(query, values...)
var id int64
err := row.Scan(&id)
return id, err
} }
// execute update sql dbQuerier with given struct reflect.Value. // execute update sql dbQuerier with given struct reflect.Value.
@ -493,11 +488,11 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
if res, err := q.Exec(query, setValues...); err == nil { res, err := q.Exec(query, setValues...)
if err == nil {
return res.RowsAffected() return res.RowsAffected()
} else {
return 0, err
} }
return 0, err
} }
// execute delete sql dbQuerier with given struct reflect.Value. // execute delete sql dbQuerier with given struct reflect.Value.
@ -513,14 +508,12 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, pkName, Q) query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, pkName, Q)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, pkValue)
if res, err := q.Exec(query, pkValue); err == nil { if err == nil {
num, err := res.RowsAffected() num, err := res.RowsAffected()
if err != nil { if err != nil {
return 0, err return 0, err
} }
if num > 0 { if num > 0 {
if mi.fields.pk.auto { if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
@ -529,17 +522,14 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
ind.Field(mi.fields.pk.fieldIndex).SetInt(0) ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
} }
} }
err := d.deleteRels(q, mi, []interface{}{pkValue}, tz) err := d.deleteRels(q, mi, []interface{}{pkValue}, tz)
if err != nil { if err != nil {
return num, err return num, err
} }
} }
return num, err return num, err
} else {
return 0, err
} }
return 0, err
} }
// update table-related record by querySet. // update table-related record by querySet.
@ -565,11 +555,11 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
} }
where, args := tables.getCondSql(cond, false, tz) where, args := tables.getCondSQL(cond, false, tz)
values = append(values, args...) values = append(values, args...)
join := tables.getJoinSql() join := tables.getJoinSQL()
var query, T string var query, T string
@ -585,13 +575,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
col := fmt.Sprintf("%s%s%s%s", T, Q, v, Q) col := fmt.Sprintf("%s%s%s%s", T, Q, v, Q)
if c, ok := values[i].(colValue); ok { if c, ok := values[i].(colValue); ok {
switch c.opt { switch c.opt {
case Col_Add: case ColAdd:
cols = append(cols, col+" = "+col+" + ?") cols = append(cols, col+" = "+col+" + ?")
case Col_Minus: case ColMinus:
cols = append(cols, col+" = "+col+" - ?") cols = append(cols, col+" = "+col+" - ?")
case Col_Multiply: case ColMultiply:
cols = append(cols, col+" = "+col+" * ?") cols = append(cols, col+" = "+col+" * ?")
case Col_Except: case ColExcept:
cols = append(cols, col+" = "+col+" / ?") cols = append(cols, col+" = "+col+" / ?")
} }
values[i] = c.value values[i] = c.value
@ -610,12 +600,11 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
} }
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, values...)
if res, err := q.Exec(query, values...); err == nil { if err == nil {
return res.RowsAffected() return res.RowsAffected()
} else {
return 0, err
} }
return 0, err
} }
// delete related records. // delete related records.
@ -624,23 +613,23 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
for _, fi := range mi.fields.fieldsReverse { for _, fi := range mi.fields.fieldsReverse {
fi = fi.reverseFieldInfo fi = fi.reverseFieldInfo
switch fi.onDelete { switch fi.onDelete {
case od_CASCADE: case odCascade:
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
_, err := d.DeleteBatch(q, nil, fi.mi, cond, tz) _, err := d.DeleteBatch(q, nil, fi.mi, cond, tz)
if err != nil { if err != nil {
return err return err
} }
case od_SET_DEFAULT, od_SET_NULL: case odSetDefault, odSetNULL:
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
params := Params{fi.column: nil} params := Params{fi.column: nil}
if fi.onDelete == od_SET_DEFAULT { if fi.onDelete == odSetDefault {
params[fi.column] = fi.initial.String() params[fi.column] = fi.initial.String()
} }
_, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz) _, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz)
if err != nil { if err != nil {
return err return err
} }
case od_DO_NOTHING: case odDoNothing:
} }
} }
return nil return nil
@ -661,8 +650,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
where, args := tables.getCondSql(cond, false, tz) where, args := tables.getCondSQL(cond, false, tz)
join := tables.getJoinSql() join := tables.getJoinSQL()
cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q) cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q)
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where) query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where)
@ -670,16 +659,14 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
var rs *sql.Rows var rs *sql.Rows
if r, err := q.Query(query, args...); err != nil { r, err := q.Query(query, args...)
if err != nil {
return 0, err return 0, err
} else {
rs = r
} }
rs = r
defer rs.Close() defer rs.Close()
var ref interface{} var ref interface{}
args = make([]interface{}, 0) args = make([]interface{}, 0)
cnt := 0 cnt := 0
for rs.Next() { for rs.Next() {
@ -702,24 +689,21 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql) query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, args...)
if res, err := q.Exec(query, args...); err == nil { if err == nil {
num, err := res.RowsAffected() num, err := res.RowsAffected()
if err != nil { if err != nil {
return 0, err return 0, err
} }
if num > 0 { if num > 0 {
err := d.deleteRels(q, mi, args, tz) err := d.deleteRels(q, mi, args, tz)
if err != nil { if err != nil {
return num, err return num, err
} }
} }
return num, nil return num, nil
} else {
return 0, err
} }
return 0, err
} }
// read related records. // read related records.
@ -801,10 +785,11 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
where, args := tables.getCondSql(cond, false, tz) where, args := tables.getCondSQL(cond, false, tz)
orderBy := tables.getOrderSql(qs.orders) groupBy := tables.getGroupSQL(qs.groups)
limit := tables.getLimitSql(mi, offset, rlimit) orderBy := tables.getOrderSQL(qs.orders)
join := tables.getJoinSql() limit := tables.getLimitSQL(mi, offset, rlimit)
join := tables.getJoinSQL()
for _, tbl := range tables.tables { for _, tbl := range tables.tables {
if tbl.sel { if tbl.sel {
@ -814,16 +799,20 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
} }
} }
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit) sqlSelect := "SELECT"
if qs.distinct {
sqlSelect += " DISTINCT"
}
query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
var rs *sql.Rows var rs *sql.Rows
if r, err := q.Query(query, args...); err != nil { r, err := q.Query(query, args...)
if err != nil {
return 0, err return 0, err
} else {
rs = r
} }
rs = r
refs := make([]interface{}, colsNum) refs := make([]interface{}, colsNum)
for i := range refs { for i := range refs {
@ -937,9 +926,9 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
tables := newDbTables(mi, d.ins) tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth) tables.parseRelated(qs.related, qs.relDepth)
where, args := tables.getCondSql(cond, false, tz) where, args := tables.getCondSQL(cond, false, tz)
tables.getOrderSql(qs.orders) tables.getOrderSQL(qs.orders)
join := tables.getJoinSql() join := tables.getJoinSQL()
Q := d.ins.TableQuote() Q := d.ins.TableQuote()
@ -954,7 +943,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
} }
// generate sql with replacing operator string placeholders and replaced values. // generate sql with replacing operator string placeholders and replaced values.
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
sql := "" sql := ""
params := getFlatParams(fi, args, tz) params := getFlatParams(fi, args, tz)
@ -979,7 +968,7 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
if len(params) > 1 { if len(params) > 1 {
panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params))) panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params)))
} }
sql = d.ins.OperatorSql(operator) sql = d.ins.OperatorSQL(operator)
switch operator { switch operator {
case "exact": case "exact":
if arg == nil { if arg == nil {
@ -1107,12 +1096,12 @@ setValue:
) )
if len(s) >= 19 { if len(s) >= 19 {
s = s[:19] s = s[:19]
t, err = time.ParseInLocation(format_DateTime, s, tz) t, err = time.ParseInLocation(formatDateTime, s, tz)
} else { } else {
if len(s) > 10 { if len(s) > 10 {
s = s[:10] s = s[:10]
} }
t, err = time.ParseInLocation(format_Date, s, tz) t, err = time.ParseInLocation(formatDate, s, tz)
} }
t = t.In(DefaultTimeLoc) t = t.In(DefaultTimeLoc)
@ -1443,24 +1432,22 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
} }
} }
where, args := tables.getCondSql(cond, false, tz) where, args := tables.getCondSQL(cond, false, tz)
orderBy := tables.getOrderSql(qs.orders) groupBy := tables.getGroupSQL(qs.groups)
limit := tables.getLimitSql(mi, qs.offset, qs.limit) orderBy := tables.getOrderSQL(qs.orders)
join := tables.getJoinSql() limit := tables.getLimitSQL(mi, qs.offset, qs.limit)
join := tables.getJoinSQL()
sels := strings.Join(cols, ", ") sels := strings.Join(cols, ", ")
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit) query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s%s", sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
var rs *sql.Rows rs, err := q.Query(query, args...)
if r, err := q.Query(query, args...); err != nil { if err != nil {
return 0, err return 0, err
} else {
rs = r
} }
refs := make([]interface{}, len(cols)) refs := make([]interface{}, len(cols))
for i := range refs { for i := range refs {
var ref interface{} var ref interface{}
@ -1475,11 +1462,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
) )
for rs.Next() { for rs.Next() {
if cnt == 0 { if cnt == 0 {
if cols, err := rs.Columns(); err != nil { cols, err := rs.Columns()
if err != nil {
return 0, err return 0, err
} else {
columns = cols
} }
columns = cols
} }
if err := rs.Scan(refs...); err != nil { if err := rs.Scan(refs...); err != nil {
@ -1643,7 +1630,7 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e
} }
// not implement. // not implement.
func (d *dbBase) OperatorSql(operator string) string { func (d *dbBase) OperatorSQL(operator string) string {
panic(ErrNotImplement) panic(ErrNotImplement)
} }

View File

@ -22,15 +22,17 @@ import (
"time" "time"
) )
// database driver constant int. // DriverType database driver constant int.
type DriverType int type DriverType int
// Enum the Database driver
const ( const (
_ DriverType = iota // int enum type _ DriverType = iota // int enum type
DR_MySQL // mysql DRMySQL // mysql
DR_Sqlite // sqlite DRSqlite // sqlite
DR_Oracle // oracle DROracle // oracle
DR_Postgres // pgsql DRPostgres // pgsql
DRTiDB // TiDB
) )
// database driver string. // database driver string.
@ -53,15 +55,17 @@ var _ Driver = new(driver)
var ( var (
dataBaseCache = &_dbCache{cache: make(map[string]*alias)} dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
drivers = map[string]DriverType{ drivers = map[string]DriverType{
"mysql": DR_MySQL, "mysql": DRMySQL,
"postgres": DR_Postgres, "postgres": DRPostgres,
"sqlite3": DR_Sqlite, "sqlite3": DRSqlite,
"tidb": DRTiDB,
} }
dbBasers = map[DriverType]dbBaser{ dbBasers = map[DriverType]dbBaser{
DR_MySQL: newdbBaseMysql(), DRMySQL: newdbBaseMysql(),
DR_Sqlite: newdbBaseSqlite(), DRSqlite: newdbBaseSqlite(),
DR_Oracle: newdbBaseMysql(), DROracle: newdbBaseOracle(),
DR_Postgres: newdbBasePostgres(), DRPostgres: newdbBasePostgres(),
DRTiDB: newdbBaseTidb(),
} }
) )
@ -119,7 +123,7 @@ func detectTZ(al *alias) {
} }
switch al.Driver { switch al.Driver {
case DR_MySQL: case DRMySQL:
row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)") row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
var tz string var tz string
row.Scan(&tz) row.Scan(&tz)
@ -147,10 +151,10 @@ func detectTZ(al *alias) {
al.Engine = "INNODB" al.Engine = "INNODB"
} }
case DR_Sqlite: case DRSqlite:
al.TZ = time.UTC al.TZ = time.UTC
case DR_Postgres: case DRPostgres:
row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')") row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
var tz string var tz string
row.Scan(&tz) row.Scan(&tz)
@ -188,12 +192,13 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
return al, nil return al, nil
} }
// AddAliasWthDB add a aliasName for the drivename
func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
_, err := addAliasWthDB(aliasName, driverName, db) _, err := addAliasWthDB(aliasName, driverName, db)
return err return err
} }
// Setting the database connect params. Use the database driver self dataSource args. // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args.
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
var ( var (
err error err error
@ -236,7 +241,7 @@ end:
return err return err
} }
// Register a database driver use specify driver name, this can be definition the driver is which database type. // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
func RegisterDriver(driverName string, typ DriverType) error { func RegisterDriver(driverName string, typ DriverType) error {
if t, ok := drivers[driverName]; ok == false { if t, ok := drivers[driverName]; ok == false {
drivers[driverName] = typ drivers[driverName] = typ
@ -248,7 +253,7 @@ func RegisterDriver(driverName string, typ DriverType) error {
return nil return nil
} }
// Change the database default used timezone // SetDataBaseTZ Change the database default used timezone
func SetDataBaseTZ(aliasName string, tz *time.Location) error { func SetDataBaseTZ(aliasName string, tz *time.Location) error {
if al, ok := dataBaseCache.get(aliasName); ok { if al, ok := dataBaseCache.get(aliasName); ok {
al.TZ = tz al.TZ = tz
@ -258,14 +263,14 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error {
return nil return nil
} }
// Change the max idle conns for *sql.DB, use specify database alias name // SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name
func SetMaxIdleConns(aliasName string, maxIdleConns int) { func SetMaxIdleConns(aliasName string, maxIdleConns int) {
al := getDbAlias(aliasName) al := getDbAlias(aliasName)
al.MaxIdleConns = maxIdleConns al.MaxIdleConns = maxIdleConns
al.DB.SetMaxIdleConns(maxIdleConns) al.DB.SetMaxIdleConns(maxIdleConns)
} }
// Change the max open conns for *sql.DB, use specify database alias name // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
func SetMaxOpenConns(aliasName string, maxOpenConns int) { func SetMaxOpenConns(aliasName string, maxOpenConns int) {
al := getDbAlias(aliasName) al := getDbAlias(aliasName)
al.MaxOpenConns = maxOpenConns al.MaxOpenConns = maxOpenConns
@ -275,7 +280,7 @@ func SetMaxOpenConns(aliasName string, maxOpenConns int) {
} }
} }
// Get *sql.DB from registered database by db alias name. // GetDB Get *sql.DB from registered database by db alias name.
// Use "default" as alias name if you not set. // Use "default" as alias name if you not set.
func GetDB(aliasNames ...string) (*sql.DB, error) { func GetDB(aliasNames ...string) (*sql.DB, error) {
var name string var name string
@ -284,9 +289,9 @@ func GetDB(aliasNames ...string) (*sql.DB, error) {
} else { } else {
name = "default" name = "default"
} }
if al, ok := dataBaseCache.get(name); ok { al, ok := dataBaseCache.get(name)
if ok {
return al.DB, nil return al.DB, nil
} else {
return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
} }
return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
} }

View File

@ -67,7 +67,7 @@ type dbBaseMysql struct {
var _ dbBaser = new(dbBaseMysql) var _ dbBaser = new(dbBaseMysql)
// get mysql operator. // get mysql operator.
func (d *dbBaseMysql) OperatorSql(operator string) string { func (d *dbBaseMysql) OperatorSQL(operator string) string {
return mysqlOperators[operator] return mysqlOperators[operator]
} }

View File

@ -66,7 +66,7 @@ type dbBasePostgres struct {
var _ dbBaser = new(dbBasePostgres) var _ dbBaser = new(dbBasePostgres)
// get postgresql operator. // get postgresql operator.
func (d *dbBasePostgres) OperatorSql(operator string) string { func (d *dbBasePostgres) OperatorSQL(operator string) string {
return postgresOperators[operator] return postgresOperators[operator]
} }
@ -101,7 +101,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
num := 0 num := 0
for _, c := range q { for _, c := range q {
if c == '?' { if c == '?' {
num += 1 num++
} }
} }
if num == 0 { if num == 0 {
@ -114,7 +114,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
if c == '?' { if c == '?' {
data = append(data, '$') data = append(data, '$')
data = append(data, []byte(strconv.Itoa(num))...) data = append(data, []byte(strconv.Itoa(num))...)
num += 1 num++
} else { } else {
data = append(data, c) data = append(data, c)
} }

View File

@ -66,7 +66,7 @@ type dbBaseSqlite struct {
var _ dbBaser = new(dbBaseSqlite) var _ dbBaser = new(dbBaseSqlite)
// get sqlite operator. // get sqlite operator.
func (d *dbBaseSqlite) OperatorSql(operator string) string { func (d *dbBaseSqlite) OperatorSQL(operator string) string {
return sqliteOperators[operator] return sqliteOperators[operator]
} }

View File

@ -164,7 +164,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
} }
// generate join string. // generate join string.
func (t *dbTables) getJoinSql() (join string) { func (t *dbTables) getJoinSQL() (join string) {
Q := t.base.TableQuote() Q := t.base.TableQuote()
for _, jt := range t.tables { for _, jt := range t.tables {
@ -186,7 +186,7 @@ func (t *dbTables) getJoinSql() (join string) {
table = jt.mi.table table = jt.mi.table
switch { switch {
case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
c1 = jt.fi.mi.fields.pk.column c1 = jt.fi.mi.fields.pk.column
for _, ffi := range jt.mi.fields.fieldsRel { for _, ffi := range jt.mi.fields.fieldsRel {
if jt.fi.mi == ffi.relModelInfo { if jt.fi.mi == ffi.relModelInfo {
@ -220,7 +220,7 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
) )
num := len(exprs) - 1 num := len(exprs) - 1
names := make([]string, 0) var names []string
inner := true inner := true
@ -326,7 +326,7 @@ loopFor:
} }
// generate condition sql. // generate condition sql.
func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) { func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
if cond == nil || cond.IsEmpty() { if cond == nil || cond.IsEmpty() {
return return
} }
@ -347,7 +347,7 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
where += "NOT " where += "NOT "
} }
if p.isCond { if p.isCond {
w, ps := t.getCondSql(p.cond, true, tz) w, ps := t.getCondSQL(p.cond, true, tz)
if w != "" { if w != "" {
w = fmt.Sprintf("( %s) ", w) w = fmt.Sprintf("( %s) ", w)
} }
@ -372,12 +372,12 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
operator = "exact" operator = "exact"
} }
operSql, args := t.base.GenerateOperatorSql(mi, fi, operator, p.args, tz) operSQL, args := t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz)
leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
t.base.GenerateOperatorLeftCol(fi, operator, &leftCol) t.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
where += fmt.Sprintf("%s %s ", leftCol, operSql) where += fmt.Sprintf("%s %s ", leftCol, operSQL)
params = append(params, args...) params = append(params, args...)
} }
@ -390,8 +390,32 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
return return
} }
// generate group sql.
func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
if len(groups) == 0 {
return
}
Q := t.base.TableQuote()
groupSqls := make([]string, 0, len(groups))
for _, group := range groups {
exprs := strings.Split(group, ExprSep)
index, _, fi, suc := t.parseExprs(t.mi, exprs)
if suc == false {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
}
groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q))
}
groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", "))
return
}
// generate order sql. // generate order sql.
func (t *dbTables) getOrderSql(orders []string) (orderSql string) { func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
if len(orders) == 0 { if len(orders) == 0 {
return return
} }
@ -415,12 +439,12 @@ func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc)) orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc))
} }
orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
return return
} }
// generate limit sql. // generate limit sql.
func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) { func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) {
if limit == 0 { if limit == 0 {
limit = int64(DefaultRowsLimit) limit = int64(DefaultRowsLimit)
} }

63
orm/db_tidb.go Normal file
View File

@ -0,0 +1,63 @@
// Copyright 2015 TiDB Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
)
// mysql dbBaser implementation.
type dbBaseTidb struct {
dbBase
}
var _ dbBaser = new(dbBaseTidb)
// get mysql operator.
func (d *dbBaseTidb) OperatorSQL(operator string) string {
return mysqlOperators[operator]
}
// get mysql table field types.
func (d *dbBaseTidb) DbTypes() map[string]string {
return mysqlTypes
}
// show table sql for mysql.
func (d *dbBaseTidb) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
}
// show columns sql of table for mysql.
func (d *dbBaseTidb) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
}
// execute sql to check index exist.
func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
var cnt int
row.Scan(&cnt)
return cnt > 0
}
// create new mysql dbBaser.
func newdbBaseTidb() dbBaser {
b := new(dbBaseTidb)
b.ins = b
return b
}

View File

@ -24,9 +24,8 @@ import (
func getDbAlias(name string) *alias { func getDbAlias(name string) *alias {
if al, ok := dataBaseCache.get(name); ok { if al, ok := dataBaseCache.get(name); ok {
return al return al
} else {
panic(fmt.Errorf("unknown DataBase alias name %s", name))
} }
panic(fmt.Errorf("unknown DataBase alias name %s", name))
} }
// get pk column info. // get pk column info.
@ -80,19 +79,19 @@ outFor:
var err error var err error
if len(v) >= 19 { if len(v) >= 19 {
s := v[:19] s := v[:19]
t, err = time.ParseInLocation(format_DateTime, s, DefaultTimeLoc) t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc)
} else { } else {
s := v s := v
if len(v) > 10 { if len(v) > 10 {
s = v[:10] s = v[:10]
} }
t, err = time.ParseInLocation(format_Date, s, tz) t, err = time.ParseInLocation(formatDate, s, tz)
} }
if err == nil { if err == nil {
if fi.fieldType == TypeDateField { if fi.fieldType == TypeDateField {
v = t.In(tz).Format(format_Date) v = t.In(tz).Format(formatDate)
} else { } else {
v = t.In(tz).Format(format_DateTime) v = t.In(tz).Format(formatDateTime)
} }
} }
} }
@ -137,9 +136,9 @@ outFor:
case reflect.Struct: case reflect.Struct:
if v, ok := arg.(time.Time); ok { if v, ok := arg.(time.Time); ok {
if fi != nil && fi.fieldType == TypeDateField { if fi != nil && fi.fieldType == TypeDateField {
arg = v.In(tz).Format(format_Date) arg = v.In(tz).Format(formatDate)
} else { } else {
arg = v.In(tz).Format(format_DateTime) arg = v.In(tz).Format(formatDateTime)
} }
} else { } else {
typ := val.Type() typ := val.Type()

View File

@ -19,10 +19,10 @@ import (
) )
const ( const (
od_CASCADE = "cascade" odCascade = "cascade"
od_SET_NULL = "set_null" odSetNULL = "set_null"
od_SET_DEFAULT = "set_default" odSetDefault = "set_default"
od_DO_NOTHING = "do_nothing" odDoNothing = "do_nothing"
defaultStructTagName = "orm" defaultStructTagName = "orm"
defaultStructTagDelim = ";" defaultStructTagDelim = ";"
) )
@ -113,7 +113,7 @@ func (mc *_modelCache) clean() {
mc.done = false mc.done = false
} }
// Clean model cache. Then you can re-RegisterModel. // ResetModelCache Clean model cache. Then you can re-RegisterModel.
// Common use this api for test case. // Common use this api for test case.
func ResetModelCache() { func ResetModelCache() {
modelCache.clean() modelCache.clean()

View File

@ -51,19 +51,16 @@ func registerModel(prefix string, model interface{}) {
} }
info := newModelInfo(val) info := newModelInfo(val)
if info.fields.pk == nil { if info.fields.pk == nil {
outFor: outFor:
for _, fi := range info.fields.fieldsDB { for _, fi := range info.fields.fieldsDB {
if fi.name == "Id" { if strings.ToLower(fi.name) == "id" {
if fi.sf.Tag.Get(defaultStructTagName) == "" { switch fi.addrValue.Elem().Kind() {
switch fi.addrValue.Elem().Kind() { case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: fi.auto = true
fi.auto = true fi.pk = true
fi.pk = true info.fields.pk = fi
info.fields.pk = fi break outFor
break outFor
}
} }
} }
} }
@ -269,7 +266,10 @@ func bootStrap() {
if found == false { if found == false {
mForC: mForC:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] {
if ffi.relModelInfo == mi { conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough ||
fi.relTable != "" && fi.relTable == ffi.relTable ||
fi.relThrough == "" && fi.relTable == ""
if ffi.relModelInfo == mi && conditions {
found = true found = true
fi.reverseField = ffi.reverseFieldInfoTwo.name fi.reverseField = ffi.reverseFieldInfoTwo.name
@ -298,12 +298,12 @@ end:
} }
} }
// register models // RegisterModel register models
func RegisterModel(models ...interface{}) { func RegisterModel(models ...interface{}) {
RegisterModelWithPrefix("", models...) RegisterModelWithPrefix("", models...)
} }
// register models with a prefix // RegisterModelWithPrefix register models with a prefix
func RegisterModelWithPrefix(prefix string, models ...interface{}) { func RegisterModelWithPrefix(prefix string, models ...interface{}) {
if modelCache.done { if modelCache.done {
panic(fmt.Errorf("RegisterModel must be run before BootStrap")) panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
@ -314,7 +314,7 @@ func RegisterModelWithPrefix(prefix string, models ...interface{}) {
} }
} }
// bootrap models. // BootStrap bootrap models.
// make all model parsed and can not add more models // make all model parsed and can not add more models
func BootStrap() { func BootStrap() {
if modelCache.done { if modelCache.done {

View File

@ -15,49 +15,28 @@
package orm package orm
import ( import (
"errors"
"fmt" "fmt"
"strconv" "strconv"
"time" "time"
) )
// Define the Type enum
const ( const (
// bool
TypeBooleanField = 1 << iota TypeBooleanField = 1 << iota
// string
TypeCharField TypeCharField
// string
TypeTextField TypeTextField
// time.Time
TypeDateField TypeDateField
// time.Time
TypeDateTimeField TypeDateTimeField
// int8
TypeBitField TypeBitField
// int16
TypeSmallIntegerField TypeSmallIntegerField
// int32
TypeIntegerField TypeIntegerField
// int64
TypeBigIntegerField TypeBigIntegerField
// uint8
TypePositiveBitField TypePositiveBitField
// uint16
TypePositiveSmallIntegerField TypePositiveSmallIntegerField
// uint32
TypePositiveIntegerField TypePositiveIntegerField
// uint64
TypePositiveBigIntegerField TypePositiveBigIntegerField
// float64
TypeFloatField TypeFloatField
// float64
TypeDecimalField TypeDecimalField
RelForeignKey RelForeignKey
RelOneToOne RelOneToOne
RelManyToMany RelManyToMany
@ -65,6 +44,7 @@ const (
RelReverseMany RelReverseMany
) )
// Define some logic enum
const ( const (
IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5 IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5
IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 8 << 9 IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 8 << 9
@ -72,25 +52,30 @@ const (
IsFieldType = ^-RelReverseMany<<1 + 1 IsFieldType = ^-RelReverseMany<<1 + 1
) )
// A true/false field. // BooleanField A true/false field.
type BooleanField bool type BooleanField bool
// Value return the BooleanField
func (e BooleanField) Value() bool { func (e BooleanField) Value() bool {
return bool(e) return bool(e)
} }
// Set will set the BooleanField
func (e *BooleanField) Set(d bool) { func (e *BooleanField) Set(d bool) {
*e = BooleanField(d) *e = BooleanField(d)
} }
// String format the Bool to string
func (e *BooleanField) String() string { func (e *BooleanField) String() string {
return strconv.FormatBool(e.Value()) return strconv.FormatBool(e.Value())
} }
// FieldType return BooleanField the type
func (e *BooleanField) FieldType() int { func (e *BooleanField) FieldType() int {
return TypeBooleanField return TypeBooleanField
} }
// SetRaw set the interface to bool
func (e *BooleanField) SetRaw(value interface{}) error { func (e *BooleanField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case bool: case bool:
@ -102,56 +87,65 @@ func (e *BooleanField) SetRaw(value interface{}) error {
} }
return err return err
default: default:
return errors.New(fmt.Sprintf("<BooleanField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<BooleanField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return the current value
func (e *BooleanField) RawValue() interface{} { func (e *BooleanField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify the BooleanField implement the Fielder interface
var _ Fielder = new(BooleanField) var _ Fielder = new(BooleanField)
// A string field // CharField A string field
// required values tag: size // required values tag: size
// The size is enforced at the database level and in modelss validation. // The size is enforced at the database level and in modelss validation.
// eg: `orm:"size(120)"` // eg: `orm:"size(120)"`
type CharField string type CharField string
// Value return the CharField's Value
func (e CharField) Value() string { func (e CharField) Value() string {
return string(e) return string(e)
} }
// Set CharField value
func (e *CharField) Set(d string) { func (e *CharField) Set(d string) {
*e = CharField(d) *e = CharField(d)
} }
// String return the CharField
func (e *CharField) String() string { func (e *CharField) String() string {
return e.Value() return e.Value()
} }
// FieldType return the enum type
func (e *CharField) FieldType() int { func (e *CharField) FieldType() int {
return TypeCharField return TypeCharField
} }
// SetRaw set the interface to string
func (e *CharField) SetRaw(value interface{}) error { func (e *CharField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case string: case string:
e.Set(d) e.Set(d)
default: default:
return errors.New(fmt.Sprintf("<CharField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<CharField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return the CharField value
func (e *CharField) RawValue() interface{} { func (e *CharField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify CharField implement Fielder
var _ Fielder = new(CharField) var _ Fielder = new(CharField)
// A date, represented in go by a time.Time instance. // DateField A date, represented in go by a time.Time instance.
// only date values like 2006-01-02 // only date values like 2006-01-02
// Has a few extra, optional attr tag: // Has a few extra, optional attr tag:
// //
@ -166,106 +160,125 @@ var _ Fielder = new(CharField)
// eg: `orm:"auto_now"` or `orm:"auto_now_add"` // eg: `orm:"auto_now"` or `orm:"auto_now_add"`
type DateField time.Time type DateField time.Time
// Value return the time.Time
func (e DateField) Value() time.Time { func (e DateField) Value() time.Time {
return time.Time(e) return time.Time(e)
} }
// Set set the DateField's value
func (e *DateField) Set(d time.Time) { func (e *DateField) Set(d time.Time) {
*e = DateField(d) *e = DateField(d)
} }
// String convert datatime to string
func (e *DateField) String() string { func (e *DateField) String() string {
return e.Value().String() return e.Value().String()
} }
// FieldType return enum type Date
func (e *DateField) FieldType() int { func (e *DateField) FieldType() int {
return TypeDateField return TypeDateField
} }
// SetRaw convert the interface to time.Time. Allow string and time.Time
func (e *DateField) SetRaw(value interface{}) error { func (e *DateField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case time.Time: case time.Time:
e.Set(d) e.Set(d)
case string: case string:
v, err := timeParse(d, format_Date) v, err := timeParse(d, formatDate)
if err != nil { if err != nil {
e.Set(v) e.Set(v)
} }
return err return err
default: default:
return errors.New(fmt.Sprintf("<DateField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<DateField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return Date value
func (e *DateField) RawValue() interface{} { func (e *DateField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify DateField implement fielder interface
var _ Fielder = new(DateField) var _ Fielder = new(DateField)
// A date, represented in go by a time.Time instance. // DateTimeField A date, represented in go by a time.Time instance.
// datetime values like 2006-01-02 15:04:05 // datetime values like 2006-01-02 15:04:05
// Takes the same extra arguments as DateField. // Takes the same extra arguments as DateField.
type DateTimeField time.Time type DateTimeField time.Time
// Value return the datatime value
func (e DateTimeField) Value() time.Time { func (e DateTimeField) Value() time.Time {
return time.Time(e) return time.Time(e)
} }
// Set set the time.Time to datatime
func (e *DateTimeField) Set(d time.Time) { func (e *DateTimeField) Set(d time.Time) {
*e = DateTimeField(d) *e = DateTimeField(d)
} }
// String return the time's String
func (e *DateTimeField) String() string { func (e *DateTimeField) String() string {
return e.Value().String() return e.Value().String()
} }
// FieldType return the enum TypeDateTimeField
func (e *DateTimeField) FieldType() int { func (e *DateTimeField) FieldType() int {
return TypeDateTimeField return TypeDateTimeField
} }
// SetRaw convert the string or time.Time to DateTimeField
func (e *DateTimeField) SetRaw(value interface{}) error { func (e *DateTimeField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case time.Time: case time.Time:
e.Set(d) e.Set(d)
case string: case string:
v, err := timeParse(d, format_DateTime) v, err := timeParse(d, formatDateTime)
if err != nil { if err != nil {
e.Set(v) e.Set(v)
} }
return err return err
default: default:
return errors.New(fmt.Sprintf("<DateTimeField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<DateTimeField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return the datatime value
func (e *DateTimeField) RawValue() interface{} { func (e *DateTimeField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify datatime implement fielder
var _ Fielder = new(DateTimeField) var _ Fielder = new(DateTimeField)
// A floating-point number represented in go by a float32 value. // FloatField A floating-point number represented in go by a float32 value.
type FloatField float64 type FloatField float64
// Value return the FloatField value
func (e FloatField) Value() float64 { func (e FloatField) Value() float64 {
return float64(e) return float64(e)
} }
// Set the Float64
func (e *FloatField) Set(d float64) { func (e *FloatField) Set(d float64) {
*e = FloatField(d) *e = FloatField(d)
} }
// String return the string
func (e *FloatField) String() string { func (e *FloatField) String() string {
return ToStr(e.Value(), -1, 32) return ToStr(e.Value(), -1, 32)
} }
// FieldType return the enum type
func (e *FloatField) FieldType() int { func (e *FloatField) FieldType() int {
return TypeFloatField return TypeFloatField
} }
// SetRaw converter interface Float64 float32 or string to FloatField
func (e *FloatField) SetRaw(value interface{}) error { func (e *FloatField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case float32: case float32:
@ -278,36 +291,43 @@ func (e *FloatField) SetRaw(value interface{}) error {
e.Set(v) e.Set(v)
} }
default: default:
return errors.New(fmt.Sprintf("<FloatField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<FloatField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return the FloatField value
func (e *FloatField) RawValue() interface{} { func (e *FloatField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify FloatField implement Fielder
var _ Fielder = new(FloatField) var _ Fielder = new(FloatField)
// -32768 to 32767 // SmallIntegerField -32768 to 32767
type SmallIntegerField int16 type SmallIntegerField int16
// Value return int16 value
func (e SmallIntegerField) Value() int16 { func (e SmallIntegerField) Value() int16 {
return int16(e) return int16(e)
} }
// Set the SmallIntegerField value
func (e *SmallIntegerField) Set(d int16) { func (e *SmallIntegerField) Set(d int16) {
*e = SmallIntegerField(d) *e = SmallIntegerField(d)
} }
// String convert smallint to string
func (e *SmallIntegerField) String() string { func (e *SmallIntegerField) String() string {
return ToStr(e.Value()) return ToStr(e.Value())
} }
// FieldType return enum type SmallIntegerField
func (e *SmallIntegerField) FieldType() int { func (e *SmallIntegerField) FieldType() int {
return TypeSmallIntegerField return TypeSmallIntegerField
} }
// SetRaw convert interface int16/string to int16
func (e *SmallIntegerField) SetRaw(value interface{}) error { func (e *SmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case int16: case int16:
@ -318,36 +338,43 @@ func (e *SmallIntegerField) SetRaw(value interface{}) error {
e.Set(v) e.Set(v)
} }
default: default:
return errors.New(fmt.Sprintf("<SmallIntegerField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<SmallIntegerField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return smallint value
func (e *SmallIntegerField) RawValue() interface{} { func (e *SmallIntegerField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify SmallIntegerField implement Fielder
var _ Fielder = new(SmallIntegerField) var _ Fielder = new(SmallIntegerField)
// -2147483648 to 2147483647 // IntegerField -2147483648 to 2147483647
type IntegerField int32 type IntegerField int32
// Value return the int32
func (e IntegerField) Value() int32 { func (e IntegerField) Value() int32 {
return int32(e) return int32(e)
} }
// Set IntegerField value
func (e *IntegerField) Set(d int32) { func (e *IntegerField) Set(d int32) {
*e = IntegerField(d) *e = IntegerField(d)
} }
// String convert Int32 to string
func (e *IntegerField) String() string { func (e *IntegerField) String() string {
return ToStr(e.Value()) return ToStr(e.Value())
} }
// FieldType return the enum type
func (e *IntegerField) FieldType() int { func (e *IntegerField) FieldType() int {
return TypeIntegerField return TypeIntegerField
} }
// SetRaw convert interface int32/string to int32
func (e *IntegerField) SetRaw(value interface{}) error { func (e *IntegerField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case int32: case int32:
@ -358,36 +385,43 @@ func (e *IntegerField) SetRaw(value interface{}) error {
e.Set(v) e.Set(v)
} }
default: default:
return errors.New(fmt.Sprintf("<IntegerField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<IntegerField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return IntegerField value
func (e *IntegerField) RawValue() interface{} { func (e *IntegerField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify IntegerField implement Fielder
var _ Fielder = new(IntegerField) var _ Fielder = new(IntegerField)
// -9223372036854775808 to 9223372036854775807. // BigIntegerField -9223372036854775808 to 9223372036854775807.
type BigIntegerField int64 type BigIntegerField int64
// Value return int64
func (e BigIntegerField) Value() int64 { func (e BigIntegerField) Value() int64 {
return int64(e) return int64(e)
} }
// Set the BigIntegerField value
func (e *BigIntegerField) Set(d int64) { func (e *BigIntegerField) Set(d int64) {
*e = BigIntegerField(d) *e = BigIntegerField(d)
} }
// String convert BigIntegerField to string
func (e *BigIntegerField) String() string { func (e *BigIntegerField) String() string {
return ToStr(e.Value()) return ToStr(e.Value())
} }
// FieldType return enum type
func (e *BigIntegerField) FieldType() int { func (e *BigIntegerField) FieldType() int {
return TypeBigIntegerField return TypeBigIntegerField
} }
// SetRaw convert interface int64/string to int64
func (e *BigIntegerField) SetRaw(value interface{}) error { func (e *BigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case int64: case int64:
@ -398,36 +432,43 @@ func (e *BigIntegerField) SetRaw(value interface{}) error {
e.Set(v) e.Set(v)
} }
default: default:
return errors.New(fmt.Sprintf("<BigIntegerField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<BigIntegerField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return BigIntegerField value
func (e *BigIntegerField) RawValue() interface{} { func (e *BigIntegerField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify BigIntegerField implement Fielder
var _ Fielder = new(BigIntegerField) var _ Fielder = new(BigIntegerField)
// 0 to 65535 // PositiveSmallIntegerField 0 to 65535
type PositiveSmallIntegerField uint16 type PositiveSmallIntegerField uint16
// Value return uint16
func (e PositiveSmallIntegerField) Value() uint16 { func (e PositiveSmallIntegerField) Value() uint16 {
return uint16(e) return uint16(e)
} }
// Set PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) Set(d uint16) { func (e *PositiveSmallIntegerField) Set(d uint16) {
*e = PositiveSmallIntegerField(d) *e = PositiveSmallIntegerField(d)
} }
// String convert uint16 to string
func (e *PositiveSmallIntegerField) String() string { func (e *PositiveSmallIntegerField) String() string {
return ToStr(e.Value()) return ToStr(e.Value())
} }
// FieldType return enum type
func (e *PositiveSmallIntegerField) FieldType() int { func (e *PositiveSmallIntegerField) FieldType() int {
return TypePositiveSmallIntegerField return TypePositiveSmallIntegerField
} }
// SetRaw convert Interface uint16/string to uint16
func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error { func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case uint16: case uint16:
@ -438,36 +479,43 @@ func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
e.Set(v) e.Set(v)
} }
default: default:
return errors.New(fmt.Sprintf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue returns PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) RawValue() interface{} { func (e *PositiveSmallIntegerField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify PositiveSmallIntegerField implement Fielder
var _ Fielder = new(PositiveSmallIntegerField) var _ Fielder = new(PositiveSmallIntegerField)
// 0 to 4294967295 // PositiveIntegerField 0 to 4294967295
type PositiveIntegerField uint32 type PositiveIntegerField uint32
// Value return PositiveIntegerField value. Uint32
func (e PositiveIntegerField) Value() uint32 { func (e PositiveIntegerField) Value() uint32 {
return uint32(e) return uint32(e)
} }
// Set the PositiveIntegerField value
func (e *PositiveIntegerField) Set(d uint32) { func (e *PositiveIntegerField) Set(d uint32) {
*e = PositiveIntegerField(d) *e = PositiveIntegerField(d)
} }
// String convert PositiveIntegerField to string
func (e *PositiveIntegerField) String() string { func (e *PositiveIntegerField) String() string {
return ToStr(e.Value()) return ToStr(e.Value())
} }
// FieldType return enum type
func (e *PositiveIntegerField) FieldType() int { func (e *PositiveIntegerField) FieldType() int {
return TypePositiveIntegerField return TypePositiveIntegerField
} }
// SetRaw convert interface uint32/string to Uint32
func (e *PositiveIntegerField) SetRaw(value interface{}) error { func (e *PositiveIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case uint32: case uint32:
@ -478,36 +526,43 @@ func (e *PositiveIntegerField) SetRaw(value interface{}) error {
e.Set(v) e.Set(v)
} }
default: default:
return errors.New(fmt.Sprintf("<PositiveIntegerField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<PositiveIntegerField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return the PositiveIntegerField Value
func (e *PositiveIntegerField) RawValue() interface{} { func (e *PositiveIntegerField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify PositiveIntegerField implement Fielder
var _ Fielder = new(PositiveIntegerField) var _ Fielder = new(PositiveIntegerField)
// 0 to 18446744073709551615 // PositiveBigIntegerField 0 to 18446744073709551615
type PositiveBigIntegerField uint64 type PositiveBigIntegerField uint64
// Value return uint64
func (e PositiveBigIntegerField) Value() uint64 { func (e PositiveBigIntegerField) Value() uint64 {
return uint64(e) return uint64(e)
} }
// Set PositiveBigIntegerField value
func (e *PositiveBigIntegerField) Set(d uint64) { func (e *PositiveBigIntegerField) Set(d uint64) {
*e = PositiveBigIntegerField(d) *e = PositiveBigIntegerField(d)
} }
// String convert PositiveBigIntegerField to string
func (e *PositiveBigIntegerField) String() string { func (e *PositiveBigIntegerField) String() string {
return ToStr(e.Value()) return ToStr(e.Value())
} }
// FieldType return enum type
func (e *PositiveBigIntegerField) FieldType() int { func (e *PositiveBigIntegerField) FieldType() int {
return TypePositiveIntegerField return TypePositiveIntegerField
} }
// SetRaw convert interface uint64/string to Uint64
func (e *PositiveBigIntegerField) SetRaw(value interface{}) error { func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case uint64: case uint64:
@ -518,48 +573,57 @@ func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
e.Set(v) e.Set(v)
} }
default: default:
return errors.New(fmt.Sprintf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return PositiveBigIntegerField value
func (e *PositiveBigIntegerField) RawValue() interface{} { func (e *PositiveBigIntegerField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify PositiveBigIntegerField implement Fielder
var _ Fielder = new(PositiveBigIntegerField) var _ Fielder = new(PositiveBigIntegerField)
// A large text field. // TextField A large text field.
type TextField string type TextField string
// Value return TextField value
func (e TextField) Value() string { func (e TextField) Value() string {
return string(e) return string(e)
} }
// Set the TextField value
func (e *TextField) Set(d string) { func (e *TextField) Set(d string) {
*e = TextField(d) *e = TextField(d)
} }
// String convert TextField to string
func (e *TextField) String() string { func (e *TextField) String() string {
return e.Value() return e.Value()
} }
// FieldType return enum type
func (e *TextField) FieldType() int { func (e *TextField) FieldType() int {
return TypeTextField return TypeTextField
} }
// SetRaw convert interface string to string
func (e *TextField) SetRaw(value interface{}) error { func (e *TextField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case string: case string:
e.Set(d) e.Set(d)
default: default:
return errors.New(fmt.Sprintf("<TextField.SetRaw> unknown value `%s`", value)) return fmt.Errorf("<TextField.SetRaw> unknown value `%s`", value)
} }
return nil return nil
} }
// RawValue return TextField value
func (e *TextField) RawValue() interface{} { func (e *TextField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify TextField implement Fielder
var _ Fielder = new(TextField) var _ Fielder = new(TextField)

View File

@ -119,8 +119,8 @@ type fieldInfo struct {
colDefault bool colDefault bool
initial StrTo initial StrTo
size int size int
auto_now bool autoNow bool
auto_now_add bool autoNowAdd bool
rel bool rel bool
reverse bool reverse bool
reverseField string reverseField string
@ -223,6 +223,11 @@ checkType:
break checkType break checkType
case "many": case "many":
fieldType = RelReverseMany fieldType = RelReverseMany
if tv := tags["rel_table"]; tv != "" {
fi.relTable = tv
} else if tv := tags["rel_through"]; tv != "" {
fi.relThrough = tv
}
break checkType break checkType
default: default:
err = fmt.Errorf("error") err = fmt.Errorf("error")
@ -309,20 +314,20 @@ checkType:
if fi.rel && fi.dbcol { if fi.rel && fi.dbcol {
switch onDelete { switch onDelete {
case od_CASCADE, od_DO_NOTHING: case odCascade, odDoNothing:
case od_SET_DEFAULT: case odSetDefault:
if initial.Exist() == false { if initial.Exist() == false {
err = errors.New("on_delete: set_default need set field a default value") err = errors.New("on_delete: set_default need set field a default value")
goto end goto end
} }
case od_SET_NULL: case odSetNULL:
if fi.null == false { if fi.null == false {
err = errors.New("on_delete: set_null need set field null") err = errors.New("on_delete: set_null need set field null")
goto end goto end
} }
default: default:
if onDelete == "" { if onDelete == "" {
onDelete = od_CASCADE onDelete = odCascade
} else { } else {
err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete) err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete)
goto end goto end
@ -350,9 +355,9 @@ checkType:
fi.unique = false fi.unique = false
case TypeDateField, TypeDateTimeField: case TypeDateField, TypeDateTimeField:
if attrs["auto_now"] { if attrs["auto_now"] {
fi.auto_now = true fi.autoNow = true
} else if attrs["auto_now_add"] { } else if attrs["auto_now_add"] {
fi.auto_now_add = true fi.autoNowAdd = true
} }
case TypeFloatField: case TypeFloatField:
case TypeDecimalField: case TypeDecimalField:

View File

@ -15,7 +15,6 @@
package orm package orm
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
@ -72,13 +71,13 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
added := info.fields.Add(fi) added := info.fields.Add(fi)
if added == false { if added == false {
err = errors.New(fmt.Sprintf("duplicate column name: %s", fi.column)) err = fmt.Errorf("duplicate column name: %s", fi.column)
break break
} }
if fi.pk { if fi.pk {
if info.fields.pk != nil { if info.fields.pk != nil {
err = errors.New(fmt.Sprintf("one model must have one pk field only")) err = fmt.Errorf("one model must have one pk field only")
break break
} else { } else {
info.fields.pk = fi info.fields.pk = fi

View File

@ -25,6 +25,9 @@ import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
// As tidb can't use go get, so disable the tidb testing now
// _ "github.com/pingcap/tidb"
) )
// A slice string field. // A slice string field.
@ -76,21 +79,21 @@ func (e *SliceStringField) RawValue() interface{} {
var _ Fielder = new(SliceStringField) var _ Fielder = new(SliceStringField)
// A json field. // A json field.
type JsonField struct { type JSONField struct {
Name string Name string
Data string Data string
} }
func (e *JsonField) String() string { func (e *JSONField) String() string {
data, _ := json.Marshal(e) data, _ := json.Marshal(e)
return string(data) return string(data)
} }
func (e *JsonField) FieldType() int { func (e *JSONField) FieldType() int {
return TypeTextField return TypeTextField
} }
func (e *JsonField) SetRaw(value interface{}) error { func (e *JSONField) SetRaw(value interface{}) error {
switch d := value.(type) { switch d := value.(type) {
case string: case string:
return json.Unmarshal([]byte(d), e) return json.Unmarshal([]byte(d), e)
@ -99,14 +102,14 @@ func (e *JsonField) SetRaw(value interface{}) error {
} }
} }
func (e *JsonField) RawValue() interface{} { func (e *JSONField) RawValue() interface{} {
return e.String() return e.String()
} }
var _ Fielder = new(JsonField) var _ Fielder = new(JSONField)
type Data struct { type Data struct {
Id int ID int `orm:"column(id)"`
Boolean bool Boolean bool
Char string `orm:"size(50)"` Char string `orm:"size(50)"`
Text string `orm:"type(text)"` Text string `orm:"type(text)"`
@ -130,7 +133,7 @@ type Data struct {
} }
type DataNull struct { type DataNull struct {
Id int ID int `orm:"column(id)"`
Boolean bool `orm:"null"` Boolean bool `orm:"null"`
Char string `orm:"null;size(50)"` Char string `orm:"null;size(50)"`
Text string `orm:"null;type(text)"` Text string `orm:"null;type(text)"`
@ -193,7 +196,7 @@ type Float32 float64
type Float64 float64 type Float64 float64
type DataCustom struct { type DataCustom struct {
Id int ID int `orm:"column(id)"`
Boolean Boolean Boolean Boolean
Char string `orm:"size(50)"` Char string `orm:"size(50)"`
Text string `orm:"type(text)"` Text string `orm:"type(text)"`
@ -216,28 +219,28 @@ type DataCustom struct {
// only for mysql // only for mysql
type UserBig struct { type UserBig struct {
Id uint64 ID uint64 `orm:"column(id)"`
Name string Name string
} }
type User struct { type User struct {
Id int ID int `orm:"column(id)"`
UserName string `orm:"size(30);unique"` UserName string `orm:"size(30);unique"`
Email string `orm:"size(100)"` Email string `orm:"size(100)"`
Password string `orm:"size(100)"` Password string `orm:"size(100)"`
Status int16 `orm:"column(Status)"` Status int16 `orm:"column(Status)"`
IsStaff bool IsStaff bool
IsActive bool `orm:"default(1)"` IsActive bool `orm:"default(true)"`
Created time.Time `orm:"auto_now_add;type(date)"` Created time.Time `orm:"auto_now_add;type(date)"`
Updated time.Time `orm:"auto_now"` Updated time.Time `orm:"auto_now"`
Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` Profile *Profile `orm:"null;rel(one);on_delete(set_null)"`
Posts []*Post `orm:"reverse(many)" json:"-"` Posts []*Post `orm:"reverse(many)" json:"-"`
ShouldSkip string `orm:"-"` ShouldSkip string `orm:"-"`
Nums int Nums int
Langs SliceStringField `orm:"size(100)"` Langs SliceStringField `orm:"size(100)"`
Extra JsonField `orm:"type(text)"` Extra JSONField `orm:"type(text)"`
unexport bool `orm:"-"` unexport bool `orm:"-"`
unexport_ bool unexportBool bool
} }
func (u *User) TableIndex() [][]string { func (u *User) TableIndex() [][]string {
@ -259,7 +262,7 @@ func NewUser() *User {
} }
type Profile struct { type Profile struct {
Id int ID int `orm:"column(id)"`
Age int16 Age int16
Money float64 Money float64
User *User `orm:"reverse(one)" json:"-"` User *User `orm:"reverse(one)" json:"-"`
@ -276,7 +279,7 @@ func NewProfile() *Profile {
} }
type Post struct { type Post struct {
Id int ID int `orm:"column(id)"`
User *User `orm:"rel(fk)"` User *User `orm:"rel(fk)"`
Title string `orm:"size(60)"` Title string `orm:"size(60)"`
Content string `orm:"type(text)"` Content string `orm:"type(text)"`
@ -297,7 +300,7 @@ func NewPost() *Post {
} }
type Tag struct { type Tag struct {
Id int ID int `orm:"column(id)"`
Name string `orm:"size(30)"` Name string `orm:"size(30)"`
BestPost *Post `orm:"rel(one);null"` BestPost *Post `orm:"rel(one);null"`
Posts []*Post `orm:"reverse(many)" json:"-"` Posts []*Post `orm:"reverse(many)" json:"-"`
@ -309,7 +312,7 @@ func NewTag() *Tag {
} }
type PostTags struct { type PostTags struct {
Id int ID int `orm:"column(id)"`
Post *Post `orm:"rel(fk)"` Post *Post `orm:"rel(fk)"`
Tag *Tag `orm:"rel(fk)"` Tag *Tag `orm:"rel(fk)"`
} }
@ -319,7 +322,7 @@ func (m *PostTags) TableName() string {
} }
type Comment struct { type Comment struct {
Id int ID int `orm:"column(id)"`
Post *Post `orm:"rel(fk);column(post)"` Post *Post `orm:"rel(fk);column(post)"`
Content string `orm:"type(text)"` Content string `orm:"type(text)"`
Parent *Comment `orm:"null;rel(fk)"` Parent *Comment `orm:"null;rel(fk)"`
@ -331,6 +334,24 @@ func NewComment() *Comment {
return obj return obj
} }
type Group struct {
ID int `orm:"column(gid);size(32)"`
Name string
Permissions []*Permission `orm:"reverse(many)" json:"-"`
}
type Permission struct {
ID int `orm:"column(id)"`
Name string
Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.GroupPermissions)"`
}
type GroupPermissions struct {
ID int `orm:"column(id)"`
Group *Group `orm:"rel(fk)"`
Permission *Permission `orm:"rel(fk)"`
}
var DBARGS = struct { var DBARGS = struct {
Driver string Driver string
Source string Source string
@ -345,6 +366,7 @@ var (
IsMysql = DBARGS.Driver == "mysql" IsMysql = DBARGS.Driver == "mysql"
IsSqlite = DBARGS.Driver == "sqlite3" IsSqlite = DBARGS.Driver == "sqlite3"
IsPostgres = DBARGS.Driver == "postgres" IsPostgres = DBARGS.Driver == "postgres"
IsTidb = DBARGS.Driver == "tidb"
) )
var ( var (
@ -364,6 +386,7 @@ Default DB Drivers.
mysql: https://github.com/go-sql-driver/mysql mysql: https://github.com/go-sql-driver/mysql
sqlite3: https://github.com/mattn/go-sqlite3 sqlite3: https://github.com/mattn/go-sqlite3
postgres: https://github.com/lib/pq postgres: https://github.com/lib/pq
tidb: https://github.com/pingcap/tidb
usage: usage:
@ -371,6 +394,7 @@ go get -u github.com/astaxie/beego/orm
go get -u github.com/go-sql-driver/mysql go get -u github.com/go-sql-driver/mysql
go get -u github.com/mattn/go-sqlite3 go get -u github.com/mattn/go-sqlite3
go get -u github.com/lib/pq go get -u github.com/lib/pq
go get -u github.com/pingcap/tidb
#### MySQL #### MySQL
mysql -u root -e 'create database orm_test;' mysql -u root -e 'create database orm_test;'
@ -390,6 +414,12 @@ psql -c 'create database orm_test;' -U postgres
export ORM_DRIVER=postgres export ORM_DRIVER=postgres
export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
go test -v github.com/astaxie/beego/orm go test -v github.com/astaxie/beego/orm
#### TiDB
export ORM_DRIVER=tidb
export ORM_SOURCE='memory://test/test'
go test -v github.com/astaxie/beego/orm
`) `)
os.Exit(2) os.Exit(2)
} }
@ -397,7 +427,7 @@ go test -v github.com/astaxie/beego/orm
RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20)
alias := getDbAlias("default") alias := getDbAlias("default")
if alias.Driver == DR_MySQL { if alias.Driver == DRMySQL {
alias.Engine = "INNODB" alias.Engine = "INNODB"
} }

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Package orm provide ORM for MySQL/PostgreSQL/sqlite
// Simple Usage // Simple Usage
// //
// package main // package main
@ -59,12 +60,13 @@ import (
"time" "time"
) )
// DebugQueries define the debug
const ( const (
Debug_Queries = iota DebugQueries = iota
) )
// Define common vars
var ( var (
// DebugLevel = Debug_Queries
Debug = false Debug = false
DebugLog = NewLog(os.Stderr) DebugLog = NewLog(os.Stderr)
DefaultRowsLimit = 1000 DefaultRowsLimit = 1000
@ -79,7 +81,10 @@ var (
ErrNotImplement = errors.New("have not implement") ErrNotImplement = errors.New("have not implement")
) )
// Params stores the Params
type Params map[string]interface{} type Params map[string]interface{}
// ParamsList stores paramslist
type ParamsList []interface{} type ParamsList []interface{}
type orm struct { type orm struct {
@ -188,7 +193,7 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
o.setPk(mi, ind, id) o.setPk(mi, ind, id)
cnt += 1 cnt++
} }
} else { } else {
mi, _ := o.getMiInd(sind.Index(0).Interface(), false) mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
@ -489,7 +494,7 @@ func (o *orm) Driver() Driver {
return driver(o.alias.Name) return driver(o.alias.Name)
} }
// create new orm // NewOrm create new orm
func NewOrm() Ormer { func NewOrm() Ormer {
BootStrap() // execute only once BootStrap() // execute only once
@ -501,7 +506,7 @@ func NewOrm() Ormer {
return o return o
} }
// create a new ormer object with specify *sql.DB for query // NewOrmWithDB create a new ormer object with specify *sql.DB for query
func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
var al *alias var al *alias

View File

@ -19,6 +19,7 @@ import (
"strings" "strings"
) )
// ExprSep define the expression seperation
const ( const (
ExprSep = "__" ExprSep = "__"
) )
@ -32,19 +33,19 @@ type condValue struct {
isCond bool isCond bool
} }
// condition struct. // Condition struct.
// work for WHERE conditions. // work for WHERE conditions.
type Condition struct { type Condition struct {
params []condValue params []condValue
} }
// return new condition struct // NewCondition return new condition struct
func NewCondition() *Condition { func NewCondition() *Condition {
c := &Condition{} c := &Condition{}
return c return c
} }
// add expression to condition // And add expression to condition
func (c Condition) And(expr string, args ...interface{}) *Condition { func (c Condition) And(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.And> args cannot empty")) panic(fmt.Errorf("<Condition.And> args cannot empty"))
@ -53,7 +54,7 @@ func (c Condition) And(expr string, args ...interface{}) *Condition {
return &c return &c
} }
// add NOT expression to condition // AndNot add NOT expression to condition
func (c Condition) AndNot(expr string, args ...interface{}) *Condition { func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.AndNot> args cannot empty")) panic(fmt.Errorf("<Condition.AndNot> args cannot empty"))
@ -62,7 +63,7 @@ func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
return &c return &c
} }
// combine a condition to current condition // AndCond combine a condition to current condition
func (c *Condition) AndCond(cond *Condition) *Condition { func (c *Condition) AndCond(cond *Condition) *Condition {
c = c.clone() c = c.clone()
if c == cond { if c == cond {
@ -74,7 +75,7 @@ func (c *Condition) AndCond(cond *Condition) *Condition {
return c return c
} }
// add OR expression to condition // Or add OR expression to condition
func (c Condition) Or(expr string, args ...interface{}) *Condition { func (c Condition) Or(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.Or> args cannot empty")) panic(fmt.Errorf("<Condition.Or> args cannot empty"))
@ -83,7 +84,7 @@ func (c Condition) Or(expr string, args ...interface{}) *Condition {
return &c return &c
} }
// add OR NOT expression to condition // OrNot add OR NOT expression to condition
func (c Condition) OrNot(expr string, args ...interface{}) *Condition { func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.OrNot> args cannot empty")) panic(fmt.Errorf("<Condition.OrNot> args cannot empty"))
@ -92,7 +93,7 @@ func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
return &c return &c
} }
// combine a OR condition to current condition // OrCond combine a OR condition to current condition
func (c *Condition) OrCond(cond *Condition) *Condition { func (c *Condition) OrCond(cond *Condition) *Condition {
c = c.clone() c = c.clone()
if c == cond { if c == cond {
@ -104,12 +105,12 @@ func (c *Condition) OrCond(cond *Condition) *Condition {
return c return c
} }
// check the condition arguments are empty or not. // IsEmpty check the condition arguments are empty or not.
func (c *Condition) IsEmpty() bool { func (c *Condition) IsEmpty() bool {
return len(c.params) == 0 return len(c.params) == 0
} }
// clone a condition // clone clone a condition
func (c Condition) clone() *Condition { func (c Condition) clone() *Condition {
return &c return &c
} }

View File

@ -23,11 +23,12 @@ import (
"time" "time"
) )
// Log implement the log.Logger
type Log struct { type Log struct {
*log.Logger *log.Logger
} }
// set io.Writer to create a Logger. // NewLog set io.Writer to create a Logger.
func NewLog(out io.Writer) *Log { func NewLog(out io.Writer) *Log {
d := new(Log) d := new(Log)
d.Logger = log.New(out, "[ORM]", 1e9) d.Logger = log.New(out, "[ORM]", 1e9)
@ -41,7 +42,7 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error
if err != nil { if err != nil {
flag = "FAIL" flag = "FAIL"
} }
con := fmt.Sprintf(" - %s - [Queries/%s] - [%s / %11s / %7.1fms] - [%s]", t.Format(format_DateTime), alias.Name, flag, operaton, elsp, query) con := fmt.Sprintf(" - %s - [Queries/%s] - [%s / %11s / %7.1fms] - [%s]", t.Format(formatDateTime), alias.Name, flag, operaton, elsp, query)
cons := make([]string, 0, len(args)) cons := make([]string, 0, len(args))
for _, arg := range args { for _, arg := range args {
cons = append(cons, fmt.Sprintf("%v", arg)) cons = append(cons, fmt.Sprintf("%v", arg))

View File

@ -14,9 +14,7 @@
package orm package orm
import ( import "reflect"
"reflect"
)
// model to model struct // model to model struct
type queryM2M struct { type queryM2M struct {
@ -44,7 +42,21 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
dbase := orm.alias.DbBaser dbase := orm.alias.DbBaser
var models []interface{} var models []interface{}
var other_values []interface{}
var other_names []string
for _, colname := range mi.fields.dbcols {
if colname != mfi.column && colname != rfi.column && colname != fi.mi.fields.pk.column &&
mi.fields.columns[colname] != mi.fields.pk {
other_names = append(other_names, colname)
}
}
for i, md := range mds {
if reflect.Indirect(reflect.ValueOf(md)).Kind() != reflect.Struct && i > 0 {
other_values = append(other_values, md)
mds = append(mds[:i], mds[i+1:]...)
}
}
for _, md := range mds { for _, md := range mds {
val := reflect.ValueOf(md) val := reflect.ValueOf(md)
if val.Kind() == reflect.Slice || val.Kind() == reflect.Array { if val.Kind() == reflect.Slice || val.Kind() == reflect.Array {
@ -67,11 +79,9 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
names := []string{mfi.column, rfi.column} names := []string{mfi.column, rfi.column}
values := make([]interface{}, 0, len(models)*2) values := make([]interface{}, 0, len(models)*2)
for _, md := range models { for _, md := range models {
ind := reflect.Indirect(reflect.ValueOf(md)) ind := reflect.Indirect(reflect.ValueOf(md))
var v2 interface{} var v2 interface{}
if ind.Kind() != reflect.Struct { if ind.Kind() != reflect.Struct {
v2 = ind.Interface() v2 = ind.Interface()
@ -81,11 +91,11 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
panic(ErrMissPK) panic(ErrMissPK)
} }
} }
values = append(values, v1, v2) values = append(values, v1, v2)
} }
names = append(names, other_names...)
values = append(values, other_values...)
return dbase.InsertValue(orm.db, mi, true, names, values) return dbase.InsertValue(orm.db, mi, true, names, values)
} }

View File

@ -25,11 +25,12 @@ type colValue struct {
type operator int type operator int
// define Col operations
const ( const (
Col_Add operator = iota ColAdd operator = iota
Col_Minus ColMinus
Col_Multiply ColMultiply
Col_Except ColExcept
) )
// ColValue do the field raw changes. e.g Nums = Nums + 10. usage: // ColValue do the field raw changes. e.g Nums = Nums + 10. usage:
@ -38,7 +39,7 @@ const (
// } // }
func ColValue(opt operator, value interface{}) interface{} { func ColValue(opt operator, value interface{}) interface{} {
switch opt { switch opt {
case Col_Add, Col_Minus, Col_Multiply, Col_Except: case ColAdd, ColMinus, ColMultiply, ColExcept:
default: default:
panic(fmt.Errorf("orm.ColValue wrong operator")) panic(fmt.Errorf("orm.ColValue wrong operator"))
} }
@ -60,7 +61,9 @@ type querySet struct {
relDepth int relDepth int
limit int64 limit int64
offset int64 offset int64
groups []string
orders []string orders []string
distinct bool
orm *orm orm *orm
} }
@ -105,6 +108,12 @@ func (o querySet) Offset(offset interface{}) QuerySeter {
return &o return &o
} }
// add GROUP expression
func (o querySet) GroupBy(exprs ...string) QuerySeter {
o.groups = exprs
return &o
}
// add ORDER expression. // add ORDER expression.
// "column" means ASC, "-column" means DESC. // "column" means ASC, "-column" means DESC.
func (o querySet) OrderBy(exprs ...string) QuerySeter { func (o querySet) OrderBy(exprs ...string) QuerySeter {
@ -112,24 +121,30 @@ func (o querySet) OrderBy(exprs ...string) QuerySeter {
return &o return &o
} }
// add DISTINCT to SELECT
func (o querySet) Distinct() QuerySeter {
o.distinct = true
return &o
}
// set relation model to query together. // set relation model to query together.
// it will query relation models and assign to parent model. // it will query relation models and assign to parent model.
func (o querySet) RelatedSel(params ...interface{}) QuerySeter { func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
if len(params) == 0 { if len(params) == 0 {
o.relDepth = DefaultRelsDepth o.relDepth = DefaultRelsDepth
} else { } else {
for _, p := range params { for _, p := range params {
switch val := p.(type) { switch val := p.(type) {
case string: case string:
o.related = append(o.related, val) o.related = append(o.related, val)
case int: case int:
o.relDepth = val o.relDepth = val
default: default:
panic(fmt.Errorf("<QuerySeter.RelatedSel> wrong param kind: %v", val)) panic(fmt.Errorf("<QuerySeter.RelatedSel> wrong param kind: %v", val))
} }
} }
} }
return &o return &o
} }
// set condition to QuerySeter. // set condition to QuerySeter.

View File

@ -165,14 +165,14 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
if str != "" { if str != "" {
if len(str) >= 19 { if len(str) >= 19 {
str = str[:19] str = str[:19]
t, err := time.ParseInLocation(format_DateTime, str, o.orm.alias.TZ) t, err := time.ParseInLocation(formatDateTime, str, o.orm.alias.TZ)
if err == nil { if err == nil {
t = t.In(DefaultTimeLoc) t = t.In(DefaultTimeLoc)
ind.Set(reflect.ValueOf(t)) ind.Set(reflect.ValueOf(t))
} }
} else if len(str) >= 10 { } else if len(str) >= 10 {
str = str[:10] str = str[:10]
t, err := time.ParseInLocation(format_Date, str, DefaultTimeLoc) t, err := time.ParseInLocation(formatDate, str, DefaultTimeLoc)
if err == nil { if err == nil {
ind.Set(reflect.ValueOf(t)) ind.Set(reflect.ValueOf(t))
} }
@ -255,12 +255,13 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr
// query data and map to container // query data and map to container
func (o *rawSet) QueryRow(containers ...interface{}) error { func (o *rawSet) QueryRow(containers ...interface{}) error {
refs := make([]interface{}, 0, len(containers)) var (
sInds := make([]reflect.Value, 0) refs = make([]interface{}, 0, len(containers))
eTyps := make([]reflect.Type, 0) sInds []reflect.Value
eTyps []reflect.Type
sMi *modelInfo
)
structMode := false structMode := false
var sMi *modelInfo
for _, container := range containers { for _, container := range containers {
val := reflect.ValueOf(container) val := reflect.ValueOf(container)
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
@ -385,12 +386,13 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
// query data rows and map to container // query data rows and map to container
func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
refs := make([]interface{}, 0, len(containers)) var (
sInds := make([]reflect.Value, 0) refs = make([]interface{}, 0, len(containers))
eTyps := make([]reflect.Type, 0) sInds []reflect.Value
eTyps []reflect.Type
sMi *modelInfo
)
structMode := false structMode := false
var sMi *modelInfo
for _, container := range containers { for _, container := range containers {
val := reflect.ValueOf(container) val := reflect.ValueOf(container)
sInd := reflect.Indirect(val) sInd := reflect.Indirect(val)
@ -557,10 +559,9 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er
args := getFlatParams(nil, o.args, o.orm.alias.TZ) args := getFlatParams(nil, o.args, o.orm.alias.TZ)
var rs *sql.Rows var rs *sql.Rows
if r, err := o.orm.db.Query(query, args...); err != nil { rs, err := o.orm.db.Query(query, args...)
if err != nil {
return 0, err return 0, err
} else {
rs = r
} }
defer rs.Close() defer rs.Close()
@ -574,30 +575,30 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er
for rs.Next() { for rs.Next() {
if cnt == 0 { if cnt == 0 {
if columns, err := rs.Columns(); err != nil { columns, err := rs.Columns()
if err != nil {
return 0, err return 0, err
}
if len(needCols) > 0 {
indexs = make([]int, 0, len(needCols))
} else { } else {
indexs = make([]int, 0, len(columns))
}
cols = columns
refs = make([]interface{}, len(cols))
for i := range refs {
var ref sql.NullString
refs[i] = &ref
if len(needCols) > 0 { if len(needCols) > 0 {
indexs = make([]int, 0, len(needCols)) for _, c := range needCols {
} else { if c == cols[i] {
indexs = make([]int, 0, len(columns)) indexs = append(indexs, i)
}
cols = columns
refs = make([]interface{}, len(cols))
for i := range refs {
var ref sql.NullString
refs[i] = &ref
if len(needCols) > 0 {
for _, c := range needCols {
if c == cols[i] {
indexs = append(indexs, i)
}
} }
} else {
indexs = append(indexs, i)
} }
} else {
indexs = append(indexs, i)
} }
} }
} }
@ -684,11 +685,9 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
args := getFlatParams(nil, o.args, o.orm.alias.TZ) args := getFlatParams(nil, o.args, o.orm.alias.TZ)
var rs *sql.Rows rs, err := o.orm.db.Query(query, args...)
if r, err := o.orm.db.Query(query, args...); err != nil { if err != nil {
return 0, err return 0, err
} else {
rs = r
} }
defer rs.Close() defer rs.Close()
@ -706,32 +705,29 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
for rs.Next() { for rs.Next() {
if cnt == 0 { if cnt == 0 {
if columns, err := rs.Columns(); err != nil { columns, err := rs.Columns()
if err != nil {
return 0, err return 0, err
} else { }
cols = columns cols = columns
refs = make([]interface{}, len(cols)) refs = make([]interface{}, len(cols))
for i := range refs { for i := range refs {
if keyCol == cols[i] { if keyCol == cols[i] {
keyIndex = i keyIndex = i
}
if typ == 1 || keyIndex == i {
var ref sql.NullString
refs[i] = &ref
} else {
var ref interface{}
refs[i] = &ref
}
if valueCol == cols[i] {
valueIndex = i
}
} }
if typ == 1 || keyIndex == i {
if keyIndex == -1 || valueIndex == -1 { var ref sql.NullString
panic(fmt.Errorf("<RawSeter> RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol)) refs[i] = &ref
} else {
var ref interface{}
refs[i] = &ref
} }
if valueCol == cols[i] {
valueIndex = i
}
}
if keyIndex == -1 || valueIndex == -1 {
panic(fmt.Errorf("<RawSeter> RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol))
} }
} }

View File

@ -31,13 +31,26 @@ import (
var _ = os.PathSeparator var _ = os.PathSeparator
var ( var (
test_Date = format_Date + " -0700" testDate = formatDate + " -0700"
test_DateTime = format_DateTime + " -0700" testDateTime = formatDateTime + " -0700"
) )
func ValuesCompare(is bool, a interface{}, args ...interface{}) (err error, ok bool) { type argAny []interface{}
// get interface by index from interface slice
func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
if i >= 0 && i < len(a) {
r = a[i]
}
if len(args) > 0 {
r = args[0]
}
return
}
func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err error) {
if len(args) == 0 { if len(args) == 0 {
return fmt.Errorf("miss args"), false return false, fmt.Errorf("miss args")
} }
b := args[0] b := args[0]
arg := argAny(args) arg := argAny(args)
@ -71,21 +84,21 @@ func ValuesCompare(is bool, a interface{}, args ...interface{}) (err error, ok b
wrongArg: wrongArg:
if err != nil { if err != nil {
return err, false return false, err
} }
return nil, true return true, nil
} }
func AssertIs(a interface{}, args ...interface{}) error { func AssertIs(a interface{}, args ...interface{}) error {
if err, ok := ValuesCompare(true, a, args...); ok == false { if ok, err := ValuesCompare(true, a, args...); ok == false {
return err return err
} }
return nil return nil
} }
func AssertNot(a interface{}, args ...interface{}) error { func AssertNot(a interface{}, args ...interface{}) error {
if err, ok := ValuesCompare(false, a, args...); ok == false { if ok, err := ValuesCompare(false, a, args...); ok == false {
return err return err
} }
return nil return nil
@ -171,8 +184,11 @@ func TestSyncDb(t *testing.T) {
RegisterModel(new(Comment)) RegisterModel(new(Comment))
RegisterModel(new(UserBig)) RegisterModel(new(UserBig))
RegisterModel(new(PostTags)) RegisterModel(new(PostTags))
RegisterModel(new(Group))
RegisterModel(new(Permission))
RegisterModel(new(GroupPermissions))
err := RunSyncdb("default", true, false) err := RunSyncdb("default", true, Debug)
throwFail(t, err) throwFail(t, err)
modelCache.clean() modelCache.clean()
@ -187,6 +203,9 @@ func TestRegisterModels(t *testing.T) {
RegisterModel(new(Comment)) RegisterModel(new(Comment))
RegisterModel(new(UserBig)) RegisterModel(new(UserBig))
RegisterModel(new(PostTags)) RegisterModel(new(PostTags))
RegisterModel(new(Group))
RegisterModel(new(Permission))
RegisterModel(new(GroupPermissions))
BootStrap() BootStrap()
@ -208,7 +227,7 @@ func TestModelSyntax(t *testing.T) {
} }
} }
var Data_Values = map[string]interface{}{ var DataValues = map[string]interface{}{
"Boolean": true, "Boolean": true,
"Char": "char", "Char": "char",
"Text": "text", "Text": "text",
@ -235,7 +254,7 @@ func TestDataTypes(t *testing.T) {
d := Data{} d := Data{}
ind := reflect.Indirect(reflect.ValueOf(&d)) ind := reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values { for name, value := range DataValues {
e := ind.FieldByName(name) e := ind.FieldByName(name)
e.Set(reflect.ValueOf(value)) e.Set(reflect.ValueOf(value))
} }
@ -244,22 +263,22 @@ func TestDataTypes(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(id, 1)) throwFail(t, AssertIs(id, 1))
d = Data{Id: 1} d = Data{ID: 1}
err = dORM.Read(&d) err = dORM.Read(&d)
throwFail(t, err) throwFail(t, err)
ind = reflect.Indirect(reflect.ValueOf(&d)) ind = reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values { for name, value := range DataValues {
e := ind.FieldByName(name) e := ind.FieldByName(name)
vu := e.Interface() vu := e.Interface()
switch name { switch name {
case "Date": case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date) vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date) value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
case "DateTime": case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
} }
throwFail(t, AssertIs(vu == value, true), value, vu) throwFail(t, AssertIs(vu == value, true), value, vu)
} }
@ -278,7 +297,7 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(id, 1)) throwFail(t, AssertIs(id, 1))
d = DataNull{Id: 1} d = DataNull{ID: 1}
err = dORM.Read(&d) err = dORM.Read(&d)
throwFail(t, err) throwFail(t, err)
@ -309,7 +328,7 @@ func TestNullDataTypes(t *testing.T) {
_, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec()
throwFail(t, err) throwFail(t, err)
d = DataNull{Id: 2} d = DataNull{ID: 2}
err = dORM.Read(&d) err = dORM.Read(&d)
throwFail(t, err) throwFail(t, err)
@ -362,7 +381,7 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(id, 3)) throwFail(t, AssertIs(id, 3))
d = DataNull{Id: 3} d = DataNull{ID: 3}
err = dORM.Read(&d) err = dORM.Read(&d)
throwFail(t, err) throwFail(t, err)
@ -402,7 +421,7 @@ func TestDataCustomTypes(t *testing.T) {
d := DataCustom{} d := DataCustom{}
ind := reflect.Indirect(reflect.ValueOf(&d)) ind := reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values { for name, value := range DataValues {
e := ind.FieldByName(name) e := ind.FieldByName(name)
if !e.IsValid() { if !e.IsValid() {
continue continue
@ -414,13 +433,13 @@ func TestDataCustomTypes(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(id, 1)) throwFail(t, AssertIs(id, 1))
d = DataCustom{Id: 1} d = DataCustom{ID: 1}
err = dORM.Read(&d) err = dORM.Read(&d)
throwFail(t, err) throwFail(t, err)
ind = reflect.Indirect(reflect.ValueOf(&d)) ind = reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values { for name, value := range DataValues {
e := ind.FieldByName(name) e := ind.FieldByName(name)
if !e.IsValid() { if !e.IsValid() {
continue continue
@ -451,7 +470,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(id, 1)) throwFail(t, AssertIs(id, 1))
u := &User{Id: user.Id} u := &User{ID: user.ID}
err = dORM.Read(u) err = dORM.Read(u)
throwFail(t, err) throwFail(t, err)
@ -461,8 +480,8 @@ func TestCRUD(t *testing.T) {
throwFail(t, AssertIs(u.Status, 3)) throwFail(t, AssertIs(u.Status, 3))
throwFail(t, AssertIs(u.IsStaff, true)) throwFail(t, AssertIs(u.IsStaff, true))
throwFail(t, AssertIs(u.IsActive, true)) throwFail(t, AssertIs(u.IsActive, true))
throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), test_Date)) throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), testDate))
throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), test_DateTime)) throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), testDateTime))
user.UserName = "astaxie" user.UserName = "astaxie"
user.Profile = profile user.Profile = profile
@ -470,11 +489,11 @@ func TestCRUD(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
u = &User{Id: user.Id} u = &User{ID: user.ID}
err = dORM.Read(u) err = dORM.Read(u)
throwFailNow(t, err) throwFailNow(t, err)
throwFail(t, AssertIs(u.UserName, "astaxie")) throwFail(t, AssertIs(u.UserName, "astaxie"))
throwFail(t, AssertIs(u.Profile.Id, profile.Id)) throwFail(t, AssertIs(u.Profile.ID, profile.ID))
u = &User{UserName: "astaxie", Password: "pass"} u = &User{UserName: "astaxie", Password: "pass"}
err = dORM.Read(u, "UserName") err = dORM.Read(u, "UserName")
@ -487,7 +506,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
u = &User{Id: user.Id} u = &User{ID: user.ID}
err = dORM.Read(u) err = dORM.Read(u)
throwFailNow(t, err) throwFailNow(t, err)
throwFail(t, AssertIs(u.UserName, "QQ")) throwFail(t, AssertIs(u.UserName, "QQ"))
@ -497,7 +516,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
u = &User{Id: user.Id} u = &User{ID: user.ID}
err = dORM.Read(u) err = dORM.Read(u)
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(true, u.Profile == nil)) throwFail(t, AssertIs(true, u.Profile == nil))
@ -506,7 +525,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
u = &User{Id: 100} u = &User{ID: 100}
err = dORM.Read(u) err = dORM.Read(u)
throwFail(t, AssertIs(err, ErrNoRows)) throwFail(t, AssertIs(err, ErrNoRows))
@ -516,7 +535,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(id, 1)) throwFail(t, AssertIs(id, 1))
ub = UserBig{Id: 1} ub = UserBig{ID: 1}
err = dORM.Read(&ub) err = dORM.Read(&ub)
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(ub.Name, "name")) throwFail(t, AssertIs(ub.Name, "name"))
@ -586,7 +605,7 @@ func TestInsertTestData(t *testing.T) {
throwFail(t, AssertIs(id, 4)) throwFail(t, AssertIs(id, 4))
tags := []*Tag{ tags := []*Tag{
{Name: "golang", BestPost: &Post{Id: 2}}, {Name: "golang", BestPost: &Post{ID: 2}},
{Name: "example"}, {Name: "example"},
{Name: "format"}, {Name: "format"},
{Name: "c++"}, {Name: "c++"},
@ -635,10 +654,47 @@ The program—and web server—godoc processes Go source files to extract docume
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(id > 0, true)) throwFail(t, AssertIs(id > 0, true))
} }
permissions := []*Permission{
{Name: "writePosts"},
{Name: "readComments"},
{Name: "readPosts"},
}
groups := []*Group{
{
Name: "admins",
Permissions: []*Permission{permissions[0], permissions[1], permissions[2]},
},
{
Name: "users",
Permissions: []*Permission{permissions[1], permissions[2]},
},
}
for _, permission := range permissions {
id, err := dORM.Insert(permission)
throwFail(t, err)
throwFail(t, AssertIs(id > 0, true))
}
for _, group := range groups {
_, err := dORM.Insert(group)
throwFail(t, err)
throwFail(t, AssertIs(id > 0, true))
num := len(group.Permissions)
if num > 0 {
nums, err := dORM.QueryM2M(group, "permissions").Add(group.Permissions)
throwFailNow(t, err)
throwFailNow(t, AssertIs(nums, num))
}
}
} }
func TestCustomField(t *testing.T) { func TestCustomField(t *testing.T) {
user := User{Id: 2} user := User{ID: 2}
err := dORM.Read(&user) err := dORM.Read(&user)
throwFailNow(t, err) throwFailNow(t, err)
@ -648,7 +704,7 @@ func TestCustomField(t *testing.T) {
_, err = dORM.Update(&user, "Langs", "Extra") _, err = dORM.Update(&user, "Langs", "Extra")
throwFailNow(t, err) throwFailNow(t, err)
user = User{Id: 2} user = User{ID: 2}
err = dORM.Read(&user) err = dORM.Read(&user)
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(len(user.Langs), 2)) throwFailNow(t, AssertIs(len(user.Langs), 2))
@ -702,7 +758,7 @@ func TestOperators(t *testing.T) {
var shouldNum int var shouldNum int
if IsSqlite { if IsSqlite || IsTidb {
shouldNum = 2 shouldNum = 2
} else { } else {
shouldNum = 0 shouldNum = 0
@ -740,7 +796,7 @@ func TestOperators(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
if IsSqlite { if IsSqlite || IsTidb {
shouldNum = 1 shouldNum = 1
} else { } else {
shouldNum = 0 shouldNum = 0
@ -758,7 +814,7 @@ func TestOperators(t *testing.T) {
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 2)) throwFail(t, AssertIs(num, 2))
if IsSqlite { if IsSqlite || IsTidb {
shouldNum = 2 shouldNum = 2
} else { } else {
shouldNum = 0 shouldNum = 0
@ -889,9 +945,9 @@ func TestAll(t *testing.T) {
throwFailNow(t, AssertIs(users2[0].UserName, "slene")) throwFailNow(t, AssertIs(users2[0].UserName, "slene"))
throwFailNow(t, AssertIs(users2[1].UserName, "astaxie")) throwFailNow(t, AssertIs(users2[1].UserName, "astaxie"))
throwFailNow(t, AssertIs(users2[2].UserName, "nobody")) throwFailNow(t, AssertIs(users2[2].UserName, "nobody"))
throwFailNow(t, AssertIs(users2[0].Id, 0)) throwFailNow(t, AssertIs(users2[0].ID, 0))
throwFailNow(t, AssertIs(users2[1].Id, 0)) throwFailNow(t, AssertIs(users2[1].ID, 0))
throwFailNow(t, AssertIs(users2[2].Id, 0)) throwFailNow(t, AssertIs(users2[2].ID, 0))
throwFailNow(t, AssertIs(users2[0].Profile == nil, false)) throwFailNow(t, AssertIs(users2[0].Profile == nil, false))
throwFailNow(t, AssertIs(users2[1].Profile == nil, false)) throwFailNow(t, AssertIs(users2[1].Profile == nil, false))
throwFailNow(t, AssertIs(users2[2].Profile == nil, true)) throwFailNow(t, AssertIs(users2[2].Profile == nil, true))
@ -986,6 +1042,10 @@ func TestValuesFlat(t *testing.T) {
} }
func TestRelatedSel(t *testing.T) { func TestRelatedSel(t *testing.T) {
if IsTidb {
// Skip it. TiDB does not support relation now.
return
}
qs := dORM.QueryTable("user") qs := dORM.QueryTable("user")
num, err := qs.Filter("profile__age", 28).Count() num, err := qs.Filter("profile__age", 28).Count()
throwFail(t, err) throwFail(t, err)
@ -1112,7 +1172,7 @@ func TestReverseQuery(t *testing.T) {
func TestLoadRelated(t *testing.T) { func TestLoadRelated(t *testing.T) {
// load reverse foreign key // load reverse foreign key
user := User{Id: 3} user := User{ID: 3}
err := dORM.Read(&user) err := dORM.Read(&user)
throwFailNow(t, err) throwFailNow(t, err)
@ -1121,7 +1181,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 2)) throwFailNow(t, AssertIs(num, 2))
throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(len(user.Posts), 2))
throwFailNow(t, AssertIs(user.Posts[0].User.Id, 3)) throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3))
num, err = dORM.LoadRelated(&user, "Posts", true) num, err = dORM.LoadRelated(&user, "Posts", true)
throwFailNow(t, err) throwFailNow(t, err)
@ -1143,8 +1203,8 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting"))
// load reverse one to one // load reverse one to one
profile := Profile{Id: 3} profile := Profile{ID: 3}
profile.BestPost = &Post{Id: 2} profile.BestPost = &Post{ID: 2}
num, err = dORM.Update(&profile, "BestPost") num, err = dORM.Update(&profile, "BestPost")
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(num, 1))
@ -1183,7 +1243,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(user.Profile.BestPost == nil, false)) throwFailNow(t, AssertIs(user.Profile.BestPost == nil, false))
throwFailNow(t, AssertIs(user.Profile.BestPost.Title, "Examples")) throwFailNow(t, AssertIs(user.Profile.BestPost.Title, "Examples"))
post := Post{Id: 2} post := Post{ID: 2}
// load rel foreign key // load rel foreign key
err = dORM.Read(&post) err = dORM.Read(&post)
@ -1204,7 +1264,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(post.User.Profile.Age, 30)) throwFailNow(t, AssertIs(post.User.Profile.Age, 30))
// load rel m2m // load rel m2m
post = Post{Id: 2} post = Post{ID: 2}
err = dORM.Read(&post) err = dORM.Read(&post)
throwFailNow(t, err) throwFailNow(t, err)
@ -1224,7 +1284,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(post.Tags[0].BestPost.User.UserName, "astaxie")) throwFailNow(t, AssertIs(post.Tags[0].BestPost.User.UserName, "astaxie"))
// load reverse m2m // load reverse m2m
tag := Tag{Id: 1} tag := Tag{ID: 1}
err = dORM.Read(&tag) err = dORM.Read(&tag)
throwFailNow(t, err) throwFailNow(t, err)
@ -1233,19 +1293,19 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3)) throwFailNow(t, AssertIs(num, 3))
throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction"))
throwFailNow(t, AssertIs(tag.Posts[0].User.Id, 2)) throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2))
throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true)) throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true))
num, err = dORM.LoadRelated(&tag, "Posts", true) num, err = dORM.LoadRelated(&tag, "Posts", true)
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3)) throwFailNow(t, AssertIs(num, 3))
throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction"))
throwFailNow(t, AssertIs(tag.Posts[0].User.Id, 2)) throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2))
throwFailNow(t, AssertIs(tag.Posts[0].User.UserName, "slene")) throwFailNow(t, AssertIs(tag.Posts[0].User.UserName, "slene"))
} }
func TestQueryM2M(t *testing.T) { func TestQueryM2M(t *testing.T) {
post := Post{Id: 4} post := Post{ID: 4}
m2m := dORM.QueryM2M(&post, "Tags") m2m := dORM.QueryM2M(&post, "Tags")
tag1 := []*Tag{{Name: "TestTag1"}, {Name: "TestTag2"}} tag1 := []*Tag{{Name: "TestTag1"}, {Name: "TestTag2"}}
@ -1319,7 +1379,7 @@ func TestQueryM2M(t *testing.T) {
for _, post := range posts { for _, post := range posts {
p := post.(*Post) p := post.(*Post)
p.User = &User{Id: 1} p.User = &User{ID: 1}
_, err := dORM.Insert(post) _, err := dORM.Insert(post)
throwFailNow(t, err) throwFailNow(t, err)
} }
@ -1394,6 +1454,18 @@ func TestQueryRelate(t *testing.T) {
// throwFailNow(t, AssertIs(num, 2)) // throwFailNow(t, AssertIs(num, 2))
} }
func TestPkManyRelated(t *testing.T) {
permission := &Permission{Name: "readPosts"}
err := dORM.Read(permission, "Name")
throwFailNow(t, err)
var groups []*Group
qs := dORM.QueryTable("Group")
num, err := qs.Filter("Permissions__Permission", permission.ID).All(&groups)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 2))
}
func TestPrepareInsert(t *testing.T) { func TestPrepareInsert(t *testing.T) {
qs := dORM.QueryTable("user") qs := dORM.QueryTable("user")
i, err := qs.PrepareInsert() i, err := qs.PrepareInsert()
@ -1459,10 +1531,10 @@ func TestRawQueryRow(t *testing.T) {
Decimal float64 Decimal float64
) )
data_values := make(map[string]interface{}, len(Data_Values)) dataValues := make(map[string]interface{}, len(DataValues))
for k, v := range Data_Values { for k, v := range DataValues {
data_values[strings.ToLower(k)] = v dataValues[strings.ToLower(k)] = v
} }
Q := dDbBaser.TableQuote() Q := dDbBaser.TableQuote()
@ -1488,14 +1560,14 @@ func TestRawQueryRow(t *testing.T) {
throwFail(t, AssertIs(id, 1)) throwFail(t, AssertIs(id, 1))
case "date": case "date":
v = v.(time.Time).In(DefaultTimeLoc) v = v.(time.Time).In(DefaultTimeLoc)
value := data_values[col].(time.Time).In(DefaultTimeLoc) value := dataValues[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, test_Date)) throwFail(t, AssertIs(v, value, testDate))
case "datetime": case "datetime":
v = v.(time.Time).In(DefaultTimeLoc) v = v.(time.Time).In(DefaultTimeLoc)
value := data_values[col].(time.Time).In(DefaultTimeLoc) value := dataValues[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, test_DateTime)) throwFail(t, AssertIs(v, value, testDateTime))
default: default:
throwFail(t, AssertIs(v, data_values[col])) throwFail(t, AssertIs(v, dataValues[col]))
} }
} }
@ -1529,16 +1601,16 @@ func TestQueryRows(t *testing.T) {
ind := reflect.Indirect(reflect.ValueOf(datas[0])) ind := reflect.Indirect(reflect.ValueOf(datas[0]))
for name, value := range Data_Values { for name, value := range DataValues {
e := ind.FieldByName(name) e := ind.FieldByName(name)
vu := e.Interface() vu := e.Interface()
switch name { switch name {
case "Date": case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date) vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date) value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
case "DateTime": case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
} }
throwFail(t, AssertIs(vu == value, true), value, vu) throwFail(t, AssertIs(vu == value, true), value, vu)
} }
@ -1553,16 +1625,16 @@ func TestQueryRows(t *testing.T) {
ind = reflect.Indirect(reflect.ValueOf(datas2[0])) ind = reflect.Indirect(reflect.ValueOf(datas2[0]))
for name, value := range Data_Values { for name, value := range DataValues {
e := ind.FieldByName(name) e := ind.FieldByName(name)
vu := e.Interface() vu := e.Interface()
switch name { switch name {
case "Date": case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date) vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date) value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
case "DateTime": case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
} }
throwFail(t, AssertIs(vu == value, true), value, vu) throwFail(t, AssertIs(vu == value, true), value, vu)
} }
@ -1699,25 +1771,25 @@ func TestUpdate(t *testing.T) {
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{ num, err = qs.Filter("user_name", "slene").Update(Params{
"Nums": ColValue(Col_Add, 100), "Nums": ColValue(ColAdd, 100),
}) })
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{ num, err = qs.Filter("user_name", "slene").Update(Params{
"Nums": ColValue(Col_Minus, 50), "Nums": ColValue(ColMinus, 50),
}) })
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{ num, err = qs.Filter("user_name", "slene").Update(Params{
"Nums": ColValue(Col_Multiply, 3), "Nums": ColValue(ColMultiply, 3),
}) })
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{ num, err = qs.Filter("user_name", "slene").Update(Params{
"Nums": ColValue(Col_Except, 5), "Nums": ColValue(ColExcept, 5),
}) })
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
@ -1838,15 +1910,15 @@ func TestReadOrCreate(t *testing.T) {
throwFail(t, AssertIs(u.Status, 7)) throwFail(t, AssertIs(u.Status, 7))
throwFail(t, AssertIs(u.IsStaff, false)) throwFail(t, AssertIs(u.IsStaff, false))
throwFail(t, AssertIs(u.IsActive, true)) throwFail(t, AssertIs(u.IsActive, true))
throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), test_Date)) throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), testDate))
throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), test_DateTime)) throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), testDateTime))
nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"} nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"}
created, pk, err = dORM.ReadOrCreate(nu, "UserName") created, pk, err = dORM.ReadOrCreate(nu, "UserName")
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(created, false)) throwFail(t, AssertIs(created, false))
throwFail(t, AssertIs(nu.Id, u.Id)) throwFail(t, AssertIs(nu.ID, u.ID))
throwFail(t, AssertIs(pk, u.Id)) throwFail(t, AssertIs(pk, u.ID))
throwFail(t, AssertIs(nu.UserName, u.UserName)) throwFail(t, AssertIs(nu.UserName, u.UserName))
throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above
throwFail(t, AssertIs(nu.Password, u.Password)) throwFail(t, AssertIs(nu.Password, u.Password))

View File

@ -16,6 +16,7 @@ package orm
import "errors" import "errors"
// QueryBuilder is the Query builder interface
type QueryBuilder interface { type QueryBuilder interface {
Select(fields ...string) QueryBuilder Select(fields ...string) QueryBuilder
From(tables ...string) QueryBuilder From(tables ...string) QueryBuilder
@ -43,15 +44,18 @@ type QueryBuilder interface {
String() string String() string
} }
// NewQueryBuilder return the QueryBuilder
func NewQueryBuilder(driver string) (qb QueryBuilder, err error) { func NewQueryBuilder(driver string) (qb QueryBuilder, err error) {
if driver == "mysql" { if driver == "mysql" {
qb = new(MySQLQueryBuilder) qb = new(MySQLQueryBuilder)
} else if driver == "tidb" {
qb = new(TiDBQueryBuilder)
} else if driver == "postgres" { } else if driver == "postgres" {
err = errors.New("postgres query builder is not supported yet!") err = errors.New("postgres query builder is not supported yet")
} else if driver == "sqlite" { } else if driver == "sqlite" {
err = errors.New("sqlite query builder is not supported yet!") err = errors.New("sqlite query builder is not supported yet")
} else { } else {
err = errors.New("unknown driver for query builder!") err = errors.New("unknown driver for query builder")
} }
return return
} }

View File

@ -20,134 +20,160 @@ import (
"strings" "strings"
) )
const COMMA_SPACE = ", " // CommaSpace is the seperation
const CommaSpace = ", "
// MySQLQueryBuilder is the SQL build
type MySQLQueryBuilder struct { type MySQLQueryBuilder struct {
Tokens []string Tokens []string
} }
// Select will join the fields
func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder { func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, COMMA_SPACE)) qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace))
return qb return qb
} }
// From join the tables
func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder { func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, COMMA_SPACE)) qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace))
return qb return qb
} }
// InnerJoin INNER JOIN the table
func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder { func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INNER JOIN", table) qb.Tokens = append(qb.Tokens, "INNER JOIN", table)
return qb return qb
} }
// LeftJoin LEFT JOIN the table
func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder { func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) qb.Tokens = append(qb.Tokens, "LEFT JOIN", table)
return qb return qb
} }
// RightJoin RIGHT JOIN the table
func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder { func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table)
return qb return qb
} }
// On join with on cond
func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder { func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "ON", cond) qb.Tokens = append(qb.Tokens, "ON", cond)
return qb return qb
} }
// Where join the Where cond
func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder { func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "WHERE", cond) qb.Tokens = append(qb.Tokens, "WHERE", cond)
return qb return qb
} }
// And join the and cond
func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder { func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "AND", cond) qb.Tokens = append(qb.Tokens, "AND", cond)
return qb return qb
} }
// Or join the or cond
func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder { func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OR", cond) qb.Tokens = append(qb.Tokens, "OR", cond)
return qb return qb
} }
// In join the IN (vals)
func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder { func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, COMMA_SPACE), ")") qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")")
return qb return qb
} }
// OrderBy join the Order by fields
func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder { func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, COMMA_SPACE)) qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace))
return qb return qb
} }
// Asc join the asc
func (qb *MySQLQueryBuilder) Asc() QueryBuilder { func (qb *MySQLQueryBuilder) Asc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "ASC") qb.Tokens = append(qb.Tokens, "ASC")
return qb return qb
} }
// Desc join the desc
func (qb *MySQLQueryBuilder) Desc() QueryBuilder { func (qb *MySQLQueryBuilder) Desc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "DESC") qb.Tokens = append(qb.Tokens, "DESC")
return qb return qb
} }
// Limit join the limit num
func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder { func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit))
return qb return qb
} }
// Offset join the offset num
func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder { func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset))
return qb return qb
} }
// GroupBy join the Group by fields
func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder { func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, COMMA_SPACE)) qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace))
return qb return qb
} }
// Having join the Having cond
func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder { func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "HAVING", cond) qb.Tokens = append(qb.Tokens, "HAVING", cond)
return qb return qb
} }
// Update join the update table
func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder { func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, COMMA_SPACE)) qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace))
return qb return qb
} }
// Set join the set kv
func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder { func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, COMMA_SPACE)) qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace))
return qb return qb
} }
// Delete join the Delete tables
func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder { func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "DELETE") qb.Tokens = append(qb.Tokens, "DELETE")
if len(tables) != 0 { if len(tables) != 0 {
qb.Tokens = append(qb.Tokens, strings.Join(tables, COMMA_SPACE)) qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace))
} }
return qb return qb
} }
// InsertInto join the insert SQL
func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INSERT INTO", table) qb.Tokens = append(qb.Tokens, "INSERT INTO", table)
if len(fields) != 0 { if len(fields) != 0 {
fieldsStr := strings.Join(fields, COMMA_SPACE) fieldsStr := strings.Join(fields, CommaSpace)
qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")")
} }
return qb return qb
} }
// Values join the Values(vals)
func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder { func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder {
valsStr := strings.Join(vals, COMMA_SPACE) valsStr := strings.Join(vals, CommaSpace)
qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")")
return qb return qb
} }
// Subquery join the sub as alias
func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string { func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string {
return fmt.Sprintf("(%s) AS %s", sub, alias) return fmt.Sprintf("(%s) AS %s", sub, alias)
} }
// String join all Tokens
func (qb *MySQLQueryBuilder) String() string { func (qb *MySQLQueryBuilder) String() string {
return strings.Join(qb.Tokens, " ") return strings.Join(qb.Tokens, " ")
} }

176
orm/qb_tidb.go Normal file
View File

@ -0,0 +1,176 @@
// Copyright 2015 TiDB Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"strconv"
"strings"
)
// TiDBQueryBuilder is the SQL build
type TiDBQueryBuilder struct {
Tokens []string
}
// Select will join the fields
func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace))
return qb
}
// From join the tables
func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace))
return qb
}
// InnerJoin INNER JOIN the table
func (qb *TiDBQueryBuilder) InnerJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INNER JOIN", table)
return qb
}
// LeftJoin LEFT JOIN the table
func (qb *TiDBQueryBuilder) LeftJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LEFT JOIN", table)
return qb
}
// RightJoin RIGHT JOIN the table
func (qb *TiDBQueryBuilder) RightJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table)
return qb
}
// On join with on cond
func (qb *TiDBQueryBuilder) On(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "ON", cond)
return qb
}
// Where join the Where cond
func (qb *TiDBQueryBuilder) Where(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "WHERE", cond)
return qb
}
// And join the and cond
func (qb *TiDBQueryBuilder) And(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "AND", cond)
return qb
}
// Or join the or cond
func (qb *TiDBQueryBuilder) Or(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OR", cond)
return qb
}
// In join the IN (vals)
func (qb *TiDBQueryBuilder) In(vals ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")")
return qb
}
// OrderBy join the Order by fields
func (qb *TiDBQueryBuilder) OrderBy(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace))
return qb
}
// Asc join the asc
func (qb *TiDBQueryBuilder) Asc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "ASC")
return qb
}
// Desc join the desc
func (qb *TiDBQueryBuilder) Desc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "DESC")
return qb
}
// Limit join the limit num
func (qb *TiDBQueryBuilder) Limit(limit int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit))
return qb
}
// Offset join the offset num
func (qb *TiDBQueryBuilder) Offset(offset int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset))
return qb
}
// GroupBy join the Group by fields
func (qb *TiDBQueryBuilder) GroupBy(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace))
return qb
}
// Having join the Having cond
func (qb *TiDBQueryBuilder) Having(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "HAVING", cond)
return qb
}
// Update join the update table
func (qb *TiDBQueryBuilder) Update(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace))
return qb
}
// Set join the set kv
func (qb *TiDBQueryBuilder) Set(kv ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace))
return qb
}
// Delete join the Delete tables
func (qb *TiDBQueryBuilder) Delete(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "DELETE")
if len(tables) != 0 {
qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace))
}
return qb
}
// InsertInto join the insert SQL
func (qb *TiDBQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INSERT INTO", table)
if len(fields) != 0 {
fieldsStr := strings.Join(fields, CommaSpace)
qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")")
}
return qb
}
// Values join the Values(vals)
func (qb *TiDBQueryBuilder) Values(vals ...string) QueryBuilder {
valsStr := strings.Join(vals, CommaSpace)
qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")")
return qb
}
// Subquery join the sub as alias
func (qb *TiDBQueryBuilder) Subquery(sub string, alias string) string {
return fmt.Sprintf("(%s) AS %s", sub, alias)
}
// String join all Tokens
func (qb *TiDBQueryBuilder) String() string {
return strings.Join(qb.Tokens, " ")
}

View File

@ -20,13 +20,13 @@ import (
"time" "time"
) )
// database driver // Driver define database driver
type Driver interface { type Driver interface {
Name() string Name() string
Type() DriverType Type() DriverType
} }
// field info // Fielder define field info
type Fielder interface { type Fielder interface {
String() string String() string
FieldType() int FieldType() int
@ -34,84 +34,315 @@ type Fielder interface {
RawValue() interface{} RawValue() interface{}
} }
// orm struct // Ormer define the orm interface
type Ormer interface { type Ormer interface {
Read(interface{}, ...string) error // read data to model
ReadOrCreate(interface{}, string, ...string) (bool, int64, error) // for example:
// this will find User by Id field
// u = &User{Id: user.Id}
// err = Ormer.Read(u)
// this will find User by UserName field
// u = &User{UserName: "astaxie", Password: "pass"}
// err = Ormer.Read(u, "UserName")
Read(md interface{}, cols ...string) error
// Try to read a row from the database, or insert one if it doesn't exist
ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error)
// insert model data to database
// for example:
// user := new(User)
// id, err = Ormer.Insert(user)
// user must a pointer and Insert will set user's pk field
Insert(interface{}) (int64, error) Insert(interface{}) (int64, error)
InsertMulti(int, interface{}) (int64, error) // insert some models to database
Update(interface{}, ...string) (int64, error) InsertMulti(bulk int, mds interface{}) (int64, error)
Delete(interface{}) (int64, error) // update model to database.
LoadRelated(interface{}, string, ...interface{}) (int64, error) // cols set the columns those want to update.
QueryM2M(interface{}, string) QueryM2Mer // find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns
QueryTable(interface{}) QuerySeter // for example:
Using(string) error // user := User{Id: 2}
// user.Langs = append(user.Langs, "zh-CN", "en-US")
// user.Extra.Name = "beego"
// user.Extra.Data = "orm"
// num, err = Ormer.Update(&user, "Langs", "Extra")
Update(md interface{}, cols ...string) (int64, error)
// delete model in database
Delete(md interface{}) (int64, error)
// load related models to md model.
// args are limit, offset int and order string.
//
// example:
// Ormer.LoadRelated(post,"Tags")
// for _,tag := range post.Tags{...}
//args[0] bool true useDefaultRelsDepth ; false depth 0
//args[0] int loadRelationDepth
//args[1] int limit default limit 1000
//args[2] int offset default offset 0
//args[3] string order for example : "-Id"
// make sure the relation is defined in model struct tags.
LoadRelated(md interface{}, name string, args ...interface{}) (int64, error)
// create a models to models queryer
// for example:
// post := Post{Id: 4}
// m2m := Ormer.QueryM2M(&post, "Tags")
QueryM2M(md interface{}, name string) QueryM2Mer
// return a QuerySeter for table operations.
// table name can be string or struct.
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
QueryTable(ptrStructOrTableName interface{}) QuerySeter
// switch to another registered database driver by given name.
Using(name string) error
// begin transaction
// for example:
// o := NewOrm()
// err := o.Begin()
// ...
// err = o.Rollback()
Begin() error Begin() error
// commit transaction
Commit() error Commit() error
// rollback transaction
Rollback() error Rollback() error
Raw(string, ...interface{}) RawSeter // return a raw query seter for raw sql string.
// for example:
// ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec()
// // update user testing's name to slene
Raw(query string, args ...interface{}) RawSeter
Driver() Driver Driver() Driver
} }
// insert prepared statement // Inserter insert prepared statement
type Inserter interface { type Inserter interface {
Insert(interface{}) (int64, error) Insert(interface{}) (int64, error)
Close() error Close() error
} }
// query seter // QuerySeter query seter
type QuerySeter interface { type QuerySeter interface {
// add condition expression to QuerySeter.
// for example:
// filter by UserName == 'slene'
// qs.Filter("UserName", "slene")
// sql : left outer join profile on t0.id1==t1.id2 where t1.age == 28
// Filter("profile__Age", 28)
// // time compare
// qs.Filter("created", time.Now())
Filter(string, ...interface{}) QuerySeter Filter(string, ...interface{}) QuerySeter
// add NOT condition to querySeter.
// have the same usage as Filter
Exclude(string, ...interface{}) QuerySeter Exclude(string, ...interface{}) QuerySeter
// set condition to QuerySeter.
// sql's where condition
// cond := orm.NewCondition()
// cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000)
// //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
// num, err := qs.SetCond(cond1).Count()
SetCond(*Condition) QuerySeter SetCond(*Condition) QuerySeter
Limit(interface{}, ...interface{}) QuerySeter // add LIMIT value.
Offset(interface{}) QuerySeter // args[0] means offset, e.g. LIMIT num,offset.
OrderBy(...string) QuerySeter // if Limit <= 0 then Limit will be set to default limit ,eg 1000
RelatedSel(...interface{}) QuerySeter // if QuerySeter doesn't call Limit, the sql's Limit will be set to default limit, eg 1000
// for example:
// qs.Limit(10, 2)
// // sql-> limit 10 offset 2
Limit(limit interface{}, args ...interface{}) QuerySeter
// add OFFSET value
// same as Limit function's args[0]
Offset(offset interface{}) QuerySeter
// add ORDER expression.
// "column" means ASC, "-column" means DESC.
// for example:
// qs.OrderBy("-status")
OrderBy(exprs ...string) QuerySeter
// set relation model to query together.
// it will query relation models and assign to parent model.
// for example:
// // will load all related fields use left join .
// qs.RelatedSel().One(&user)
// // will load related field only profile
// qs.RelatedSel("profile").One(&user)
// user.Profile.Age = 32
RelatedSel(params ...interface{}) QuerySeter
// return QuerySeter execution result number
// for example:
// num, err = qs.Filter("profile__age__gt", 28).Count()
Count() (int64, error) Count() (int64, error)
// check result empty or not after QuerySeter executed
// the same as QuerySeter.Count > 0
Exist() bool Exist() bool
Update(Params) (int64, error) // execute update with parameters
// for example:
// num, err = qs.Filter("user_name", "slene").Update(Params{
// "Nums": ColValue(Col_Minus, 50),
// }) // user slene's Nums will minus 50
// num, err = qs.Filter("UserName", "slene").Update(Params{
// "user_name": "slene2"
// }) // user slene's name will change to slene2
Update(values Params) (int64, error)
// delete from table
//for example:
// num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete()
// //delete two user who's name is testing1 or testing2
Delete() (int64, error) Delete() (int64, error)
// return a insert queryer.
// it can be used in times.
// example:
// i,err := sq.PrepareInsert()
// num, err = i.Insert(&user1) // user table will add one record user1 at once
// num, err = i.Insert(&user2) // user table will add one record user2 at once
// err = i.Close() //don't forget call Close
PrepareInsert() (Inserter, error) PrepareInsert() (Inserter, error)
All(interface{}, ...string) (int64, error) // query all data and map to containers.
One(interface{}, ...string) error // cols means the columns when querying.
Values(*[]Params, ...string) (int64, error) // for example:
ValuesList(*[]ParamsList, ...string) (int64, error) // var users []*User
ValuesFlat(*ParamsList, string) (int64, error) // qs.All(&users) // users[0],users[1],users[2] ...
RowsToMap(*Params, string, string) (int64, error) All(container interface{}, cols ...string) (int64, error)
RowsToStruct(interface{}, string, string) (int64, error) // query one row data and map to containers.
// cols means the columns when querying.
// for example:
// var user User
// qs.One(&user) //user.UserName == "slene"
One(container interface{}, cols ...string) error
// query all data and map to []map[string]interface.
// expres means condition expression.
// it converts data to []map[column]value.
// for example:
// var maps []Params
// qs.Values(&maps) //maps[0]["UserName"]=="slene"
Values(results *[]Params, exprs ...string) (int64, error)
// query all data and map to [][]interface
// it converts data to [][column_index]value
// for example:
// var list []ParamsList
// qs.ValuesList(&list) // list[0][1] == "slene"
ValuesList(results *[]ParamsList, exprs ...string) (int64, error)
// query all data and map to []interface.
// it's designed for one column record set, auto change to []value, not [][column]value.
// for example:
// var list ParamsList
// qs.ValuesFlat(&list, "UserName") // list[0] == "slene"
ValuesFlat(result *ParamsList, expr string) (int64, error)
// query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to map[string]interface{}{
// "total": 100,
// "found": 200,
// }
RowsToMap(result *Params, keyCol, valueCol string) (int64, error)
// query all rows into struct with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to struct {
// Total int
// Found int
// }
RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
} }
// model to model query struct // QueryM2Mer model to model query struct
// all operations are on the m2m table only, will not affect the origin model table
type QueryM2Mer interface { type QueryM2Mer interface {
// add models to origin models when creating queryM2M.
// example:
// m2m := orm.QueryM2M(post,"Tag")
// m2m.Add(&Tag1{},&Tag2{})
// for _,tag := range post.Tags{}{ ... }
// param could also be any of the follow
// []*Tag{{Id:3,Name: "TestTag1"}, {Id:4,Name: "TestTag2"}}
// &Tag{Id:5,Name: "TestTag3"}
// []interface{}{&Tag{Id:6,Name: "TestTag4"}}
// insert one or more rows to m2m table
// make sure the relation is defined in post model struct tag.
Add(...interface{}) (int64, error) Add(...interface{}) (int64, error)
// remove models following the origin model relationship
// only delete rows from m2m table
// for example:
//tag3 := &Tag{Id:5,Name: "TestTag3"}
//num, err = m2m.Remove(tag3)
Remove(...interface{}) (int64, error) Remove(...interface{}) (int64, error)
// check model is existed in relationship of origin model
Exist(interface{}) bool Exist(interface{}) bool
// clean all models in related of origin model
Clear() (int64, error) Clear() (int64, error)
// count all related models of origin model
Count() (int64, error) Count() (int64, error)
} }
// raw query statement // RawPreparer raw query statement
type RawPreparer interface { type RawPreparer interface {
Exec(...interface{}) (sql.Result, error) Exec(...interface{}) (sql.Result, error)
Close() error Close() error
} }
// raw query seter // RawSeter raw query seter
// create From Ormer.Raw
// for example:
// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q)
// rs := Ormer.Raw(sql, 1)
type RawSeter interface { type RawSeter interface {
//execute sql and get result
Exec() (sql.Result, error) Exec() (sql.Result, error)
QueryRow(...interface{}) error //query data and map to container
QueryRows(...interface{}) (int64, error) //for example:
// var name string
// var id int
// rs.QueryRow(&id,&name) // id==2 name=="slene"
QueryRow(containers ...interface{}) error
// query data rows and map to container
// var ids []int
// var names []int
// query = fmt.Sprintf("SELECT 'id','name' FROM %suser%s", Q, Q)
// num, err = dORM.Raw(query).QueryRows(&ids,&names) // ids=>{1,2},names=>{"nobody","slene"}
QueryRows(containers ...interface{}) (int64, error)
SetArgs(...interface{}) RawSeter SetArgs(...interface{}) RawSeter
Values(*[]Params, ...string) (int64, error) // query data to []map[string]interface
ValuesList(*[]ParamsList, ...string) (int64, error) // see QuerySeter's Values
ValuesFlat(*ParamsList, ...string) (int64, error) Values(container *[]Params, cols ...string) (int64, error)
RowsToMap(*Params, string, string) (int64, error) // query data to [][]interface
RowsToStruct(interface{}, string, string) (int64, error) // see QuerySeter's ValuesList
ValuesList(container *[]ParamsList, cols ...string) (int64, error)
// query data to []interface
// see QuerySeter's ValuesFlat
ValuesFlat(container *ParamsList, cols ...string) (int64, error)
// query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to map[string]interface{}{
// "total": 100,
// "found": 200,
// }
RowsToMap(result *Params, keyCol, valueCol string) (int64, error)
// query all rows into struct with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to struct {
// Total int
// Found int
// }
RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
// return prepared raw statement for used in times.
// for example:
// pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare()
// r, err := pre.Exec("name1") // INSERT INTO tag (name) VALUES (`name1`)
Prepare() (RawPreparer, error) Prepare() (RawPreparer, error)
} }
// statement querier // stmtQuerier statement querier
type stmtQuerier interface { type stmtQuerier interface {
Close() error Close() error
Exec(args ...interface{}) (sql.Result, error) Exec(args ...interface{}) (sql.Result, error)
@ -160,8 +391,8 @@ type dbBaser interface {
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
OperatorSql(string) string OperatorSQL(string) string
GenerateOperatorSql(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
GenerateOperatorLeftCol(*fieldInfo, string, *string) GenerateOperatorLeftCol(*fieldInfo, string, *string)
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)

View File

@ -22,9 +22,10 @@ import (
"time" "time"
) )
// StrTo is the target string
type StrTo string type StrTo string
// set string // Set string
func (f *StrTo) Set(v string) { func (f *StrTo) Set(v string) {
if v != "" { if v != "" {
*f = StrTo(v) *f = StrTo(v)
@ -33,93 +34,93 @@ func (f *StrTo) Set(v string) {
} }
} }
// clean string // Clear string
func (f *StrTo) Clear() { func (f *StrTo) Clear() {
*f = StrTo(0x1E) *f = StrTo(0x1E)
} }
// check string exist // Exist check string exist
func (f StrTo) Exist() bool { func (f StrTo) Exist() bool {
return string(f) != string(0x1E) return string(f) != string(0x1E)
} }
// string to bool // Bool string to bool
func (f StrTo) Bool() (bool, error) { func (f StrTo) Bool() (bool, error) {
return strconv.ParseBool(f.String()) return strconv.ParseBool(f.String())
} }
// string to float32 // Float32 string to float32
func (f StrTo) Float32() (float32, error) { func (f StrTo) Float32() (float32, error) {
v, err := strconv.ParseFloat(f.String(), 32) v, err := strconv.ParseFloat(f.String(), 32)
return float32(v), err return float32(v), err
} }
// string to float64 // Float64 string to float64
func (f StrTo) Float64() (float64, error) { func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64) return strconv.ParseFloat(f.String(), 64)
} }
// string to int // Int string to int
func (f StrTo) Int() (int, error) { func (f StrTo) Int() (int, error) {
v, err := strconv.ParseInt(f.String(), 10, 32) v, err := strconv.ParseInt(f.String(), 10, 32)
return int(v), err return int(v), err
} }
// string to int8 // Int8 string to int8
func (f StrTo) Int8() (int8, error) { func (f StrTo) Int8() (int8, error) {
v, err := strconv.ParseInt(f.String(), 10, 8) v, err := strconv.ParseInt(f.String(), 10, 8)
return int8(v), err return int8(v), err
} }
// string to int16 // Int16 string to int16
func (f StrTo) Int16() (int16, error) { func (f StrTo) Int16() (int16, error) {
v, err := strconv.ParseInt(f.String(), 10, 16) v, err := strconv.ParseInt(f.String(), 10, 16)
return int16(v), err return int16(v), err
} }
// string to int32 // Int32 string to int32
func (f StrTo) Int32() (int32, error) { func (f StrTo) Int32() (int32, error) {
v, err := strconv.ParseInt(f.String(), 10, 32) v, err := strconv.ParseInt(f.String(), 10, 32)
return int32(v), err return int32(v), err
} }
// string to int64 // Int64 string to int64
func (f StrTo) Int64() (int64, error) { func (f StrTo) Int64() (int64, error) {
v, err := strconv.ParseInt(f.String(), 10, 64) v, err := strconv.ParseInt(f.String(), 10, 64)
return int64(v), err return int64(v), err
} }
// string to uint // Uint string to uint
func (f StrTo) Uint() (uint, error) { func (f StrTo) Uint() (uint, error) {
v, err := strconv.ParseUint(f.String(), 10, 32) v, err := strconv.ParseUint(f.String(), 10, 32)
return uint(v), err return uint(v), err
} }
// string to uint8 // Uint8 string to uint8
func (f StrTo) Uint8() (uint8, error) { func (f StrTo) Uint8() (uint8, error) {
v, err := strconv.ParseUint(f.String(), 10, 8) v, err := strconv.ParseUint(f.String(), 10, 8)
return uint8(v), err return uint8(v), err
} }
// string to uint16 // Uint16 string to uint16
func (f StrTo) Uint16() (uint16, error) { func (f StrTo) Uint16() (uint16, error) {
v, err := strconv.ParseUint(f.String(), 10, 16) v, err := strconv.ParseUint(f.String(), 10, 16)
return uint16(v), err return uint16(v), err
} }
// string to uint31 // Uint32 string to uint31
func (f StrTo) Uint32() (uint32, error) { func (f StrTo) Uint32() (uint32, error) {
v, err := strconv.ParseUint(f.String(), 10, 32) v, err := strconv.ParseUint(f.String(), 10, 32)
return uint32(v), err return uint32(v), err
} }
// string to uint64 // Uint64 string to uint64
func (f StrTo) Uint64() (uint64, error) { func (f StrTo) Uint64() (uint64, error) {
v, err := strconv.ParseUint(f.String(), 10, 64) v, err := strconv.ParseUint(f.String(), 10, 64)
return uint64(v), err return uint64(v), err
} }
// string to string // String string to string
func (f StrTo) String() string { func (f StrTo) String() string {
if f.Exist() { if f.Exist() {
return string(f) return string(f)
@ -127,7 +128,7 @@ func (f StrTo) String() string {
return "" return ""
} }
// interface to string // ToStr interface to string
func ToStr(value interface{}, args ...int) (s string) { func ToStr(value interface{}, args ...int) (s string) {
switch v := value.(type) { switch v := value.(type) {
case bool: case bool:
@ -166,7 +167,7 @@ func ToStr(value interface{}, args ...int) (s string) {
return s return s
} }
// interface to int64 // ToInt64 interface to int64
func ToInt64(value interface{}) (d int64) { func ToInt64(value interface{}) (d int64) {
val := reflect.ValueOf(value) val := reflect.ValueOf(value)
switch value.(type) { switch value.(type) {
@ -248,30 +249,12 @@ func (a argInt) Get(i int, args ...int) (r int) {
return return
} }
type argAny []interface{}
// get interface by index from interface slice
func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
if i >= 0 && i < len(a) {
r = a[i]
}
if len(args) > 0 {
r = args[0]
}
return
}
// parse time to string with location // parse time to string with location
func timeParse(dateString, format string) (time.Time, error) { func timeParse(dateString, format string) (time.Time, error) {
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
return tp, err return tp, err
} }
// format time string
func timeFormat(t time.Time, format string) string {
return t.Format(format)
}
// get pointer indirect type // get pointer indirect type
func indirectType(v reflect.Type) reflect.Type { func indirectType(v reflect.Type) reflect.Type {
switch v.Kind() { switch v.Kind() {

Some files were not shown because too many files have changed in this diff Show More