1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-16 12:10:55 +00:00

Merge remote-tracking branch 'refs/remotes/astaxie/master'

This commit is contained in:
bradycao 2016-01-19 09:43:43 +08:00
commit b48f251043
153 changed files with 7610 additions and 5409 deletions

1
.gitignore vendored
View File

@ -2,3 +2,4 @@
.DS_Store
*.swp
*.swo
beego.iml

30
.travis.yml Normal file
View File

@ -0,0 +1,30 @@
language: go
go:
- 1.5.1
services:
- redis-server
- mysql
- postgresql
- memcached
env:
- ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db
- ORM_DRIVER=mysql ORM_SOURCE="root:@/orm_test?charset=utf8"
- ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
install:
- go get github.com/lib/pq
- go get github.com/go-sql-driver/mysql
- go get github.com/mattn/go-sqlite3
- go get github.com/bradfitz/gomemcache/memcache
- go get github.com/garyburd/redigo/redis
- go get github.com/beego/x2j
- go get github.com/beego/goyaml2
- go get github.com/belogik/goes
- go get github.com/couchbase/go-couchbase
- go get github.com/siddontang/ledisdb/config
- go get github.com/siddontang/ledisdb/ledis
before_script:
- sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi"
- sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi"
- sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi"

52
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,52 @@
# Contributing to beego
beego is an open source project.
It is the work of hundreds of contributors. We appreciate your help!
Here are instructions to get you started. They are probably not perfect,
please let us know if anything feels wrong or incomplete.
## Contribution guidelines
### Pull requests
First of all. beego follow the gitflow. So please send you pull request
to **develop** branch. We will close the pull request to master branch.
We are always happy to receive pull requests, and do our best to
review them as fast as possible. Not sure if that typo is worth a pull
request? Do it! We will appreciate it.
If your pull request is not accepted on the first try, don't be
discouraged! Sometimes we can make a mistake, please do more explaining
for us. We will appreciate it.
We're trying very hard to keep beego simple and fast. We don't want it
to do everything for everybody. This means that we might decide against
incorporating a new feature. But we will give you some advice on how to
do it in other way.
### Create issues
Any significant improvement should be documented as [a GitHub
issue](https://github.com/astaxie/beego/issues) before anybody
starts working on it.
Also when filing an issue, make sure to answer these five questions:
- What version of beego are you using (bee version)?
- What operating system and processor architecture are you using?
- What did you do?
- What did you expect to see?
- What did you see instead?
### but check existing issues and docs first!
Please take a moment to check that an issue doesn't already exist
documenting your bug report or improvement proposal. If it does, it
never hurts to add a quick "+1" or "I have this problem too". This will
help prioritize the most common problems and requests.
Also if you don't know how to use it. please make sure you have read though
the docs in http://beego.me/docs

View File

@ -1,16 +1,38 @@
## Beego
[![Build Status](https://drone.io/github.com/astaxie/beego/status.png)](https://drone.io/github.com/astaxie/beego/latest)
[![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego)
[![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego)
beego is an open-source, high-performance, modular, full-stack web framework.
beego is used for rapid development of RESTful APIs, web apps and backend services in Go.
It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct embedding.
More info [beego.me](http://beego.me)
## Installation
##Quick Start
######Download and install
go get github.com/astaxie/beego
######Create file `hello.go`
```go
package main
import "github.com/astaxie/beego"
func main(){
beego.Run()
}
```
######Build and run
```bash
go build hello.go
./hello
```
######Congratulations!
You just built your first beego app.
Open your browser and visit `http://localhost:8000`.
Please see [Documentation](http://beego.me/docs) for more.
## Features
* RESTful support
@ -26,6 +48,7 @@ More info [beego.me](http://beego.me)
* [English](http://beego.me/docs/intro/)
* [中文文档](http://beego.me/docs/intro/)
* [Русский](http://beego.me/docs/intro/)
## Community
@ -33,5 +56,5 @@ More info [beego.me](http://beego.me)
## LICENSE
beego is licensed under the Apache Licence, Version 2.0
beego source code is licensed under the Apache Licence, Version 2.0
(http://www.apache.org/licenses/LICENSE-2.0.html).

324
admin.go
View File

@ -19,9 +19,11 @@ import (
"encoding/json"
"fmt"
"net/http"
"os"
"text/template"
"time"
"github.com/astaxie/beego/grace"
"github.com/astaxie/beego/toolbox"
"github.com/astaxie/beego/utils"
)
@ -63,24 +65,15 @@ func init() {
// AdminIndex is the default http.Handler for admin module.
// it matches url pattern "/".
func adminIndex(rw http.ResponseWriter, r *http.Request) {
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
tmpl = template.Must(tmpl.Parse(indexTpl))
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
data := make(map[interface{}]interface{})
tmpl.Execute(rw, data)
execTpl(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl)
}
// QpsIndex is the http.Handler for writing qbs statistics map result info in http.ResponseWriter.
// it's registered with url pattern "/qbs" in admin module.
func qpsIndex(rw http.ResponseWriter, r *http.Request) {
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
tmpl = template.Must(tmpl.Parse(qpsTpl))
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
data := make(map[interface{}]interface{})
data["Content"] = toolbox.StatisticsMap.GetMap()
tmpl.Execute(rw, data)
execTpl(rw, data, qpsTpl, defaultScriptsTpl)
}
// ListConf is the http.Handler of displaying all beego configuration values as key/value pair.
@ -88,49 +81,66 @@ func qpsIndex(rw http.ResponseWriter, r *http.Request) {
func listConf(rw http.ResponseWriter, r *http.Request) {
r.ParseForm()
command := r.Form.Get("command")
if command != "" {
if command == "" {
rw.Write([]byte("command not support"))
return
}
data := make(map[interface{}]interface{})
switch command {
case "conf":
m := make(map[string]interface{})
m["AppName"] = AppName
m["AppPath"] = AppPath
m["AppConfigPath"] = AppConfigPath
m["StaticDir"] = StaticDir
m["StaticExtensionsToGzip"] = StaticExtensionsToGzip
m["HttpAddr"] = HttpAddr
m["HttpPort"] = HttpPort
m["HttpTLS"] = EnableHttpTLS
m["HttpCertFile"] = HttpCertFile
m["HttpKeyFile"] = HttpKeyFile
m["RecoverPanic"] = RecoverPanic
m["AutoRender"] = AutoRender
m["ViewsPath"] = ViewsPath
m["RunMode"] = RunMode
m["SessionOn"] = SessionOn
m["SessionProvider"] = SessionProvider
m["SessionName"] = SessionName
m["SessionGCMaxLifetime"] = SessionGCMaxLifetime
m["SessionSavePath"] = SessionSavePath
m["SessionCookieLifeTime"] = SessionCookieLifeTime
m["UseFcgi"] = UseFcgi
m["MaxMemory"] = MaxMemory
m["EnableGzip"] = EnableGzip
m["DirectoryIndex"] = DirectoryIndex
m["HttpServerTimeOut"] = HttpServerTimeOut
m["ErrorsShow"] = ErrorsShow
m["XSRFKEY"] = XSRFKEY
m["EnableXSRF"] = EnableXSRF
m["XSRFExpire"] = XSRFExpire
m["CopyRequestBody"] = CopyRequestBody
m["TemplateLeft"] = TemplateLeft
m["TemplateRight"] = TemplateRight
m["BeegoServerName"] = BeegoServerName
m["EnableAdmin"] = EnableAdmin
m["AdminHttpAddr"] = AdminHttpAddr
m["AdminHttpPort"] = AdminHttpPort
m["AppConfigProvider"] = AppConfigProvider
m["BConfig.AppName"] = BConfig.AppName
m["BConfig.RunMode"] = BConfig.RunMode
m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive
m["BConfig.ServerName"] = BConfig.ServerName
m["BConfig.RecoverPanic"] = BConfig.RecoverPanic
m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody
m["BConfig.EnableGzip"] = BConfig.EnableGzip
m["BConfig.MaxMemory"] = BConfig.MaxMemory
m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow
m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful
m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut
m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4
m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP
m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr
m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort
m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS
m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr
m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort
m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile
m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile
m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin
m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr
m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort
m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi
m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo
m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender
m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs
m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName
m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator
m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex
m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir
m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip
m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft
m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight
m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath
m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF
m["BConfig.WebConfig.XSRFKEY"] = BConfig.WebConfig.XSRFKey
m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire
m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn
m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider
m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName
m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime
m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig
m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime
m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie
m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain
m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs
m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum
m["BConfig.Log.Outputs"] = BConfig.Log.Outputs
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
tmpl = template.Must(tmpl.Parse(configTpl))
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
@ -140,17 +150,17 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
tmpl.Execute(rw, data)
case "router":
content := make(map[string]interface{})
var fields = []string{
fmt.Sprintf("Router Pattern"),
fmt.Sprintf("Methods"),
fmt.Sprintf("Controller"),
var (
content = map[string]interface{}{
"Fields": []string{
"Router Pattern",
"Methods",
"Controller",
},
}
content["Fields"] = fields
methods := []string{}
methodsData := make(map[string]interface{})
methods = []string{}
methodsData = make(map[string]interface{})
)
for method, t := range BeeApp.Handlers.routers {
resultList := new([][]string)
@ -165,32 +175,32 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
content["Methods"] = methods
data["Content"] = content
data["Title"] = "Routers"
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
tmpl = template.Must(tmpl.Parse(routerAndFilterTpl))
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
tmpl.Execute(rw, data)
execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl)
case "filter":
content := make(map[string]interface{})
var fields = []string{
fmt.Sprintf("Router Pattern"),
fmt.Sprintf("Filter Function"),
var (
content = map[string]interface{}{
"Fields": []string{
"Router Pattern",
"Filter Function",
},
}
content["Fields"] = fields
filterTypes := []string{}
filterTypeData := make(map[string]interface{})
filterTypes = []string{}
filterTypeData = make(map[string]interface{})
)
if BeeApp.Handlers.enableFilter {
var filterType string
if bf, ok := BeeApp.Handlers.filters[BeforeRouter]; ok {
filterType = "Before Router"
for k, fr := range map[int]string{
BeforeStatic: "Before Static",
BeforeRouter: "Before Router",
BeforeExec: "Before Exec",
AfterExec: "After Exec",
FinishRouter: "Finish Router"} {
if bf, ok := BeeApp.Handlers.filters[k]; ok {
filterType = fr
filterTypes = append(filterTypes, filterType)
resultList := new([][]string)
for _, f := range bf {
var result = []string{
fmt.Sprintf("%s", f.pattern),
fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)),
@ -199,50 +209,6 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
}
filterTypeData[filterType] = resultList
}
if bf, ok := BeeApp.Handlers.filters[BeforeExec]; ok {
filterType = "Before Exec"
filterTypes = append(filterTypes, filterType)
resultList := new([][]string)
for _, f := range bf {
var result = []string{
fmt.Sprintf("%s", f.pattern),
fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)),
}
*resultList = append(*resultList, result)
}
filterTypeData[filterType] = resultList
}
if bf, ok := BeeApp.Handlers.filters[AfterExec]; ok {
filterType = "After Exec"
filterTypes = append(filterTypes, filterType)
resultList := new([][]string)
for _, f := range bf {
var result = []string{
fmt.Sprintf("%s", f.pattern),
fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)),
}
*resultList = append(*resultList, result)
}
filterTypeData[filterType] = resultList
}
if bf, ok := BeeApp.Handlers.filters[FinishRouter]; ok {
filterType = "Finish Router"
filterTypes = append(filterTypes, filterType)
resultList := new([][]string)
for _, f := range bf {
var result = []string{
fmt.Sprintf("%s", f.pattern),
fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)),
}
*resultList = append(*resultList, result)
}
filterTypeData[filterType] = resultList
}
}
@ -251,16 +217,10 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
data["Content"] = content
data["Title"] = "Filters"
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
tmpl = template.Must(tmpl.Parse(routerAndFilterTpl))
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
tmpl.Execute(rw, data)
execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl)
default:
rw.Write([]byte("command not support"))
}
} else {
}
}
func printTree(resultList *[][]string, t *Tree) {
@ -274,23 +234,23 @@ func printTree(resultList *[][]string, t *Tree) {
if v, ok := l.runObject.(*controllerInfo); ok {
if v.routerType == routerTypeBeego {
var result = []string{
fmt.Sprintf("%s", v.pattern),
v.pattern,
fmt.Sprintf("%s", v.methods),
fmt.Sprintf("%s", v.controllerType),
}
*resultList = append(*resultList, result)
} else if v.routerType == routerTypeRESTFul {
var result = []string{
fmt.Sprintf("%s", v.pattern),
v.pattern,
fmt.Sprintf("%s", v.methods),
fmt.Sprintf(""),
"",
}
*resultList = append(*resultList, result)
} else if v.routerType == routerTypeHandler {
var result = []string{
fmt.Sprintf("%s", v.pattern),
fmt.Sprintf(""),
fmt.Sprintf(""),
v.pattern,
"",
"",
}
*resultList = append(*resultList, result)
}
@ -303,54 +263,49 @@ func printTree(resultList *[][]string, t *Tree) {
func profIndex(rw http.ResponseWriter, r *http.Request) {
r.ParseForm()
command := r.Form.Get("command")
format := r.Form.Get("format")
data := make(map[string]interface{})
if command == "" {
return
}
var result bytes.Buffer
if command != "" {
var (
format = r.Form.Get("format")
data = make(map[interface{}]interface{})
result bytes.Buffer
)
toolbox.ProcessInput(command, &result)
data["Content"] = result.String()
if format == "json" && command == "gc summary" {
dataJson, err := json.Marshal(data)
dataJSON, err := json.Marshal(data)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
rw.Header().Set("Content-Type", "application/json")
rw.Write(dataJson)
rw.Write(dataJSON)
return
}
data["Title"] = command
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
tmpl = template.Must(tmpl.Parse(profillingTpl))
defaultTpl := defaultScriptsTpl
if command == "gc summary" {
tmpl = template.Must(tmpl.Parse(gcAjaxTpl))
} else {
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
}
tmpl.Execute(rw, data)
} else {
defaultTpl = gcAjaxTpl
}
execTpl(rw, data, profillingTpl, defaultTpl)
}
// Healthcheck is a http.Handler calling health checking and showing the result.
// it's in "/healthcheck" pattern in admin module.
func healthcheck(rw http.ResponseWriter, req *http.Request) {
data := make(map[interface{}]interface{})
var result = []string{}
fields := []string{
fmt.Sprintf("Name"),
fmt.Sprintf("Message"),
fmt.Sprintf("Status"),
var (
data = make(map[interface{}]interface{})
result = []string{}
resultList = new([][]string)
content = map[string]interface{}{
"Fields": []string{"Name", "Message", "Status"},
}
resultList := new([][]string)
content := make(map[string]interface{})
)
for name, h := range toolbox.AdminCheckList {
if err := h.Check(); err != nil {
@ -370,16 +325,10 @@ func healthcheck(rw http.ResponseWriter, req *http.Request) {
}
*resultList = append(*resultList, result)
}
content["Fields"] = fields
content["Data"] = resultList
data["Content"] = content
data["Title"] = "Health Check"
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
tmpl = template.Must(tmpl.Parse(healthCheckTpl))
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
tmpl.Execute(rw, data)
execTpl(rw, data, healthCheckTpl, defaultScriptsTpl)
}
// TaskStatus is a http.Handler with running task status (task name, status and the last execution).
@ -391,10 +340,8 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) {
req.ParseForm()
taskname := req.Form.Get("taskname")
if taskname != "" {
if t, ok := toolbox.AdminTaskList[taskname]; ok {
err := t.Run()
if err != nil {
if err := t.Run(); err != nil {
data["Message"] = []string{"error", fmt.Sprintf("%s", err)}
}
data["Message"] = []string{"success", fmt.Sprintf("%s run success,Now the Status is <br>%s", taskname, t.GetStatus())}
@ -408,18 +355,18 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) {
resultList := new([][]string)
var result = []string{}
var fields = []string{
fmt.Sprintf("Task Name"),
fmt.Sprintf("Task Spec"),
fmt.Sprintf("Task Status"),
fmt.Sprintf("Last Time"),
fmt.Sprintf(""),
"Task Name",
"Task Spec",
"Task Status",
"Last Time",
"",
}
for tname, tk := range toolbox.AdminTaskList {
result = []string{
fmt.Sprintf("%s", tname),
tname,
fmt.Sprintf("%s", tk.GetSpec()),
fmt.Sprintf("%s", tk.GetStatus()),
fmt.Sprintf("%s", tk.GetPrev().String()),
tk.GetPrev().String(),
}
*resultList = append(*resultList, result)
}
@ -428,9 +375,14 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) {
content["Data"] = resultList
data["Content"] = content
data["Title"] = "Tasks"
execTpl(rw, data, tasksTpl, defaultScriptsTpl)
}
func execTpl(rw http.ResponseWriter, data map[interface{}]interface{}, tpls ...string) {
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
tmpl = template.Must(tmpl.Parse(tasksTpl))
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
for _, tpl := range tpls {
tmpl = template.Must(tmpl.Parse(tpl))
}
tmpl.Execute(rw, data)
}
@ -450,17 +402,23 @@ func (admin *adminApp) Run() {
if len(toolbox.AdminTaskList) > 0 {
toolbox.StartTask()
}
addr := AdminHttpAddr
addr := BConfig.Listen.AdminAddr
if AdminHttpPort != 0 {
addr = fmt.Sprintf("%s:%d", AdminHttpAddr, AdminHttpPort)
if BConfig.Listen.AdminPort != 0 {
addr = fmt.Sprintf("%s:%d", BConfig.Listen.AdminAddr, BConfig.Listen.AdminPort)
}
for p, f := range admin.routers {
http.Handle(p, f)
}
BeeLogger.Info("Admin server Running on %s", addr)
err := http.ListenAndServe(addr, nil)
var err error
if BConfig.Listen.Graceful {
err = grace.ListenAndServe(addr, nil)
} else {
err = http.ListenAndServe(addr, nil)
}
if err != nil {
BeeLogger.Critical("Admin ListenAndServe: ", err)
BeeLogger.Critical("Admin ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid()))
}
}

294
app.go
View File

@ -20,14 +20,26 @@ import (
"net/http"
"net/http/fcgi"
"os"
"path"
"time"
"github.com/astaxie/beego/grace"
"github.com/astaxie/beego/utils"
)
var (
// BeeApp is an application instance
BeeApp *App
)
func init() {
// create beego application
BeeApp = NewApp()
}
// App defines beego application with a new PatternServeMux.
type App struct {
Handlers *ControllerRegistor
Handlers *ControllerRegister
Server *http.Server
}
@ -40,28 +52,29 @@ func NewApp() *App {
// Run beego application.
func (app *App) Run() {
addr := HttpAddr
addr := BConfig.Listen.HTTPAddr
if HttpPort != 0 {
addr = fmt.Sprintf("%s:%d", HttpAddr, HttpPort)
if BConfig.Listen.HTTPPort != 0 {
addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPAddr, BConfig.Listen.HTTPPort)
}
var (
err error
l net.Listener
endRunning = make(chan bool, 1)
)
endRunning := make(chan bool, 1)
if UseFcgi {
if UseStdIo {
err = fcgi.Serve(nil, app.Handlers) // standard I/O
if err == nil {
// run cgi server
if BConfig.Listen.EnableFcgi {
if BConfig.Listen.EnableStdIo {
if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O
BeeLogger.Info("Use FCGI via standard I/O")
} else {
BeeLogger.Info("Cannot use FCGI via standard I/O", err)
BeeLogger.Critical("Cannot use FCGI via standard I/O", err)
}
} else {
if HttpPort == 0 {
return
}
if BConfig.Listen.HTTPPort == 0 {
// remove the Socket file before start
if utils.FileExists(addr) {
os.Remove(addr)
@ -73,35 +86,77 @@ func (app *App) Run() {
if err != nil {
BeeLogger.Critical("Listen: ", err)
}
err = fcgi.Serve(l, app.Handlers)
if err = fcgi.Serve(l, app.Handlers); err != nil {
BeeLogger.Critical("fcgi.Serve: ", err)
}
return
}
} else {
app.Server.Addr = addr
app.Server.Handler = app.Handlers
app.Server.ReadTimeout = time.Duration(HttpServerTimeOut) * time.Second
app.Server.WriteTimeout = time.Duration(HttpServerTimeOut) * time.Second
if EnableHttpTLS {
app.Server.Handler = app.Handlers
app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second
app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second
// run graceful mode
if BConfig.Listen.Graceful {
httpsAddr := BConfig.Listen.HTTPSAddr
app.Server.Addr = httpsAddr
if BConfig.Listen.EnableHTTPS {
go func() {
time.Sleep(20 * time.Microsecond)
if HttpsPort != 0 {
app.Server.Addr = fmt.Sprintf("%s:%d", HttpAddr, HttpsPort)
if BConfig.Listen.HTTPSPort != 0 {
httpsAddr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort)
app.Server.Addr = httpsAddr
}
server := grace.NewServer(httpsAddr, app.Handlers)
server.Server.ReadTimeout = app.Server.ReadTimeout
server.Server.WriteTimeout = app.Server.WriteTimeout
if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
BeeLogger.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
}()
}
if BConfig.Listen.EnableHTTP {
go func() {
server := grace.NewServer(addr, app.Handlers)
server.Server.ReadTimeout = app.Server.ReadTimeout
server.Server.WriteTimeout = app.Server.WriteTimeout
if BConfig.Listen.ListenTCP4 {
server.Network = "tcp4"
}
if err := server.ListenAndServe(); err != nil {
BeeLogger.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
}()
}
<-endRunning
return
}
// run normal mode
app.Server.Addr = addr
if BConfig.Listen.EnableHTTPS {
go func() {
time.Sleep(20 * time.Microsecond)
if BConfig.Listen.HTTPSPort != 0 {
app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort)
}
BeeLogger.Info("https server Running on %s", app.Server.Addr)
err := app.Server.ListenAndServeTLS(HttpCertFile, HttpKeyFile)
if err != nil {
if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
BeeLogger.Critical("ListenAndServeTLS: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
}()
}
if EnableHttpListen {
if BConfig.Listen.EnableHTTP {
go func() {
app.Server.Addr = addr
BeeLogger.Info("http server Running on %s", app.Server.Addr)
if ListenTCP4 && HttpAddr == "" {
if BConfig.Listen.ListenTCP4 {
ln, err := net.Listen("tcp4", app.Server.Addr)
if err != nil {
BeeLogger.Critical("ListenAndServe: ", err)
@ -109,16 +164,14 @@ func (app *App) Run() {
endRunning <- true
return
}
err = app.Server.Serve(ln)
if err != nil {
if err = app.Server.Serve(ln); err != nil {
BeeLogger.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
return
}
} else {
err := app.Server.ListenAndServe()
if err != nil {
if err := app.Server.ListenAndServe(); err != nil {
BeeLogger.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
@ -126,7 +179,184 @@ func (app *App) Run() {
}
}()
}
}
<-endRunning
}
// Router adds a patterned controller handler to BeeApp.
// it's an alias method of App.Router.
// usage:
// simple router
// beego.Router("/admin", &admin.UserController{})
// beego.Router("/admin/index", &admin.ArticleController{})
//
// regex router
//
// beego.Router("/api/:id([0-9]+)", &controllers.RController{})
//
// custom rules
// beego.Router("/api/list",&RestController{},"*:ListFood")
// beego.Router("/api/create",&RestController{},"post:CreateFood")
// beego.Router("/api/update",&RestController{},"put:UpdateFood")
// beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App {
BeeApp.Handlers.Add(rootpath, c, mappingMethods...)
return BeeApp
}
// Include will generate router file in the router/xxx.go from the controller's comments
// usage:
// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{})
// type BankAccount struct{
// beego.Controller
// }
//
// register the function
// func (b *BankAccount)Mapping(){
// b.Mapping("ShowAccount" , b.ShowAccount)
// b.Mapping("ModifyAccount", b.ModifyAccount)
//}
//
// //@router /account/:id [get]
// func (b *BankAccount) ShowAccount(){
// //logic
// }
//
//
// //@router /account/:id [post]
// func (b *BankAccount) ModifyAccount(){
// //logic
// }
//
// the comments @router url methodlist
// url support all the function Router's pattern
// methodlist [get post head put delete options *]
func Include(cList ...ControllerInterface) *App {
BeeApp.Handlers.Include(cList...)
return BeeApp
}
// RESTRouter adds a restful controller handler to BeeApp.
// its' controller implements beego.ControllerInterface and
// defines a param "pattern/:objectId" to visit each resource.
func RESTRouter(rootpath string, c ControllerInterface) *App {
Router(rootpath, c)
Router(path.Join(rootpath, ":objectId"), c)
return BeeApp
}
// AutoRouter adds defined controller handler to BeeApp.
// it's same to App.AutoRouter.
// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page,
// visit the url /main/list to exec List function or /main/page to exec Page function.
func AutoRouter(c ControllerInterface) *App {
BeeApp.Handlers.AddAuto(c)
return BeeApp
}
// AutoPrefix adds controller handler to BeeApp with prefix.
// it's same to App.AutoRouterWithPrefix.
// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page,
// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function.
func AutoPrefix(prefix string, c ControllerInterface) *App {
BeeApp.Handlers.AddAutoPrefix(prefix, c)
return BeeApp
}
// Get used to register router for Get method
// usage:
// beego.Get("/", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Get(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Get(rootpath, f)
return BeeApp
}
// Post used to register router for Post method
// usage:
// beego.Post("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Post(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Post(rootpath, f)
return BeeApp
}
// Delete used to register router for Delete method
// usage:
// beego.Delete("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Delete(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Delete(rootpath, f)
return BeeApp
}
// Put used to register router for Put method
// usage:
// beego.Put("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Put(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Put(rootpath, f)
return BeeApp
}
// Head used to register router for Head method
// usage:
// beego.Head("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Head(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Head(rootpath, f)
return BeeApp
}
// Options used to register router for Options method
// usage:
// beego.Options("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Options(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Options(rootpath, f)
return BeeApp
}
// Patch used to register router for Patch method
// usage:
// beego.Patch("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Patch(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Patch(rootpath, f)
return BeeApp
}
// Any used to register router for all methods
// usage:
// beego.Any("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Any(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Any(rootpath, f)
return BeeApp
}
// Handler used to register a Handler router
// usage:
// beego.Handler("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Handler(rootpath string, h http.Handler, options ...interface{}) *App {
BeeApp.Handlers.Handler(rootpath, h, options...)
return BeeApp
}
// InsertFilter adds a FilterFunc with pattern condition and action constant.
// The pos means action constant including
// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter.
// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute)
func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App {
BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...)
return BeeApp
}

337
beego.go
View File

@ -12,243 +12,35 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// beego is an open-source, high-performance, modularity, full-stack web framework
//
// package main
//
// import "github.com/astaxie/beego"
//
// func main() {
// beego.Run()
// }
//
// more infomation: http://beego.me
package beego
import (
"net/http"
"fmt"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"github.com/astaxie/beego/session"
)
// beego web framework version.
const VERSION = "1.4.3"
const (
// VERSION represent beego web framework version.
VERSION = "1.6.0"
type hookfunc func() error //hook function to run
var hooks []hookfunc //hook function slice to store the hookfunc
// DEV is for develop
DEV = "dev"
// PROD is for production
PROD = "prod"
)
// Router adds a patterned controller handler to BeeApp.
// it's an alias method of App.Router.
// usage:
// simple router
// beego.Router("/admin", &admin.UserController{})
// beego.Router("/admin/index", &admin.ArticleController{})
//
// regex router
//
// beego.Router("/api/:id([0-9]+)", &controllers.RController{})
//
// custom rules
// beego.Router("/api/list",&RestController{},"*:ListFood")
// beego.Router("/api/create",&RestController{},"post:CreateFood")
// beego.Router("/api/update",&RestController{},"put:UpdateFood")
// beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App {
BeeApp.Handlers.Add(rootpath, c, mappingMethods...)
return BeeApp
}
//hook function to run
type hookfunc func() error
// Router add list from
// usage:
// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{})
// type BankAccount struct{
// beego.Controller
// }
//
// register the function
// func (b *BankAccount)Mapping(){
// b.Mapping("ShowAccount" , b.ShowAccount)
// b.Mapping("ModifyAccount", b.ModifyAccount)
//}
//
// //@router /account/:id [get]
// func (b *BankAccount) ShowAccount(){
// //logic
// }
//
//
// //@router /account/:id [post]
// func (b *BankAccount) ModifyAccount(){
// //logic
// }
//
// the comments @router url methodlist
// url support all the function Router's pattern
// methodlist [get post head put delete options *]
func Include(cList ...ControllerInterface) *App {
BeeApp.Handlers.Include(cList...)
return BeeApp
}
var (
hooks = make([]hookfunc, 0) //hook function slice to store the hookfunc
)
// RESTRouter adds a restful controller handler to BeeApp.
// its' controller implements beego.ControllerInterface and
// defines a param "pattern/:objectId" to visit each resource.
func RESTRouter(rootpath string, c ControllerInterface) *App {
Router(rootpath, c)
Router(path.Join(rootpath, ":objectId"), c)
return BeeApp
}
// AutoRouter adds defined controller handler to BeeApp.
// it's same to App.AutoRouter.
// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page,
// visit the url /main/list to exec List function or /main/page to exec Page function.
func AutoRouter(c ControllerInterface) *App {
BeeApp.Handlers.AddAuto(c)
return BeeApp
}
// AutoPrefix adds controller handler to BeeApp with prefix.
// it's same to App.AutoRouterWithPrefix.
// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page,
// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function.
func AutoPrefix(prefix string, c ControllerInterface) *App {
BeeApp.Handlers.AddAutoPrefix(prefix, c)
return BeeApp
}
// register router for Get method
// usage:
// beego.Get("/", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Get(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Get(rootpath, f)
return BeeApp
}
// register router for Post method
// usage:
// beego.Post("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Post(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Post(rootpath, f)
return BeeApp
}
// register router for Delete method
// usage:
// beego.Delete("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Delete(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Delete(rootpath, f)
return BeeApp
}
// register router for Put method
// usage:
// beego.Put("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Put(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Put(rootpath, f)
return BeeApp
}
// register router for Head method
// usage:
// beego.Head("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Head(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Head(rootpath, f)
return BeeApp
}
// register router for Options method
// usage:
// beego.Options("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Options(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Options(rootpath, f)
return BeeApp
}
// register router for Patch method
// usage:
// beego.Patch("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Patch(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Patch(rootpath, f)
return BeeApp
}
// register router for all method
// usage:
// beego.Any("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Any(rootpath string, f FilterFunc) *App {
BeeApp.Handlers.Any(rootpath, f)
return BeeApp
}
// register router for own Handler
// usage:
// beego.Handler("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func Handler(rootpath string, h http.Handler, options ...interface{}) *App {
BeeApp.Handlers.Handler(rootpath, h, options...)
return BeeApp
}
// SetViewsPath sets view directory path in beego application.
func SetViewsPath(path string) *App {
ViewsPath = path
return BeeApp
}
// SetStaticPath sets static directory path and proper url pattern in beego application.
// if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public".
func SetStaticPath(url string, path string) *App {
if !strings.HasPrefix(url, "/") {
url = "/" + url
}
url = strings.TrimRight(url, "/")
StaticDir[url] = path
return BeeApp
}
// DelStaticPath removes the static folder setting in this url pattern in beego application.
func DelStaticPath(url string) *App {
if !strings.HasPrefix(url, "/") {
url = "/" + url
}
url = strings.TrimRight(url, "/")
delete(StaticDir, url)
return BeeApp
}
// InsertFilter adds a FilterFunc with pattern condition and action constant.
// The pos means action constant including
// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter.
// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute)
func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App {
BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...)
return BeeApp
}
// The hookfunc will run in beego.Run()
// AddAPPStartHook is used to register the hookfunc
// The hookfuncs will run in beego.Run()
// such as sessionInit, middlerware start, buildtemplate, admin start
func AddAPPStartHook(hf hookfunc) {
hooks = append(hooks, hf)
@ -256,97 +48,60 @@ func AddAPPStartHook(hf hookfunc) {
// Run beego application.
// beego.Run() default run on HttpPort
// beego.Run("localhost")
// beego.Run(":8089")
// beego.Run("127.0.0.1:8089")
func Run(params ...string) {
initBeforeHTTPRun()
if len(params) > 0 && params[0] != "" {
strs := strings.Split(params[0], ":")
if len(strs) > 0 && strs[0] != "" {
HttpAddr = strs[0]
BConfig.Listen.HTTPAddr = strs[0]
}
if len(strs) > 1 && strs[1] != "" {
HttpPort, _ = strconv.Atoi(strs[1])
BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1])
}
}
initBeforeHttpRun()
if EnableAdmin {
go beeAdminApp.Run()
}
BeeApp.Run()
}
func initBeforeHttpRun() {
// if AppConfigPath not In the conf/app.conf reParse config
if AppConfigPath != filepath.Join(AppPath, "conf", "app.conf") {
func initBeforeHTTPRun() {
// if AppConfigPath is setted or conf/app.conf exist
err := ParseConfig()
if err != nil && AppConfigPath != filepath.Join(workPath, "conf", "app.conf") {
// configuration is critical to app, panic here if parse failed
if err != nil {
panic(err)
}
//init log
for adaptor, config := range BConfig.Log.Outputs {
err = BeeLogger.SetLogger(adaptor, config)
if err != nil {
fmt.Printf("%s with the config `%s` got err:%s\n", adaptor, config, err)
}
}
//init mime
AddAPPStartHook(initMime)
SetLogFuncCall(BConfig.Log.FileLineNum)
//init hooks
AddAPPStartHook(registerMime)
AddAPPStartHook(registerDefaultErrorHandler)
AddAPPStartHook(registerSession)
AddAPPStartHook(registerDocs)
AddAPPStartHook(registerTemplate)
AddAPPStartHook(registerAdmin)
// do hooks function
for _, hk := range hooks {
err := hk()
if err != nil {
if err := hk(); err != nil {
panic(err)
}
}
if SessionOn {
var err error
sessionConfig := AppConfig.String("sessionConfig")
if sessionConfig == "" {
sessionConfig = `{"cookieName":"` + SessionName + `",` +
`"gclifetime":` + strconv.FormatInt(SessionGCMaxLifetime, 10) + `,` +
`"providerConfig":"` + filepath.ToSlash(SessionSavePath) + `",` +
`"secure":` + strconv.FormatBool(EnableHttpTLS) + `,` +
`"enableSetCookie":` + strconv.FormatBool(SessionAutoSetCookie) + `,` +
`"domain":"` + SessionDomain + `",` +
`"cookieLifeTime":` + strconv.Itoa(SessionCookieLifeTime) + `}`
}
GlobalSessions, err = session.NewManager(SessionProvider,
sessionConfig)
if err != nil {
panic(err)
}
go GlobalSessions.GC()
}
err := BuildTemplate(ViewsPath)
if err != nil {
if RunMode == "dev" {
Warn(err)
}
}
registerDefaultErrorHandler()
if EnableDocs {
Get("/docs", serverDocs)
Get("/docs/*", serverDocs)
}
}
// this function is for test package init
func TestBeegoInit(apppath string) {
AppPath = apppath
RunMode = "test"
AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
err := ParseConfig()
if err != nil && !os.IsNotExist(err) {
// for init if doesn't have app.conf will not panic
Info(err)
}
os.Chdir(AppPath)
initBeforeHttpRun()
}
func init() {
hooks = make([]hookfunc, 0)
// TestBeegoInit is for test package init
func TestBeegoInit(ap string) {
os.Setenv("BEEGO_RUNMODE", "test")
AppConfigPath = filepath.Join(ap, "conf", "app.conf")
os.Chdir(ap)
initBeforeHTTPRun()
}

6
cache/README.md vendored
View File

@ -26,7 +26,7 @@ Then init a Cache (example with memory adapter)
Use it like this:
bm.Put("astaxie", 1, 10)
bm.Put("astaxie", 1, 10 * time.Second)
bm.Get("astaxie")
bm.IsExist("astaxie")
bm.Delete("astaxie")
@ -43,7 +43,7 @@ interval means the gc time. The cache will check at each time interval, whether
## Memcache adapter
Memcache adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client.
Memcache adapter use the [gomemcache](http://github.com/bradfitz/gomemcache) client.
Configure like this:
@ -52,7 +52,7 @@ Configure like this:
## Redis adapter
Redis adapter use the [redigo](http://github.com/garyburd/redigo/redis) client.
Redis adapter use the [redigo](http://github.com/garyburd/redigo) client.
Configure like this:

24
cache/cache.go vendored
View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Package cache provide a Cache interface and some implemetn engine
// Usage:
//
// import(
@ -22,7 +23,7 @@
//
// Use it like this:
//
// bm.Put("astaxie", 1, 10)
// bm.Put("astaxie", 1, 10 * time.Second)
// bm.Get("astaxie")
// bm.IsExist("astaxie")
// bm.Delete("astaxie")
@ -32,13 +33,14 @@ package cache
import (
"fmt"
"time"
)
// Cache interface contains all behaviors for cache adapter.
// usage:
// cache.Register("file",cache.NewFileCache()) // this operation is run in init method of file.go.
// cache.Register("file",cache.NewFileCache) // this operation is run in init method of file.go.
// c,err := cache.NewCache("file","{....}")
// c.Put("key",value,3600)
// c.Put("key",value, 3600 * time.Second)
// v := c.Get("key")
//
// c.Incr("counter") // now is 1
@ -47,8 +49,10 @@ import (
type Cache interface {
// get cached value by key.
Get(key string) interface{}
// GetMulti is a batch version of Get.
GetMulti(keys []string) []interface{}
// set cached value with key and expire time.
Put(key string, val interface{}, timeout int64) error
Put(key string, val interface{}, timeout time.Duration) error
// delete cached value by key.
Delete(key string) error
// increase cached int value by key, as a counter.
@ -63,12 +67,15 @@ type Cache interface {
StartAndGC(config string) error
}
var adapters = make(map[string]Cache)
// Instance is a function create a new Cache Instance
type Instance func() Cache
var adapters = make(map[string]Instance)
// Register makes a cache adapter available by the adapter name.
// If Register is called twice with the same name or if driver is nil,
// it panics.
func Register(name string, adapter Cache) {
func Register(name string, adapter Instance) {
if adapter == nil {
panic("cache: Register adapter is nil")
}
@ -78,15 +85,16 @@ func Register(name string, adapter Cache) {
adapters[name] = adapter
}
// Create a new cache driver by adapter name and config string.
// NewCache Create a new cache driver by adapter name and config string.
// config need to be correct JSON as string: {"interval":360}.
// it will start gc automatically.
func NewCache(adapterName, config string) (adapter Cache, err error) {
adapter, ok := adapters[adapterName]
instanceFunc, ok := adapters[adapterName]
if !ok {
err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName)
return
}
adapter = instanceFunc()
err = adapter.StartAndGC(config)
if err != nil {
adapter = nil

61
cache/cache_test.go vendored
View File

@ -25,7 +25,8 @@ func TestCache(t *testing.T) {
if err != nil {
t.Error("init err")
}
if err = bm.Put("astaxie", 1, 10); err != nil {
timeoutDuration := 10 * time.Second
if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
@ -42,7 +43,7 @@ func TestCache(t *testing.T) {
t.Error("check err")
}
if err = bm.Put("astaxie", 1, 10); err != nil {
if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err)
}
@ -65,6 +66,35 @@ func TestCache(t *testing.T) {
if bm.IsExist("astaxie") {
t.Error("delete err")
}
//test GetMulti
if err = bm.Put("astaxie", "author", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
t.Error("check err")
}
if v := bm.Get("astaxie"); v.(string) != "author" {
t.Error("get err")
}
if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie1") {
t.Error("check err")
}
vv := bm.GetMulti([]string{"astaxie", "astaxie1"})
if len(vv) != 2 {
t.Error("GetMulti ERROR")
}
if vv[0].(string) != "author" {
t.Error("GetMulti ERROR")
}
if vv[1].(string) != "author1" {
t.Error("GetMulti ERROR")
}
}
func TestFileCache(t *testing.T) {
@ -72,7 +102,8 @@ func TestFileCache(t *testing.T) {
if err != nil {
t.Error("init err")
}
if err = bm.Put("astaxie", 1, 10); err != nil {
timeoutDuration := 10 * time.Second
if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
@ -102,16 +133,36 @@ func TestFileCache(t *testing.T) {
if bm.IsExist("astaxie") {
t.Error("delete err")
}
//test string
if err = bm.Put("astaxie", "author", 10); err != nil {
if err = bm.Put("astaxie", "author", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
t.Error("check err")
}
if v := bm.Get("astaxie"); v.(string) != "author" {
t.Error("get err")
}
//test GetMulti
if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie1") {
t.Error("check err")
}
vv := bm.GetMulti([]string{"astaxie", "astaxie1"})
if len(vv) != 2 {
t.Error("GetMulti ERROR")
}
if vv[0].(string) != "author" {
t.Error("GetMulti ERROR")
}
if vv[1].(string) != "author1" {
t.Error("GetMulti ERROR")
}
os.RemoveAll("cache")
}

22
cache/conv.go vendored
View File

@ -19,7 +19,7 @@ import (
"strconv"
)
// convert interface to string.
// GetString convert interface to string.
func GetString(v interface{}) string {
switch result := v.(type) {
case string:
@ -34,7 +34,7 @@ func GetString(v interface{}) string {
return ""
}
// convert interface to int.
// GetInt convert interface to int.
func GetInt(v interface{}) int {
switch result := v.(type) {
case int:
@ -52,7 +52,7 @@ func GetInt(v interface{}) int {
return 0
}
// convert interface to int64.
// GetInt64 convert interface to int64.
func GetInt64(v interface{}) int64 {
switch result := v.(type) {
case int:
@ -71,7 +71,7 @@ func GetInt64(v interface{}) int64 {
return 0
}
// convert interface to float64.
// GetFloat64 convert interface to float64.
func GetFloat64(v interface{}) float64 {
switch result := v.(type) {
case float64:
@ -85,7 +85,7 @@ func GetFloat64(v interface{}) float64 {
return 0
}
// convert interface to bool.
// GetBool convert interface to bool.
func GetBool(v interface{}) bool {
switch result := v.(type) {
case bool:
@ -98,15 +98,3 @@ func GetBool(v interface{}) bool {
}
return false
}
// convert interface to byte slice.
func getByteArray(v interface{}) []byte {
switch result := v.(type) {
case []byte:
return result
case string:
return []byte(result)
default:
return nil
}
}

29
cache/conv_test.go vendored
View File

@ -27,7 +27,7 @@ func TestGetString(t *testing.T) {
if "test2" != GetString(t2) {
t.Error("get string from byte array error")
}
var t3 int = 1
var t3 = 1
if "1" != GetString(t3) {
t.Error("get string from int error")
}
@ -35,7 +35,7 @@ func TestGetString(t *testing.T) {
if "1" != GetString(t4) {
t.Error("get string from int64 error")
}
var t5 float64 = 1.1
var t5 = 1.1
if "1.1" != GetString(t5) {
t.Error("get string from float64 error")
}
@ -46,7 +46,7 @@ func TestGetString(t *testing.T) {
}
func TestGetInt(t *testing.T) {
var t1 int = 1
var t1 = 1
if 1 != GetInt(t1) {
t.Error("get int from int error")
}
@ -69,7 +69,7 @@ func TestGetInt(t *testing.T) {
func TestGetInt64(t *testing.T) {
var i int64 = 1
var t1 int = 1
var t1 = 1
if i != GetInt64(t1) {
t.Error("get int64 from int error")
}
@ -91,12 +91,12 @@ func TestGetInt64(t *testing.T) {
}
func TestGetFloat64(t *testing.T) {
var f float64 = 1.11
var f = 1.11
var t1 float32 = 1.11
if f != GetFloat64(t1) {
t.Error("get float64 from float32 error")
}
var t2 float64 = 1.11
var t2 = 1.11
if f != GetFloat64(t2) {
t.Error("get float64 from float64 error")
}
@ -106,7 +106,7 @@ func TestGetFloat64(t *testing.T) {
}
var f2 float64 = 1
var t4 int = 1
var t4 = 1
if f2 != GetFloat64(t4) {
t.Error("get float64 from int error")
}
@ -130,21 +130,6 @@ func TestGetBool(t *testing.T) {
}
}
func TestGetByteArray(t *testing.T) {
var b = []byte("test")
var t1 = []byte("test")
if !byteArrayEquals(b, getByteArray(t1)) {
t.Error("get byte array from byte array error")
}
var t2 = "test"
if !byteArrayEquals(b, getByteArray(t2)) {
t.Error("get byte array from string error")
}
if nil != getByteArray(nil) {
t.Error("get byte array from nil error")
}
}
func byteArrayEquals(a []byte, b []byte) bool {
if len(a) != len(b) {
return false

81
cache/file.go vendored
View File

@ -29,23 +29,20 @@ import (
"time"
)
func init() {
Register("file", NewFileCache())
}
// FileCacheItem is basic unit of file cache adapter.
// it contains data and expire time.
type FileCacheItem struct {
Data interface{}
Lastaccess int64
Expired int64
Lastaccess time.Time
Expired time.Time
}
// FileCache Config
var (
FileCachePath string = "cache" // cache directory
FileCacheFileSuffix string = ".bin" // cache file suffix
FileCacheDirectoryLevel int = 2 // cache file deep level if auto generated cache files.
FileCacheEmbedExpiry int64 = 0 // cache expire time, default is no expire forever.
FileCachePath = "cache" // cache directory
FileCacheFileSuffix = ".bin" // cache file suffix
FileCacheDirectoryLevel = 2 // cache file deep level if auto generated cache files.
FileCacheEmbedExpiry time.Duration // cache expire time, default is no expire forever.
)
// FileCache is cache adapter for file storage.
@ -56,14 +53,14 @@ type FileCache struct {
EmbedExpiry int
}
// Create new file cache with no config.
// NewFileCache Create new file cache with no config.
// the level and expiry need set in method StartAndGC as config string.
func NewFileCache() *FileCache {
func NewFileCache() Cache {
// return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix}
return &FileCache{}
}
// Start and begin gc for file cache.
// StartAndGC will start and begin gc for file cache.
// the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0}
func (fc *FileCache) StartAndGC(config string) error {
@ -79,7 +76,7 @@ func (fc *FileCache) StartAndGC(config string) error {
cfg["DirectoryLevel"] = strconv.Itoa(FileCacheDirectoryLevel)
}
if _, ok := cfg["EmbedExpiry"]; !ok {
cfg["EmbedExpiry"] = strconv.FormatInt(FileCacheEmbedExpiry, 10)
cfg["EmbedExpiry"] = strconv.FormatInt(int64(FileCacheEmbedExpiry.Seconds()), 10)
}
fc.CachePath = cfg["CachePath"]
fc.FileSuffix = cfg["FileSuffix"]
@ -120,36 +117,46 @@ func (fc *FileCache) getCacheFileName(key string) string {
// Get value from file cache.
// if non-exist or expired, return empty string.
func (fc *FileCache) Get(key string) interface{} {
fileData, err := File_get_contents(fc.getCacheFileName(key))
fileData, err := FileGetContents(fc.getCacheFileName(key))
if err != nil {
return ""
}
var to FileCacheItem
Gob_decode(fileData, &to)
if to.Expired < time.Now().Unix() {
GobDecode(fileData, &to)
if to.Expired.Before(time.Now()) {
return ""
}
return to.Data
}
// GetMulti gets values from file cache.
// if non-exist or expired, return empty string.
func (fc *FileCache) GetMulti(keys []string) []interface{} {
var rc []interface{}
for _, key := range keys {
rc = append(rc, fc.Get(key))
}
return rc
}
// Put value into file cache.
// timeout means how long to keep this file, unit of ms.
// if timeout equals FileCacheEmbedExpiry(default is 0), cache this item forever.
func (fc *FileCache) Put(key string, val interface{}, timeout int64) error {
func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error {
gob.Register(val)
item := FileCacheItem{Data: val}
if timeout == FileCacheEmbedExpiry {
item.Expired = time.Now().Unix() + (86400 * 365 * 10) // ten years
item.Expired = time.Now().Add((86400 * 365 * 10) * time.Second) // ten years
} else {
item.Expired = time.Now().Unix() + timeout
item.Expired = time.Now().Add(timeout)
}
item.Lastaccess = time.Now().Unix()
data, err := Gob_encode(item)
item.Lastaccess = time.Now()
data, err := GobEncode(item)
if err != nil {
return err
}
return File_put_contents(fc.getCacheFileName(key), data)
return FilePutContents(fc.getCacheFileName(key), data)
}
// Delete file cache value.
@ -161,7 +168,7 @@ func (fc *FileCache) Delete(key string) error {
return nil
}
// Increase cached int value.
// Incr will increase cached int value.
// fc value is saving forever unless Delete.
func (fc *FileCache) Incr(key string) error {
data := fc.Get(key)
@ -175,7 +182,7 @@ func (fc *FileCache) Incr(key string) error {
return nil
}
// Decrease cached int value.
// Decr will decrease cached int value.
func (fc *FileCache) Decr(key string) error {
data := fc.Get(key)
var decr int
@ -188,13 +195,13 @@ func (fc *FileCache) Decr(key string) error {
return nil
}
// Check value is exist.
// IsExist check value is exist.
func (fc *FileCache) IsExist(key string) bool {
ret, _ := exists(fc.getCacheFileName(key))
return ret
}
// Clean cached files.
// ClearAll will clean cached files.
// not implemented.
func (fc *FileCache) ClearAll() error {
return nil
@ -212,9 +219,9 @@ func exists(path string) (bool, error) {
return false, err
}
// Get bytes to file.
// FileGetContents Get bytes to file.
// if non-exist, create this file.
func File_get_contents(filename string) (data []byte, e error) {
func FileGetContents(filename string) (data []byte, e error) {
f, e := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, os.ModePerm)
if e != nil {
return
@ -232,9 +239,9 @@ func File_get_contents(filename string) (data []byte, e error) {
return
}
// Put bytes to file.
// FilePutContents Put bytes to file.
// if non-exist, create this file.
func File_put_contents(filename string, content []byte) error {
func FilePutContents(filename string, content []byte) error {
fp, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, os.ModePerm)
if err != nil {
return err
@ -244,8 +251,8 @@ func File_put_contents(filename string, content []byte) error {
return err
}
// Gob encodes file cache item.
func Gob_encode(data interface{}) ([]byte, error) {
// GobEncode Gob encodes file cache item.
func GobEncode(data interface{}) ([]byte, error) {
buf := bytes.NewBuffer(nil)
enc := gob.NewEncoder(buf)
err := enc.Encode(data)
@ -255,9 +262,13 @@ func Gob_encode(data interface{}) ([]byte, error) {
return buf.Bytes(), err
}
// Gob decodes file cache item.
func Gob_decode(data []byte, to *FileCacheItem) error {
// GobDecode Gob decodes file cache item.
func GobDecode(data []byte, to *FileCacheItem) error {
buf := bytes.NewBuffer(data)
dec := gob.NewDecoder(buf)
return dec.Decode(&to)
}
func init() {
Register("file", NewFileCache)
}

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// package memcahe for cache provider
// Package memcache for cache provider
//
// depend on github.com/bradfitz/gomemcache/memcache
//
@ -36,22 +36,24 @@ import (
"github.com/bradfitz/gomemcache/memcache"
"time"
"github.com/astaxie/beego/cache"
)
// Memcache adapter.
type MemcacheCache struct {
// Cache Memcache adapter.
type Cache struct {
conn *memcache.Client
conninfo []string
}
// create new memcache adapter.
func NewMemCache() *MemcacheCache {
return &MemcacheCache{}
// NewMemCache create new memcache adapter.
func NewMemCache() cache.Cache {
return &Cache{}
}
// get value from memcache.
func (rc *MemcacheCache) Get(key string) interface{} {
// Get get value from memcache.
func (rc *Cache) Get(key string) interface{} {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return err
@ -63,8 +65,33 @@ func (rc *MemcacheCache) Get(key string) interface{} {
return nil
}
// put value to memcache. only support string.
func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
// GetMulti get value from memcache.
func (rc *Cache) GetMulti(keys []string) []interface{} {
size := len(keys)
var rv []interface{}
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
for i := 0; i < size; i++ {
rv = append(rv, err)
}
return rv
}
}
mv, err := rc.conn.GetMulti(keys)
if err == nil {
for _, v := range mv {
rv = append(rv, string(v.Value))
}
return rv
}
for i := 0; i < size; i++ {
rv = append(rv, err)
}
return rv
}
// Put put value to memcache. only support string.
func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return err
@ -74,12 +101,12 @@ func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
if !ok {
return errors.New("val must string")
}
item := memcache.Item{Key: key, Value: []byte(v), Expiration: int32(timeout)}
item := memcache.Item{Key: key, Value: []byte(v), Expiration: int32(timeout / time.Second)}
return rc.conn.Set(&item)
}
// delete value in memcache.
func (rc *MemcacheCache) Delete(key string) error {
// Delete delete value in memcache.
func (rc *Cache) Delete(key string) error {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return err
@ -88,8 +115,8 @@ func (rc *MemcacheCache) Delete(key string) error {
return rc.conn.Delete(key)
}
// increase counter.
func (rc *MemcacheCache) Incr(key string) error {
// Incr increase counter.
func (rc *Cache) Incr(key string) error {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return err
@ -99,8 +126,8 @@ func (rc *MemcacheCache) Incr(key string) error {
return err
}
// decrease counter.
func (rc *MemcacheCache) Decr(key string) error {
// Decr decrease counter.
func (rc *Cache) Decr(key string) error {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return err
@ -110,8 +137,8 @@ func (rc *MemcacheCache) Decr(key string) error {
return err
}
// check value exists in memcache.
func (rc *MemcacheCache) IsExist(key string) bool {
// IsExist check value exists in memcache.
func (rc *Cache) IsExist(key string) bool {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return false
@ -124,8 +151,8 @@ func (rc *MemcacheCache) IsExist(key string) bool {
return true
}
// clear all cached in memcache.
func (rc *MemcacheCache) ClearAll() error {
// ClearAll clear all cached in memcache.
func (rc *Cache) ClearAll() error {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return err
@ -134,10 +161,10 @@ func (rc *MemcacheCache) ClearAll() error {
return rc.conn.FlushAll()
}
// start memcache adapter.
// StartAndGC start memcache adapter.
// config string is like {"conn":"connection info"}.
// if connecting error, return.
func (rc *MemcacheCache) StartAndGC(config string) error {
func (rc *Cache) StartAndGC(config string) error {
var cf map[string]string
json.Unmarshal([]byte(config), &cf)
if _, ok := cf["conn"]; !ok {
@ -153,11 +180,11 @@ func (rc *MemcacheCache) StartAndGC(config string) error {
}
// connect to memcache and keep the connection.
func (rc *MemcacheCache) connectInit() error {
func (rc *Cache) connectInit() error {
rc.conn = memcache.New(rc.conninfo...)
return nil
}
func init() {
cache.Register("memcache", NewMemCache())
cache.Register("memcache", NewMemCache)
}

108
cache/memcache/memcache_test.go vendored Normal file
View File

@ -0,0 +1,108 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package memcache
import (
_ "github.com/bradfitz/gomemcache/memcache"
"strconv"
"testing"
"time"
"github.com/astaxie/beego/cache"
)
func TestMemcacheCache(t *testing.T) {
bm, err := cache.NewCache("memcache", `{"conn": "127.0.0.1:11211"}`)
if err != nil {
t.Error("init err")
}
timeoutDuration := 10 * time.Second
if err = bm.Put("astaxie", "1", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
t.Error("check err")
}
time.Sleep(10 * time.Second)
if bm.IsExist("astaxie") {
t.Error("check err")
}
if err = bm.Put("astaxie", "1", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if v, err := strconv.Atoi(bm.Get("astaxie").(string)); err != nil || v != 1 {
t.Error("get err")
}
if err = bm.Incr("astaxie"); err != nil {
t.Error("Incr Error", err)
}
if v, err := strconv.Atoi(bm.Get("astaxie").(string)); err != nil || v != 2 {
t.Error("get err")
}
if err = bm.Decr("astaxie"); err != nil {
t.Error("Decr Error", err)
}
if v, err := strconv.Atoi(bm.Get("astaxie").(string)); err != nil || v != 1 {
t.Error("get err")
}
bm.Delete("astaxie")
if bm.IsExist("astaxie") {
t.Error("delete err")
}
//test string
if err = bm.Put("astaxie", "author", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
t.Error("check err")
}
if v := bm.Get("astaxie").(string); v != "author" {
t.Error("get err")
}
//test GetMulti
if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie1") {
t.Error("check err")
}
vv := bm.GetMulti([]string{"astaxie", "astaxie1"})
if len(vv) != 2 {
t.Error("GetMulti ERROR")
}
if vv[0].(string) != "author" && vv[0].(string) != "author1" {
t.Error("GetMulti ERROR")
}
if vv[1].(string) != "author1" && vv[1].(string) != "author" {
t.Error("GetMulti ERROR")
}
// test clear all
if err = bm.ClearAll(); err != nil {
t.Error("clear all err")
}
}

122
cache/memory.go vendored
View File

@ -17,34 +17,41 @@ package cache
import (
"encoding/json"
"errors"
"fmt"
"sync"
"time"
)
var (
// clock time of recycling the expired cache items in memory.
DefaultEvery int = 60 // 1 minute
// DefaultEvery means the clock time of recycling the expired cache items in memory.
DefaultEvery = 60 // 1 minute
)
// Memory cache item.
// MemoryItem store memory cache item.
type MemoryItem struct {
val interface{}
Lastaccess time.Time
expired int64
createdTime time.Time
lifespan time.Duration
}
// Memory cache adapter.
func (mi *MemoryItem) isExpire() bool {
// 0 means forever
if mi.lifespan == 0 {
return false
}
return time.Now().Sub(mi.createdTime) > mi.lifespan
}
// MemoryCache is Memory cache adapter.
// it contains a RW locker for safe map storage.
type MemoryCache struct {
lock sync.RWMutex
sync.RWMutex
dur time.Duration
items map[string]*MemoryItem
Every int // run an expiration check Every clock time
}
// NewMemoryCache returns a new MemoryCache.
func NewMemoryCache() *MemoryCache {
func NewMemoryCache() Cache {
cache := MemoryCache{items: make(map[string]*MemoryItem)}
return &cache
}
@ -52,11 +59,10 @@ func NewMemoryCache() *MemoryCache {
// Get cache from memory.
// if non-existed or expired, return nil.
func (bc *MemoryCache) Get(name string) interface{} {
bc.lock.RLock()
defer bc.lock.RUnlock()
bc.RLock()
defer bc.RUnlock()
if itm, ok := bc.items[name]; ok {
if (time.Now().Unix() - itm.Lastaccess.Unix()) > itm.expired {
go bc.Delete(name)
if itm.isExpire() {
return nil
}
return itm.val
@ -64,23 +70,33 @@ func (bc *MemoryCache) Get(name string) interface{} {
return nil
}
// GetMulti gets caches from memory.
// if non-existed or expired, return nil.
func (bc *MemoryCache) GetMulti(names []string) []interface{} {
var rc []interface{}
for _, name := range names {
rc = append(rc, bc.Get(name))
}
return rc
}
// Put cache to memory.
// if expired is 0, it will be cleaned by next gc operation ( default gc clock is 1 minute).
func (bc *MemoryCache) Put(name string, value interface{}, expired int64) error {
bc.lock.Lock()
defer bc.lock.Unlock()
// if lifespan is 0, it will be forever till restart.
func (bc *MemoryCache) Put(name string, value interface{}, lifespan time.Duration) error {
bc.Lock()
defer bc.Unlock()
bc.items[name] = &MemoryItem{
val: value,
Lastaccess: time.Now(),
expired: expired,
createdTime: time.Now(),
lifespan: lifespan,
}
return nil
}
/// Delete cache in memory.
// Delete cache in memory.
func (bc *MemoryCache) Delete(name string) error {
bc.lock.Lock()
defer bc.lock.Unlock()
bc.Lock()
defer bc.Unlock()
if _, ok := bc.items[name]; !ok {
return errors.New("key not exist")
}
@ -91,11 +107,11 @@ func (bc *MemoryCache) Delete(name string) error {
return nil
}
// Increase cache counter in memory.
// it supports int,int64,int32,uint,uint64,uint32.
// Incr increase cache counter in memory.
// it supports int,int32,int64,uint,uint32,uint64.
func (bc *MemoryCache) Incr(key string) error {
bc.lock.RLock()
defer bc.lock.RUnlock()
bc.RLock()
defer bc.RUnlock()
itm, ok := bc.items[key]
if !ok {
return errors.New("key not exist")
@ -103,10 +119,10 @@ func (bc *MemoryCache) Incr(key string) error {
switch itm.val.(type) {
case int:
itm.val = itm.val.(int) + 1
case int64:
itm.val = itm.val.(int64) + 1
case int32:
itm.val = itm.val.(int32) + 1
case int64:
itm.val = itm.val.(int64) + 1
case uint:
itm.val = itm.val.(uint) + 1
case uint32:
@ -114,15 +130,15 @@ func (bc *MemoryCache) Incr(key string) error {
case uint64:
itm.val = itm.val.(uint64) + 1
default:
return errors.New("item val is not int int64 int32")
return errors.New("item val is not (u)int (u)int32 (u)int64")
}
return nil
}
// Decrease counter in memory.
// Decr decrease counter in memory.
func (bc *MemoryCache) Decr(key string) error {
bc.lock.RLock()
defer bc.lock.RUnlock()
bc.RLock()
defer bc.RUnlock()
itm, ok := bc.items[key]
if !ok {
return errors.New("key not exist")
@ -158,23 +174,25 @@ func (bc *MemoryCache) Decr(key string) error {
return nil
}
// check cache exist in memory.
// IsExist check cache exist in memory.
func (bc *MemoryCache) IsExist(name string) bool {
bc.lock.RLock()
defer bc.lock.RUnlock()
_, ok := bc.items[name]
return ok
bc.RLock()
defer bc.RUnlock()
if v, ok := bc.items[name]; ok {
return !v.isExpire()
}
return false
}
// delete all cache in memory.
// ClearAll will delete all cache in memory.
func (bc *MemoryCache) ClearAll() error {
bc.lock.Lock()
defer bc.lock.Unlock()
bc.Lock()
defer bc.Unlock()
bc.items = make(map[string]*MemoryItem)
return nil
}
// start memory cache. it will check expiration in every clock time.
// StartAndGC start memory cache. it will check expiration in every clock time.
func (bc *MemoryCache) StartAndGC(config string) error {
var cf map[string]int
json.Unmarshal([]byte(config), &cf)
@ -182,10 +200,7 @@ func (bc *MemoryCache) StartAndGC(config string) error {
cf = make(map[string]int)
cf["interval"] = DefaultEvery
}
dur, err := time.ParseDuration(fmt.Sprintf("%ds", cf["interval"]))
if err != nil {
return err
}
dur := time.Duration(cf["interval"]) * time.Second
bc.Every = cf["interval"]
bc.dur = dur
go bc.vaccuum()
@ -202,21 +217,22 @@ func (bc *MemoryCache) vaccuum() {
if bc.items == nil {
return
}
for name, _ := range bc.items {
bc.item_expired(name)
for name := range bc.items {
bc.itemExpired(name)
}
}
}
// item_expired returns true if an item is expired.
func (bc *MemoryCache) item_expired(name string) bool {
bc.lock.Lock()
defer bc.lock.Unlock()
// itemExpired returns true if an item is expired.
func (bc *MemoryCache) itemExpired(name string) bool {
bc.Lock()
defer bc.Unlock()
itm, ok := bc.items[name]
if !ok {
return true
}
if time.Now().Unix()-itm.Lastaccess.Unix() >= itm.expired {
if itm.isExpire() {
delete(bc.items, name)
return true
}
@ -224,5 +240,5 @@ func (bc *MemoryCache) item_expired(name string) bool {
}
func init() {
Register("memory", NewMemoryCache())
Register("memory", NewMemoryCache)
}

105
cache/redis/redis.go vendored
View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// package redis for cache provider
// Package redis for cache provider
//
// depend on github.com/garyburd/redigo/redis
//
@ -41,25 +41,26 @@ import (
)
var (
// the collection name of redis for cache adapter.
DefaultKey string = "beecacheRedis"
// DefaultKey the collection name of redis for cache adapter.
DefaultKey = "beecacheRedis"
)
// Redis cache adapter.
type RedisCache struct {
// Cache is Redis cache adapter.
type Cache struct {
p *redis.Pool // redis connection pool
conninfo string
dbNum int
key string
password string
}
// create new redis cache with default collection name.
func NewRedisCache() *RedisCache {
return &RedisCache{key: DefaultKey}
// NewRedisCache create new redis cache with default collection name.
func NewRedisCache() cache.Cache {
return &Cache{key: DefaultKey}
}
// actually do the redis cmds
func (rc *RedisCache) do(commandName string, args ...interface{}) (reply interface{}, err error) {
func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) {
c := rc.p.Get()
defer c.Close()
@ -67,17 +68,50 @@ func (rc *RedisCache) do(commandName string, args ...interface{}) (reply interfa
}
// Get cache from redis.
func (rc *RedisCache) Get(key string) interface{} {
func (rc *Cache) Get(key string) interface{} {
if v, err := rc.do("GET", key); err == nil {
return v
}
return nil
}
// put cache to redis.
func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error {
// GetMulti get cache from redis.
func (rc *Cache) GetMulti(keys []string) []interface{} {
size := len(keys)
var rv []interface{}
c := rc.p.Get()
defer c.Close()
var err error
if _, err = rc.do("SETEX", key, timeout, val); err != nil {
for _, key := range keys {
err = c.Send("GET", key)
if err != nil {
goto ERROR
}
}
if err = c.Flush(); err != nil {
goto ERROR
}
for i := 0; i < size; i++ {
if v, err := c.Receive(); err == nil {
rv = append(rv, v.([]byte))
} else {
rv = append(rv, err)
}
}
return rv
ERROR:
rv = rv[0:0]
for i := 0; i < size; i++ {
rv = append(rv, nil)
}
return rv
}
// Put put cache to redis.
func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error {
var err error
if _, err = rc.do("SETEX", key, int64(timeout/time.Second), val); err != nil {
return err
}
@ -87,8 +121,8 @@ func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error {
return err
}
// delete cache in redis.
func (rc *RedisCache) Delete(key string) error {
// Delete delete cache in redis.
func (rc *Cache) Delete(key string) error {
var err error
if _, err = rc.do("DEL", key); err != nil {
return err
@ -97,8 +131,8 @@ func (rc *RedisCache) Delete(key string) error {
return err
}
// check cache's existence in redis.
func (rc *RedisCache) IsExist(key string) bool {
// IsExist check cache's existence in redis.
func (rc *Cache) IsExist(key string) bool {
v, err := redis.Bool(rc.do("EXISTS", key))
if err != nil {
return false
@ -111,20 +145,20 @@ func (rc *RedisCache) IsExist(key string) bool {
return v
}
// increase counter in redis.
func (rc *RedisCache) Incr(key string) error {
// Incr increase counter in redis.
func (rc *Cache) Incr(key string) error {
_, err := redis.Bool(rc.do("INCRBY", key, 1))
return err
}
// decrease counter in redis.
func (rc *RedisCache) Decr(key string) error {
// Decr decrease counter in redis.
func (rc *Cache) Decr(key string) error {
_, err := redis.Bool(rc.do("INCRBY", key, -1))
return err
}
// clean all cache in redis. delete this redis collection.
func (rc *RedisCache) ClearAll() error {
// ClearAll clean all cache in redis. delete this redis collection.
func (rc *Cache) ClearAll() error {
cachedKeys, err := redis.Strings(rc.do("HKEYS", rc.key))
if err != nil {
return err
@ -138,27 +172,31 @@ func (rc *RedisCache) ClearAll() error {
return err
}
// start redis cache adapter.
// StartAndGC start redis cache adapter.
// config is like {"key":"collection key","conn":"connection info","dbNum":"0"}
// the cache item in redis are stored forever,
// so no gc operation.
func (rc *RedisCache) StartAndGC(config string) error {
func (rc *Cache) StartAndGC(config string) error {
var cf map[string]string
json.Unmarshal([]byte(config), &cf)
if _, ok := cf["key"]; !ok {
cf["key"] = DefaultKey
}
if _, ok := cf["conn"]; !ok {
return errors.New("config has no conn key")
}
if _, ok := cf["dbNum"]; !ok {
cf["dbNum"] = "0"
}
if _, ok := cf["password"]; !ok {
cf["password"] = ""
}
rc.key = cf["key"]
rc.conninfo = cf["conn"]
rc.dbNum, _ = strconv.Atoi(cf["dbNum"])
rc.password = cf["password"]
rc.connectInit()
c := rc.p.Get()
@ -168,9 +206,20 @@ func (rc *RedisCache) StartAndGC(config string) error {
}
// connect to redis.
func (rc *RedisCache) connectInit() {
func (rc *Cache) connectInit() {
dialFunc := func() (c redis.Conn, err error) {
c, err = redis.Dial("tcp", rc.conninfo)
if err != nil {
return nil, err
}
if rc.password != "" {
if _, err := c.Do("AUTH", rc.password); err != nil {
c.Close()
return nil, err
}
}
_, selecterr := c.Do("SELECT", rc.dbNum)
if selecterr != nil {
c.Close()
@ -187,5 +236,5 @@ func (rc *RedisCache) connectInit() {
}
func init() {
cache.Register("redis", NewRedisCache())
cache.Register("redis", NewRedisCache)
}

View File

@ -28,19 +28,20 @@ func TestRedisCache(t *testing.T) {
if err != nil {
t.Error("init err")
}
if err = bm.Put("astaxie", 1, 10); err != nil {
timeoutDuration := 10 * time.Second
if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
t.Error("check err")
}
time.Sleep(10 * time.Second)
time.Sleep(11 * time.Second)
if bm.IsExist("astaxie") {
t.Error("check err")
}
if err = bm.Put("astaxie", 1, 10); err != nil {
if err = bm.Put("astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err)
}
@ -67,8 +68,9 @@ func TestRedisCache(t *testing.T) {
if bm.IsExist("astaxie") {
t.Error("delete err")
}
//test string
if err = bm.Put("astaxie", "author", 10); err != nil {
if err = bm.Put("astaxie", "author", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
@ -78,6 +80,26 @@ func TestRedisCache(t *testing.T) {
if v, _ := redis.String(bm.Get("astaxie"), err); v != "author" {
t.Error("get err")
}
//test GetMulti
if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie1") {
t.Error("check err")
}
vv := bm.GetMulti([]string{"astaxie", "astaxie1"})
if len(vv) != 2 {
t.Error("GetMulti ERROR")
}
if v, _ := redis.String(vv[0], nil); v != "author" {
t.Error("GetMulti ERROR")
}
if v, _ := redis.String(vv[1], nil); v != "author1" {
t.Error("GetMulti ERROR")
}
// test clear all
if err = bm.ClearAll(); err != nil {
t.Error("clear all err")

679
config.go
View File

@ -15,77 +15,269 @@
package beego
import (
"fmt"
"html/template"
"os"
"path/filepath"
"runtime"
"strings"
"github.com/astaxie/beego/config"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/session"
"github.com/astaxie/beego/utils"
)
var (
BeeApp *App // beego application
AppName string
AppPath string
workPath string
AppConfigPath string
StaticDir map[string]string
TemplateCache map[string]*template.Template // template caching map
StaticExtensionsToGzip []string // files with should be compressed with gzip (.js,.css,etc)
EnableHttpListen bool
HttpAddr string
HttpPort int
ListenTCP4 bool
EnableHttpTLS bool
HttpsPort int
HttpCertFile string
HttpKeyFile string
RecoverPanic bool // flag of auto recover panic
AutoRender bool // flag of render template automatically
ViewsPath string
AppConfig *beegoAppConfig
RunMode string // run mode, "dev" or "prod"
GlobalSessions *session.Manager // global session mananger
SessionOn bool // flag of starting session auto. default is false.
SessionProvider string // default session provider, memory, mysql , redis ,etc.
SessionName string // the cookie name when saving session id into cookie.
SessionGCMaxLifetime int64 // session gc time for auto cleaning expired session.
SessionSavePath string // if use mysql/redis/file provider, define save path to connection info.
SessionCookieLifeTime int // the life time of session id in cookie.
SessionAutoSetCookie bool // auto setcookie
SessionDomain string // the cookie domain default is empty
UseFcgi bool
UseStdIo bool
// Config is the main struct for BConfig
type Config struct {
AppName string //Application name
RunMode string //Running Mode: dev | prod
RouterCaseSensitive bool
ServerName string
RecoverPanic bool
CopyRequestBody bool
EnableGzip bool
MaxMemory int64
EnableGzip bool // flag of enable gzip
DirectoryIndex bool // flag of display directory index. default is false.
HttpServerTimeOut int64
ErrorsShow bool // flag of show errors in page. if true, show error and trace info in page rendered with error template.
XSRFKEY string // xsrf hash salt string.
EnableXSRF bool // flag of enable xsrf.
XSRFExpire int // the expiry of xsrf value.
CopyRequestBody bool // flag of copy raw request body in context.
EnableErrorsShow bool
Listen Listen
WebConfig WebConfig
Log LogConfig
}
// Listen holds for http and https related config
type Listen struct {
Graceful bool // Graceful means use graceful module to start the server
ServerTimeOut int64
ListenTCP4 bool
EnableHTTP bool
HTTPAddr string
HTTPPort int
EnableHTTPS bool
HTTPSAddr string
HTTPSPort int
HTTPSCertFile string
HTTPSKeyFile string
EnableAdmin bool
AdminAddr string
AdminPort int
EnableFcgi bool
EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O
}
// WebConfig holds web related config
type WebConfig struct {
AutoRender bool
EnableDocs bool
FlashName string
FlashSeparator string
DirectoryIndex bool
StaticDir map[string]string
StaticExtensionsToGzip []string
TemplateLeft string
TemplateRight string
BeegoServerName string // beego server name exported in response header.
EnableAdmin bool // flag of enable admin module to log every request info.
AdminHttpAddr string // http server configurations for admin module.
AdminHttpPort int
FlashName string // name of the flash variable found in response header and cookie
FlashSeperator string // used to seperate flash key:value
AppConfigProvider string // config provider
EnableDocs bool // enable generate docs & server docs API Swagger
RouterCaseSensitive bool // router case sensitive default is true
AccessLogs bool // print access logs, default is false
ViewsPath string
EnableXSRF bool
XSRFKey string
XSRFExpire int
Session SessionConfig
}
// SessionConfig holds session related config
type SessionConfig struct {
SessionOn bool
SessionProvider string
SessionName string
SessionGCMaxLifetime int64
SessionProviderConfig string
SessionCookieLifeTime int
SessionAutoSetCookie bool
SessionDomain string
}
// LogConfig holds Log related config
type LogConfig struct {
AccessLogs bool
FileLineNum bool
Outputs map[string]string // Store Adaptor : config
}
var (
// BConfig is the default config for Application
BConfig *Config
// AppConfig is the instance of Config, store the config information from file
AppConfig *beegoAppConfig
// AppConfigPath is the path to the config files
AppConfigPath string
// AppConfigProvider is the provider for the config, default is ini
AppConfigProvider = "ini"
// TemplateCache stores template caching
TemplateCache map[string]*template.Template
// GlobalSessions is the instance for the session manager
GlobalSessions *session.Manager
)
func init() {
BConfig = &Config{
AppName: "beego",
RunMode: DEV,
RouterCaseSensitive: true,
ServerName: "beegoServer:" + VERSION,
RecoverPanic: true,
CopyRequestBody: false,
EnableGzip: false,
MaxMemory: 1 << 26, //64MB
EnableErrorsShow: true,
Listen: Listen{
Graceful: false,
ServerTimeOut: 0,
ListenTCP4: false,
EnableHTTP: true,
HTTPAddr: "",
HTTPPort: 8080,
EnableHTTPS: false,
HTTPSAddr: "",
HTTPSPort: 10443,
HTTPSCertFile: "",
HTTPSKeyFile: "",
EnableAdmin: false,
AdminAddr: "",
AdminPort: 8088,
EnableFcgi: false,
EnableStdIo: false,
},
WebConfig: WebConfig{
AutoRender: true,
EnableDocs: false,
FlashName: "BEEGO_FLASH",
FlashSeparator: "BEEGOFLASH",
DirectoryIndex: false,
StaticDir: map[string]string{"/static": "static"},
StaticExtensionsToGzip: []string{".css", ".js"},
TemplateLeft: "{{",
TemplateRight: "}}",
ViewsPath: "views",
EnableXSRF: false,
XSRFKey: "beegoxsrf",
XSRFExpire: 0,
Session: SessionConfig{
SessionOn: false,
SessionProvider: "memory",
SessionName: "beegosessionID",
SessionGCMaxLifetime: 3600,
SessionProviderConfig: "",
SessionCookieLifeTime: 0, //set cookie default is the brower life
SessionAutoSetCookie: true,
SessionDomain: "",
},
},
Log: LogConfig{
AccessLogs: false,
FileLineNum: true,
Outputs: map[string]string{"console": ""},
},
}
ParseConfig()
}
// ParseConfig parsed default config file.
// now only support ini, next will support json.
func ParseConfig() (err error) {
if AppConfigPath == "" {
if utils.FileExists(filepath.Join("conf", "app.conf")) {
AppConfigPath = filepath.Join("conf", "app.conf")
} else {
AppConfig = &beegoAppConfig{config.NewFakeConfig()}
return
}
}
AppConfig, err = newAppConfig(AppConfigProvider, AppConfigPath)
if err != nil {
return err
}
// set the runmode first
if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" {
BConfig.RunMode = envRunMode
} else if runmode := AppConfig.String("RunMode"); runmode != "" {
BConfig.RunMode = runmode
}
BConfig.AppName = AppConfig.DefaultString("AppName", BConfig.AppName)
BConfig.RecoverPanic = AppConfig.DefaultBool("RecoverPanic", BConfig.RecoverPanic)
BConfig.RouterCaseSensitive = AppConfig.DefaultBool("RouterCaseSensitive", BConfig.RouterCaseSensitive)
BConfig.ServerName = AppConfig.DefaultString("ServerName", BConfig.ServerName)
BConfig.EnableGzip = AppConfig.DefaultBool("EnableGzip", BConfig.EnableGzip)
BConfig.EnableErrorsShow = AppConfig.DefaultBool("EnableErrorsShow", BConfig.EnableErrorsShow)
BConfig.CopyRequestBody = AppConfig.DefaultBool("CopyRequestBody", BConfig.CopyRequestBody)
BConfig.MaxMemory = AppConfig.DefaultInt64("MaxMemory", BConfig.MaxMemory)
BConfig.Listen.Graceful = AppConfig.DefaultBool("Graceful", BConfig.Listen.Graceful)
BConfig.Listen.HTTPAddr = AppConfig.String("HTTPAddr")
BConfig.Listen.HTTPPort = AppConfig.DefaultInt("HTTPPort", BConfig.Listen.HTTPPort)
BConfig.Listen.ListenTCP4 = AppConfig.DefaultBool("ListenTCP4", BConfig.Listen.ListenTCP4)
BConfig.Listen.EnableHTTP = AppConfig.DefaultBool("EnableHTTP", BConfig.Listen.EnableHTTP)
BConfig.Listen.EnableHTTPS = AppConfig.DefaultBool("EnableHTTPS", BConfig.Listen.EnableHTTPS)
BConfig.Listen.HTTPSAddr = AppConfig.DefaultString("HTTPSAddr", BConfig.Listen.HTTPSAddr)
BConfig.Listen.HTTPSPort = AppConfig.DefaultInt("HTTPSPort", BConfig.Listen.HTTPSPort)
BConfig.Listen.HTTPSCertFile = AppConfig.DefaultString("HTTPSCertFile", BConfig.Listen.HTTPSCertFile)
BConfig.Listen.HTTPSKeyFile = AppConfig.DefaultString("HTTPSKeyFile", BConfig.Listen.HTTPSKeyFile)
BConfig.Listen.EnableAdmin = AppConfig.DefaultBool("EnableAdmin", BConfig.Listen.EnableAdmin)
BConfig.Listen.AdminAddr = AppConfig.DefaultString("AdminAddr", BConfig.Listen.AdminAddr)
BConfig.Listen.AdminPort = AppConfig.DefaultInt("AdminPort", BConfig.Listen.AdminPort)
BConfig.Listen.EnableFcgi = AppConfig.DefaultBool("EnableFcgi", BConfig.Listen.EnableFcgi)
BConfig.Listen.EnableStdIo = AppConfig.DefaultBool("EnableStdIo", BConfig.Listen.EnableStdIo)
BConfig.Listen.ServerTimeOut = AppConfig.DefaultInt64("ServerTimeOut", BConfig.Listen.ServerTimeOut)
BConfig.WebConfig.AutoRender = AppConfig.DefaultBool("AutoRender", BConfig.WebConfig.AutoRender)
BConfig.WebConfig.ViewsPath = AppConfig.DefaultString("ViewsPath", BConfig.WebConfig.ViewsPath)
BConfig.WebConfig.DirectoryIndex = AppConfig.DefaultBool("DirectoryIndex", BConfig.WebConfig.DirectoryIndex)
BConfig.WebConfig.FlashName = AppConfig.DefaultString("FlashName", BConfig.WebConfig.FlashName)
BConfig.WebConfig.FlashSeparator = AppConfig.DefaultString("FlashSeparator", BConfig.WebConfig.FlashSeparator)
BConfig.WebConfig.EnableDocs = AppConfig.DefaultBool("EnableDocs", BConfig.WebConfig.EnableDocs)
BConfig.WebConfig.XSRFKey = AppConfig.DefaultString("XSRFKEY", BConfig.WebConfig.XSRFKey)
BConfig.WebConfig.EnableXSRF = AppConfig.DefaultBool("EnableXSRF", BConfig.WebConfig.EnableXSRF)
BConfig.WebConfig.XSRFExpire = AppConfig.DefaultInt("XSRFExpire", BConfig.WebConfig.XSRFExpire)
BConfig.WebConfig.TemplateLeft = AppConfig.DefaultString("TemplateLeft", BConfig.WebConfig.TemplateLeft)
BConfig.WebConfig.TemplateRight = AppConfig.DefaultString("TemplateRight", BConfig.WebConfig.TemplateRight)
BConfig.WebConfig.Session.SessionOn = AppConfig.DefaultBool("SessionOn", BConfig.WebConfig.Session.SessionOn)
BConfig.WebConfig.Session.SessionProvider = AppConfig.DefaultString("SessionProvider", BConfig.WebConfig.Session.SessionProvider)
BConfig.WebConfig.Session.SessionName = AppConfig.DefaultString("SessionName", BConfig.WebConfig.Session.SessionName)
BConfig.WebConfig.Session.SessionProviderConfig = AppConfig.DefaultString("SessionProviderConfig", BConfig.WebConfig.Session.SessionProviderConfig)
BConfig.WebConfig.Session.SessionGCMaxLifetime = AppConfig.DefaultInt64("SessionGCMaxLifetime", BConfig.WebConfig.Session.SessionGCMaxLifetime)
BConfig.WebConfig.Session.SessionCookieLifeTime = AppConfig.DefaultInt("SessionCookieLifeTime", BConfig.WebConfig.Session.SessionCookieLifeTime)
BConfig.WebConfig.Session.SessionAutoSetCookie = AppConfig.DefaultBool("SessionAutoSetCookie", BConfig.WebConfig.Session.SessionAutoSetCookie)
BConfig.WebConfig.Session.SessionDomain = AppConfig.DefaultString("SessionDomain", BConfig.WebConfig.Session.SessionDomain)
if sd := AppConfig.String("StaticDir"); sd != "" {
for k := range BConfig.WebConfig.StaticDir {
delete(BConfig.WebConfig.StaticDir, k)
}
sds := strings.Fields(sd)
for _, v := range sds {
if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 {
BConfig.WebConfig.StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[1]
} else {
BConfig.WebConfig.StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[0]
}
}
}
if sgz := AppConfig.String("StaticExtensionsToGzip"); sgz != "" {
extensions := strings.Split(sgz, ",")
fileExts := []string{}
for _, ext := range extensions {
ext = strings.TrimSpace(ext)
if ext == "" {
continue
}
if !strings.HasPrefix(ext, ".") {
ext = "." + ext
}
fileExts = append(fileExts, ext)
}
if len(fileExts) > 0 {
BConfig.WebConfig.StaticExtensionsToGzip = fileExts
}
}
return nil
}
type beegoAppConfig struct {
innerConfig config.ConfigContainer
innerConfig config.Configer
}
func newAppConfig(AppConfigProvider, AppConfigPath string) (*beegoAppConfig, error) {
@ -93,84 +285,98 @@ func newAppConfig(AppConfigProvider, AppConfigPath string) (*beegoAppConfig, err
if err != nil {
return nil, err
}
rac := &beegoAppConfig{ac}
return rac, nil
return &beegoAppConfig{ac}, nil
}
func (b *beegoAppConfig) Set(key, val string) error {
if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil {
return err
}
return b.innerConfig.Set(key, val)
}
func (b *beegoAppConfig) String(key string) string {
v := b.innerConfig.String(RunMode + "::" + key)
if v == "" {
return b.innerConfig.String(key)
}
if v := b.innerConfig.String(BConfig.RunMode + "::" + key); v != "" {
return v
}
return b.innerConfig.String(key)
}
func (b *beegoAppConfig) Strings(key string) []string {
v := b.innerConfig.Strings(RunMode + "::" + key)
if v[0] == "" {
return b.innerConfig.Strings(key)
}
if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); v[0] != "" {
return v
}
return b.innerConfig.Strings(key)
}
func (b *beegoAppConfig) Int(key string) (int, error) {
v, err := b.innerConfig.Int(RunMode + "::" + key)
if err != nil {
return b.innerConfig.Int(key)
}
if v, err := b.innerConfig.Int(BConfig.RunMode + "::" + key); err == nil {
return v, nil
}
return b.innerConfig.Int(key)
}
func (b *beegoAppConfig) Int64(key string) (int64, error) {
v, err := b.innerConfig.Int64(RunMode + "::" + key)
if err != nil {
return b.innerConfig.Int64(key)
}
if v, err := b.innerConfig.Int64(BConfig.RunMode + "::" + key); err == nil {
return v, nil
}
return b.innerConfig.Int64(key)
}
func (b *beegoAppConfig) Bool(key string) (bool, error) {
v, err := b.innerConfig.Bool(RunMode + "::" + key)
if err != nil {
return b.innerConfig.Bool(key)
}
if v, err := b.innerConfig.Bool(BConfig.RunMode + "::" + key); err == nil {
return v, nil
}
return b.innerConfig.Bool(key)
}
func (b *beegoAppConfig) Float(key string) (float64, error) {
v, err := b.innerConfig.Float(RunMode + "::" + key)
if err != nil {
return b.innerConfig.Float(key)
}
if v, err := b.innerConfig.Float(BConfig.RunMode + "::" + key); err == nil {
return v, nil
}
return b.innerConfig.Float(key)
}
func (b *beegoAppConfig) DefaultString(key string, defaultval string) string {
return b.innerConfig.DefaultString(key, defaultval)
if v := b.String(key); v != "" {
return v
}
return defaultval
}
func (b *beegoAppConfig) DefaultStrings(key string, defaultval []string) []string {
return b.innerConfig.DefaultStrings(key, defaultval)
if v := b.Strings(key); len(v) != 0 {
return v
}
return defaultval
}
func (b *beegoAppConfig) DefaultInt(key string, defaultval int) int {
return b.innerConfig.DefaultInt(key, defaultval)
if v, err := b.Int(key); err == nil {
return v
}
return defaultval
}
func (b *beegoAppConfig) DefaultInt64(key string, defaultval int64) int64 {
return b.innerConfig.DefaultInt64(key, defaultval)
if v, err := b.Int64(key); err == nil {
return v
}
return defaultval
}
func (b *beegoAppConfig) DefaultBool(key string, defaultval bool) bool {
return b.innerConfig.DefaultBool(key, defaultval)
if v, err := b.Bool(key); err == nil {
return v
}
return defaultval
}
func (b *beegoAppConfig) DefaultFloat(key string, defaultval float64) float64 {
return b.innerConfig.DefaultFloat(key, defaultval)
if v, err := b.Float(key); err == nil {
return v
}
return defaultval
}
func (b *beegoAppConfig) DIY(key string) (interface{}, error) {
@ -184,302 +390,3 @@ func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) {
func (b *beegoAppConfig) SaveConfigFile(filename string) error {
return b.innerConfig.SaveConfigFile(filename)
}
func init() {
// create beego application
BeeApp = NewApp()
workPath, _ = os.Getwd()
workPath, _ = filepath.Abs(workPath)
// initialize default configurations
AppPath, _ = filepath.Abs(filepath.Dir(os.Args[0]))
AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
if workPath != AppPath {
if utils.FileExists(AppConfigPath) {
os.Chdir(AppPath)
} else {
AppConfigPath = filepath.Join(workPath, "conf", "app.conf")
}
}
AppConfigProvider = "ini"
StaticDir = make(map[string]string)
StaticDir["/static"] = "static"
StaticExtensionsToGzip = []string{".css", ".js"}
TemplateCache = make(map[string]*template.Template)
// set this to 0.0.0.0 to make this app available to externally
EnableHttpListen = true //default enable http Listen
HttpAddr = ""
HttpPort = 8080
HttpsPort = 10443
AppName = "beego"
RunMode = "dev" //default runmod
AutoRender = true
RecoverPanic = true
ViewsPath = "views"
SessionOn = false
SessionProvider = "memory"
SessionName = "beegosessionID"
SessionGCMaxLifetime = 3600
SessionSavePath = ""
SessionCookieLifeTime = 0 //set cookie default is the brower life
SessionAutoSetCookie = true
UseFcgi = false
UseStdIo = false
MaxMemory = 1 << 26 //64MB
EnableGzip = false
HttpServerTimeOut = 0
ErrorsShow = true
XSRFKEY = "beegoxsrf"
XSRFExpire = 0
TemplateLeft = "{{"
TemplateRight = "}}"
BeegoServerName = "beegoServer:" + VERSION
EnableAdmin = false
AdminHttpAddr = "127.0.0.1"
AdminHttpPort = 8088
FlashName = "BEEGO_FLASH"
FlashSeperator = "BEEGOFLASH"
RouterCaseSensitive = true
runtime.GOMAXPROCS(runtime.NumCPU())
// init BeeLogger
BeeLogger = logs.NewLogger(10000)
err := BeeLogger.SetLogger("console", "")
if err != nil {
fmt.Println("init console log error:", err)
}
SetLogFuncCall(true)
err = ParseConfig()
if err != nil && os.IsNotExist(err) {
// for init if doesn't have app.conf will not panic
ac := config.NewFakeConfig()
AppConfig = &beegoAppConfig{ac}
Warning(err)
}
}
// ParseConfig parsed default config file.
// now only support ini, next will support json.
func ParseConfig() (err error) {
AppConfig, err = newAppConfig(AppConfigProvider, AppConfigPath)
if err != nil {
return err
}
envRunMode := os.Getenv("BEEGO_RUNMODE")
// set the runmode first
if envRunMode != "" {
RunMode = envRunMode
} else if runmode := AppConfig.String("RunMode"); runmode != "" {
RunMode = runmode
}
HttpAddr = AppConfig.String("HttpAddr")
if v, err := AppConfig.Int("HttpPort"); err == nil {
HttpPort = v
}
if v, err := AppConfig.Bool("ListenTCP4"); err == nil {
ListenTCP4 = v
}
if v, err := AppConfig.Bool("EnableHttpListen"); err == nil {
EnableHttpListen = v
}
if maxmemory, err := AppConfig.Int64("MaxMemory"); err == nil {
MaxMemory = maxmemory
}
if appname := AppConfig.String("AppName"); appname != "" {
AppName = appname
}
if autorender, err := AppConfig.Bool("AutoRender"); err == nil {
AutoRender = autorender
}
if autorecover, err := AppConfig.Bool("RecoverPanic"); err == nil {
RecoverPanic = autorecover
}
if views := AppConfig.String("ViewsPath"); views != "" {
ViewsPath = views
}
if sessionon, err := AppConfig.Bool("SessionOn"); err == nil {
SessionOn = sessionon
}
if sessProvider := AppConfig.String("SessionProvider"); sessProvider != "" {
SessionProvider = sessProvider
}
if sessName := AppConfig.String("SessionName"); sessName != "" {
SessionName = sessName
}
if sesssavepath := AppConfig.String("SessionSavePath"); sesssavepath != "" {
SessionSavePath = sesssavepath
}
if sessMaxLifeTime, err := AppConfig.Int64("SessionGCMaxLifetime"); err == nil && sessMaxLifeTime != 0 {
SessionGCMaxLifetime = sessMaxLifeTime
}
if sesscookielifetime, err := AppConfig.Int("SessionCookieLifeTime"); err == nil && sesscookielifetime != 0 {
SessionCookieLifeTime = sesscookielifetime
}
if usefcgi, err := AppConfig.Bool("UseFcgi"); err == nil {
UseFcgi = usefcgi
}
if enablegzip, err := AppConfig.Bool("EnableGzip"); err == nil {
EnableGzip = enablegzip
}
if directoryindex, err := AppConfig.Bool("DirectoryIndex"); err == nil {
DirectoryIndex = directoryindex
}
if timeout, err := AppConfig.Int64("HttpServerTimeOut"); err == nil {
HttpServerTimeOut = timeout
}
if errorsshow, err := AppConfig.Bool("ErrorsShow"); err == nil {
ErrorsShow = errorsshow
}
if copyrequestbody, err := AppConfig.Bool("CopyRequestBody"); err == nil {
CopyRequestBody = copyrequestbody
}
if xsrfkey := AppConfig.String("XSRFKEY"); xsrfkey != "" {
XSRFKEY = xsrfkey
}
if enablexsrf, err := AppConfig.Bool("EnableXSRF"); err == nil {
EnableXSRF = enablexsrf
}
if expire, err := AppConfig.Int("XSRFExpire"); err == nil {
XSRFExpire = expire
}
if tplleft := AppConfig.String("TemplateLeft"); tplleft != "" {
TemplateLeft = tplleft
}
if tplright := AppConfig.String("TemplateRight"); tplright != "" {
TemplateRight = tplright
}
if httptls, err := AppConfig.Bool("EnableHttpTLS"); err == nil {
EnableHttpTLS = httptls
}
if httpsport, err := AppConfig.Int("HttpsPort"); err == nil {
HttpsPort = httpsport
}
if certfile := AppConfig.String("HttpCertFile"); certfile != "" {
HttpCertFile = certfile
}
if keyfile := AppConfig.String("HttpKeyFile"); keyfile != "" {
HttpKeyFile = keyfile
}
if serverName := AppConfig.String("BeegoServerName"); serverName != "" {
BeegoServerName = serverName
}
if flashname := AppConfig.String("FlashName"); flashname != "" {
FlashName = flashname
}
if flashseperator := AppConfig.String("FlashSeperator"); flashseperator != "" {
FlashSeperator = flashseperator
}
if sd := AppConfig.String("StaticDir"); sd != "" {
for k := range StaticDir {
delete(StaticDir, k)
}
sds := strings.Fields(sd)
for _, v := range sds {
if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 {
StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[1]
} else {
StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[0]
}
}
}
if sgz := AppConfig.String("StaticExtensionsToGzip"); sgz != "" {
extensions := strings.Split(sgz, ",")
if len(extensions) > 0 {
StaticExtensionsToGzip = []string{}
for _, ext := range extensions {
if len(ext) == 0 {
continue
}
extWithDot := ext
if extWithDot[:1] != "." {
extWithDot = "." + extWithDot
}
StaticExtensionsToGzip = append(StaticExtensionsToGzip, extWithDot)
}
}
}
if enableadmin, err := AppConfig.Bool("EnableAdmin"); err == nil {
EnableAdmin = enableadmin
}
if adminhttpaddr := AppConfig.String("AdminHttpAddr"); adminhttpaddr != "" {
AdminHttpAddr = adminhttpaddr
}
if adminhttpport, err := AppConfig.Int("AdminHttpPort"); err == nil {
AdminHttpPort = adminhttpport
}
if enabledocs, err := AppConfig.Bool("EnableDocs"); err == nil {
EnableDocs = enabledocs
}
if casesensitive, err := AppConfig.Bool("RouterCaseSensitive"); err == nil {
RouterCaseSensitive = casesensitive
}
return nil
}

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Package config is used to parse config
// Usage:
// import(
// "github.com/astaxie/beego/config"
@ -28,12 +29,12 @@
// cnf.Int64(key string) (int64, error)
// cnf.Bool(key string) (bool, error)
// cnf.Float(key string) (float64, error)
// cnf.DefaultString(key string, defaultval string) string
// cnf.DefaultStrings(key string, defaultval []string) []string
// cnf.DefaultInt(key string, defaultval int) int
// cnf.DefaultInt64(key string, defaultval int64) int64
// cnf.DefaultBool(key string, defaultval bool) bool
// cnf.DefaultFloat(key string, defaultval float64) float64
// cnf.DefaultString(key string, defaultVal string) string
// cnf.DefaultStrings(key string, defaultVal []string) []string
// cnf.DefaultInt(key string, defaultVal int) int
// cnf.DefaultInt64(key string, defaultVal int64) int64
// cnf.DefaultBool(key string, defaultVal bool) bool
// cnf.DefaultFloat(key string, defaultVal float64) float64
// cnf.DIY(key string) (interface{}, error)
// cnf.GetSection(section string) (map[string]string, error)
// cnf.SaveConfigFile(filename string) error
@ -45,30 +46,30 @@ import (
"fmt"
)
// ConfigContainer defines how to get and set value from configuration raw data.
type ConfigContainer interface {
Set(key, val string) error // support section::key type in given key when using ini type.
String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
// Configer defines how to get and set value from configuration raw data.
type Configer interface {
Set(key, val string) error //support section::key type in given key when using ini type.
String(key string) string //support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
Strings(key string) []string //get string slice
Int(key string) (int, error)
Int64(key string) (int64, error)
Bool(key string) (bool, error)
Float(key string) (float64, error)
DefaultString(key string, defaultval string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
DefaultStrings(key string, defaultval []string) []string //get string slice
DefaultInt(key string, defaultval int) int
DefaultInt64(key string, defaultval int64) int64
DefaultBool(key string, defaultval bool) bool
DefaultFloat(key string, defaultval float64) float64
DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
DefaultStrings(key string, defaultVal []string) []string //get string slice
DefaultInt(key string, defaultVal int) int
DefaultInt64(key string, defaultVal int64) int64
DefaultBool(key string, defaultVal bool) bool
DefaultFloat(key string, defaultVal float64) float64
DIY(key string) (interface{}, error)
GetSection(section string) (map[string]string, error)
SaveConfigFile(filename string) error
}
// Config is the adapter interface for parsing config file to get raw data to ConfigContainer.
// Config is the adapter interface for parsing config file to get raw data to Configer.
type Config interface {
Parse(key string) (ConfigContainer, error)
ParseData(data []byte) (ConfigContainer, error)
Parse(key string) (Configer, error)
ParseData(data []byte) (Configer, error)
}
var adapters = make(map[string]Config)
@ -86,19 +87,19 @@ func Register(name string, adapter Config) {
adapters[name] = adapter
}
// adapterName is ini/json/xml/yaml.
// NewConfig adapterName is ini/json/xml/yaml.
// filename is the config file path.
func NewConfig(adapterName, fileaname string) (ConfigContainer, error) {
func NewConfig(adapterName, filename string) (Configer, error) {
adapter, ok := adapters[adapterName]
if !ok {
return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName)
}
return adapter.Parse(fileaname)
return adapter.Parse(filename)
}
// adapterName is ini/json/xml/yaml.
// NewConfigData adapterName is ini/json/xml/yaml.
// data is the config data.
func NewConfigData(adapterName string, data []byte) (ConfigContainer, error) {
func NewConfigData(adapterName string, data []byte) (Configer, error) {
adapter, ok := adapters[adapterName]
if !ok {
return nil, fmt.Errorf("config: unknown adaptername %q (forgotten import?)", adapterName)

View File

@ -38,11 +38,11 @@ func (c *fakeConfigContainer) String(key string) string {
}
func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string {
if v := c.getData(key); v == "" {
v := c.getData(key)
if v == "" {
return defaultval
} else {
return v
}
return v
}
func (c *fakeConfigContainer) Strings(key string) []string {
@ -50,11 +50,11 @@ func (c *fakeConfigContainer) Strings(key string) []string {
}
func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []string {
if v := c.Strings(key); len(v) == 0 {
v := c.Strings(key)
if len(v) == 0 {
return defaultval
} else {
return v
}
return v
}
func (c *fakeConfigContainer) Int(key string) (int, error) {
@ -62,11 +62,11 @@ func (c *fakeConfigContainer) Int(key string) (int, error) {
}
func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int {
if v, err := c.Int(key); err != nil {
v, err := c.Int(key)
if err != nil {
return defaultval
} else {
return v
}
return v
}
func (c *fakeConfigContainer) Int64(key string) (int64, error) {
@ -74,11 +74,11 @@ func (c *fakeConfigContainer) Int64(key string) (int64, error) {
}
func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
if v, err := c.Int64(key); err != nil {
v, err := c.Int64(key)
if err != nil {
return defaultval
} else {
return v
}
return v
}
func (c *fakeConfigContainer) Bool(key string) (bool, error) {
@ -86,11 +86,11 @@ func (c *fakeConfigContainer) Bool(key string) (bool, error) {
}
func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool {
if v, err := c.Bool(key); err != nil {
v, err := c.Bool(key)
if err != nil {
return defaultval
} else {
return v
}
return v
}
func (c *fakeConfigContainer) Float(key string) (float64, error) {
@ -98,11 +98,11 @@ func (c *fakeConfigContainer) Float(key string) (float64, error) {
}
func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
if v, err := c.Float(key); err != nil {
v, err := c.Float(key)
if err != nil {
return defaultval
} else {
return v
}
return v
}
func (c *fakeConfigContainer) DIY(key string) (interface{}, error) {
@ -120,9 +120,10 @@ func (c *fakeConfigContainer) SaveConfigFile(filename string) error {
return errors.New("not implement in the fakeConfigContainer")
}
var _ ConfigContainer = new(fakeConfigContainer)
var _ Configer = new(fakeConfigContainer)
func NewFakeConfig() ConfigContainer {
// NewFakeConfig return a fake Congiger
func NewFakeConfig() Configer {
return &fakeConfigContainer{
data: make(map[string]string),
}

View File

@ -31,7 +31,7 @@ import (
)
var (
DEFAULT_SECTION = "default" // default section means if some ini items not in a section, make them in default section,
defaultSection = "default" // default section means if some ini items not in a section, make them in default section,
bNumComment = []byte{'#'} // number signal
bSemComment = []byte{';'} // semicolon signal
bEmpty = []byte{}
@ -46,8 +46,8 @@ var (
type IniConfig struct {
}
// ParseFile creates a new Config and parses the file configuration from the named file.
func (ini *IniConfig) Parse(name string) (ConfigContainer, error) {
// Parse creates a new Config and parses the file configuration from the named file.
func (ini *IniConfig) Parse(name string) (Configer, error) {
return ini.parseFile(name)
}
@ -77,7 +77,7 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
buf.ReadByte()
}
}
section := DEFAULT_SECTION
section := defaultSection
for {
line, _, err := buf.ReadLine()
if err == io.EOF {
@ -171,7 +171,8 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
return cfg, nil
}
func (ini *IniConfig) ParseData(data []byte) (ConfigContainer, error) {
// ParseData parse ini the data
func (ini *IniConfig) ParseData(data []byte) (Configer, error) {
// Save memory data to temporary file
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond()))
os.MkdirAll(path.Dir(tmpName), os.ModePerm)
@ -181,7 +182,7 @@ func (ini *IniConfig) ParseData(data []byte) (ConfigContainer, error) {
return ini.Parse(tmpName)
}
// A Config represents the ini configuration.
// IniConfigContainer A Config represents the ini configuration.
// When set and get value, support key as section:name type.
type IniConfigContainer struct {
filename string
@ -199,11 +200,11 @@ func (c *IniConfigContainer) Bool(key string) (bool, error) {
// DefaultBool returns the boolean value for a given key.
// if err != nil return defaltval
func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool {
if v, err := c.Bool(key); err != nil {
v, err := c.Bool(key)
if err != nil {
return defaultval
} else {
return v
}
return v
}
// Int returns the integer value for a given key.
@ -214,11 +215,11 @@ func (c *IniConfigContainer) Int(key string) (int, error) {
// DefaultInt returns the integer value for a given key.
// if err != nil return defaltval
func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int {
if v, err := c.Int(key); err != nil {
v, err := c.Int(key)
if err != nil {
return defaultval
} else {
return v
}
return v
}
// Int64 returns the int64 value for a given key.
@ -229,11 +230,11 @@ func (c *IniConfigContainer) Int64(key string) (int64, error) {
// DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval
func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
if v, err := c.Int64(key); err != nil {
v, err := c.Int64(key)
if err != nil {
return defaultval
} else {
return v
}
return v
}
// Float returns the float value for a given key.
@ -244,11 +245,11 @@ func (c *IniConfigContainer) Float(key string) (float64, error) {
// DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval
func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
if v, err := c.Float(key); err != nil {
v, err := c.Float(key)
if err != nil {
return defaultval
} else {
return v
}
return v
}
// String returns the string value for a given key.
@ -259,11 +260,11 @@ func (c *IniConfigContainer) String(key string) string {
// DefaultString returns the string value for a given key.
// if err != nil return defaltval
func (c *IniConfigContainer) DefaultString(key string, defaultval string) string {
if v := c.String(key); v == "" {
v := c.String(key)
if v == "" {
return defaultval
} else {
return v
}
return v
}
// Strings returns the []string value for a given key.
@ -274,20 +275,19 @@ func (c *IniConfigContainer) Strings(key string) []string {
// DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval
func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string {
if v := c.Strings(key); len(v) == 0 {
v := c.Strings(key)
if len(v) == 0 {
return defaultval
} else {
return v
}
return v
}
// GetSection returns map for the given section
func (c *IniConfigContainer) GetSection(section string) (map[string]string, error) {
if v, ok := c.data[section]; ok {
return v, nil
} else {
return nil, errors.New("not exist setction")
}
return nil, errors.New("not exist setction")
}
// SaveConfigFile save the config into file
@ -300,7 +300,32 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
defer f.Close()
buf := bytes.NewBuffer(nil)
// Save default section at first place
if dt, ok := c.data[defaultSection]; ok {
for key, val := range dt {
if key != " " {
// Write key comments.
if v, ok := c.keyComment[key]; ok {
if _, err = buf.WriteString(string(bNumComment) + v + lineBreak); err != nil {
return err
}
}
// Write key and value.
if _, err = buf.WriteString(key + string(bEqual) + val + lineBreak); err != nil {
return err
}
}
}
// Put a line between sections.
if _, err = buf.WriteString(lineBreak); err != nil {
return err
}
}
// Save named sections
for section, dt := range c.data {
if section != defaultSection {
// Write section comments.
if v, ok := c.sectionComment[section]; ok {
if _, err = buf.WriteString(string(bNumComment) + v + lineBreak); err != nil {
@ -308,12 +333,10 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
}
}
if section != DEFAULT_SECTION {
// Write section name.
if _, err = buf.WriteString(string(sectionStart) + section + string(sectionEnd) + lineBreak); err != nil {
return err
}
}
for key, val := range dt {
if key != " " {
@ -336,6 +359,7 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
return err
}
}
}
if _, err = buf.WriteTo(f); err != nil {
return err
@ -343,7 +367,7 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
return nil
}
// WriteValue writes a new value for key.
// Set writes a new value for key.
// if write to one section, the key need be "section::key".
// if the section is not existed, it panics.
func (c *IniConfigContainer) Set(key, value string) error {
@ -355,14 +379,14 @@ func (c *IniConfigContainer) Set(key, value string) error {
var (
section, k string
sectionKey []string = strings.Split(key, "::")
sectionKey = strings.Split(key, "::")
)
if len(sectionKey) >= 2 {
section = sectionKey[0]
k = sectionKey[1]
} else {
section = DEFAULT_SECTION
section = defaultSection
k = sectionKey[0]
}
@ -391,13 +415,13 @@ func (c *IniConfigContainer) getdata(key string) string {
var (
section, k string
sectionKey []string = strings.Split(strings.ToLower(key), "::")
sectionKey = strings.Split(strings.ToLower(key), "::")
)
if len(sectionKey) >= 2 {
section = sectionKey[0]
k = sectionKey[1]
} else {
section = DEFAULT_SECTION
section = defaultSection
k = sectionKey[0]
}
if v, ok := c.data[section]; ok {

View File

@ -23,12 +23,12 @@ import (
"sync"
)
// JsonConfig is a json config parser and implements Config interface.
type JsonConfig struct {
// JSONConfig is a json config parser and implements Config interface.
type JSONConfig struct {
}
// Parse returns a ConfigContainer with parsed json config map.
func (js *JsonConfig) Parse(filename string) (ConfigContainer, error) {
func (js *JSONConfig) Parse(filename string) (Configer, error) {
file, err := os.Open(filename)
if err != nil {
return nil, err
@ -43,8 +43,8 @@ func (js *JsonConfig) Parse(filename string) (ConfigContainer, error) {
}
// ParseData returns a ConfigContainer with json string
func (js *JsonConfig) ParseData(data []byte) (ConfigContainer, error) {
x := &JsonConfigContainer{
func (js *JSONConfig) ParseData(data []byte) (Configer, error) {
x := &JSONConfigContainer{
data: make(map[string]interface{}),
}
err := json.Unmarshal(data, &x.data)
@ -59,15 +59,15 @@ func (js *JsonConfig) ParseData(data []byte) (ConfigContainer, error) {
return x, nil
}
// A Config represents the json configuration.
// JSONConfigContainer A Config represents the json configuration.
// Only when get value, support key as section:name type.
type JsonConfigContainer struct {
type JSONConfigContainer struct {
data map[string]interface{}
sync.RWMutex
}
// Bool returns the boolean value for a given key.
func (c *JsonConfigContainer) Bool(key string) (bool, error) {
func (c *JSONConfigContainer) Bool(key string) (bool, error) {
val := c.getData(key)
if val != nil {
if v, ok := val.(bool); ok {
@ -80,7 +80,7 @@ func (c *JsonConfigContainer) Bool(key string) (bool, error) {
// DefaultBool return the bool value if has no error
// otherwise return the defaultval
func (c *JsonConfigContainer) DefaultBool(key string, defaultval bool) bool {
func (c *JSONConfigContainer) DefaultBool(key string, defaultval bool) bool {
if v, err := c.Bool(key); err == nil {
return v
}
@ -88,7 +88,7 @@ func (c *JsonConfigContainer) DefaultBool(key string, defaultval bool) bool {
}
// Int returns the integer value for a given key.
func (c *JsonConfigContainer) Int(key string) (int, error) {
func (c *JSONConfigContainer) Int(key string) (int, error) {
val := c.getData(key)
if val != nil {
if v, ok := val.(float64); ok {
@ -101,7 +101,7 @@ func (c *JsonConfigContainer) Int(key string) (int, error) {
// DefaultInt returns the integer value for a given key.
// if err != nil return defaltval
func (c *JsonConfigContainer) DefaultInt(key string, defaultval int) int {
func (c *JSONConfigContainer) DefaultInt(key string, defaultval int) int {
if v, err := c.Int(key); err == nil {
return v
}
@ -109,7 +109,7 @@ func (c *JsonConfigContainer) DefaultInt(key string, defaultval int) int {
}
// Int64 returns the int64 value for a given key.
func (c *JsonConfigContainer) Int64(key string) (int64, error) {
func (c *JSONConfigContainer) Int64(key string) (int64, error) {
val := c.getData(key)
if val != nil {
if v, ok := val.(float64); ok {
@ -122,7 +122,7 @@ func (c *JsonConfigContainer) Int64(key string) (int64, error) {
// DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval
func (c *JsonConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
func (c *JSONConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
if v, err := c.Int64(key); err == nil {
return v
}
@ -130,7 +130,7 @@ func (c *JsonConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
}
// Float returns the float value for a given key.
func (c *JsonConfigContainer) Float(key string) (float64, error) {
func (c *JSONConfigContainer) Float(key string) (float64, error) {
val := c.getData(key)
if val != nil {
if v, ok := val.(float64); ok {
@ -143,7 +143,7 @@ func (c *JsonConfigContainer) Float(key string) (float64, error) {
// DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval
func (c *JsonConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
if v, err := c.Float(key); err == nil {
return v
}
@ -151,7 +151,7 @@ func (c *JsonConfigContainer) DefaultFloat(key string, defaultval float64) float
}
// String returns the string value for a given key.
func (c *JsonConfigContainer) String(key string) string {
func (c *JSONConfigContainer) String(key string) string {
val := c.getData(key)
if val != nil {
if v, ok := val.(string); ok {
@ -163,8 +163,8 @@ func (c *JsonConfigContainer) String(key string) string {
// DefaultString returns the string value for a given key.
// if err != nil return defaltval
func (c *JsonConfigContainer) DefaultString(key string, defaultval string) string {
// TODO FIXME should not use "" to replace non existance
func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string {
// TODO FIXME should not use "" to replace non existence
if v := c.String(key); v != "" {
return v
}
@ -172,7 +172,7 @@ func (c *JsonConfigContainer) DefaultString(key string, defaultval string) strin
}
// Strings returns the []string value for a given key.
func (c *JsonConfigContainer) Strings(key string) []string {
func (c *JSONConfigContainer) Strings(key string) []string {
stringVal := c.String(key)
if stringVal == "" {
return []string{}
@ -182,7 +182,7 @@ func (c *JsonConfigContainer) Strings(key string) []string {
// DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval
func (c *JsonConfigContainer) DefaultStrings(key string, defaultval []string) []string {
func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string {
if v := c.Strings(key); len(v) > 0 {
return v
}
@ -190,7 +190,7 @@ func (c *JsonConfigContainer) DefaultStrings(key string, defaultval []string) []
}
// GetSection returns map for the given section
func (c *JsonConfigContainer) GetSection(section string) (map[string]string, error) {
func (c *JSONConfigContainer) GetSection(section string) (map[string]string, error) {
if v, ok := c.data[section]; ok {
return v.(map[string]string), nil
}
@ -198,7 +198,7 @@ func (c *JsonConfigContainer) GetSection(section string) (map[string]string, err
}
// SaveConfigFile save the config into file
func (c *JsonConfigContainer) SaveConfigFile(filename string) (err error) {
func (c *JSONConfigContainer) SaveConfigFile(filename string) (err error) {
// Write configuration file by filename.
f, err := os.Create(filename)
if err != nil {
@ -214,7 +214,7 @@ func (c *JsonConfigContainer) SaveConfigFile(filename string) (err error) {
}
// Set writes a new value for key.
func (c *JsonConfigContainer) Set(key, val string) error {
func (c *JSONConfigContainer) Set(key, val string) error {
c.Lock()
defer c.Unlock()
c.data[key] = val
@ -222,7 +222,7 @@ func (c *JsonConfigContainer) Set(key, val string) error {
}
// DIY returns the raw value by a given key.
func (c *JsonConfigContainer) DIY(key string) (v interface{}, err error) {
func (c *JSONConfigContainer) DIY(key string) (v interface{}, err error) {
val := c.getData(key)
if val != nil {
return val, nil
@ -231,7 +231,7 @@ func (c *JsonConfigContainer) DIY(key string) (v interface{}, err error) {
}
// section.key or key
func (c *JsonConfigContainer) getData(key string) interface{} {
func (c *JSONConfigContainer) getData(key string) interface{} {
if len(key) == 0 {
return nil
}
@ -261,5 +261,5 @@ func (c *JsonConfigContainer) getData(key string) interface{} {
}
func init() {
Register("json", &JsonConfig{})
Register("json", &JSONConfig{})
}

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// package xml for config provider
// Package xml for config provider
//
// depend on github.com/beego/x2j
//
@ -45,20 +45,20 @@ import (
"github.com/beego/x2j"
)
// XmlConfig is a xml config parser and implements Config interface.
// Config is a xml config parser and implements Config interface.
// xml configurations should be included in <config></config> tag.
// only support key/value pair as <key>value</key> as each item.
type XMLConfig struct{}
type Config struct{}
// Parse returns a ConfigContainer with parsed xml config map.
func (xc *XMLConfig) Parse(filename string) (config.ConfigContainer, error) {
func (xc *Config) Parse(filename string) (config.Configer, error) {
file, err := os.Open(filename)
if err != nil {
return nil, err
}
defer file.Close()
x := &XMLConfigContainer{data: make(map[string]interface{})}
x := &ConfigContainer{data: make(map[string]interface{})}
content, err := ioutil.ReadAll(file)
if err != nil {
return nil, err
@ -73,84 +73,86 @@ func (xc *XMLConfig) Parse(filename string) (config.ConfigContainer, error) {
return x, nil
}
func (x *XMLConfig) ParseData(data []byte) (config.ConfigContainer, error) {
// ParseData xml data
func (xc *Config) ParseData(data []byte) (config.Configer, error) {
// Save memory data to temporary file
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond()))
os.MkdirAll(path.Dir(tmpName), os.ModePerm)
if err := ioutil.WriteFile(tmpName, data, 0655); err != nil {
return nil, err
}
return x.Parse(tmpName)
return xc.Parse(tmpName)
}
// A Config represents the xml configuration.
type XMLConfigContainer struct {
// ConfigContainer A Config represents the xml configuration.
type ConfigContainer struct {
data map[string]interface{}
sync.Mutex
}
// Bool returns the boolean value for a given key.
func (c *XMLConfigContainer) Bool(key string) (bool, error) {
func (c *ConfigContainer) Bool(key string) (bool, error) {
return strconv.ParseBool(c.data[key].(string))
}
// DefaultBool return the bool value if has no error
// otherwise return the defaultval
func (c *XMLConfigContainer) DefaultBool(key string, defaultval bool) bool {
if v, err := c.Bool(key); err != nil {
func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool {
v, err := c.Bool(key)
if err != nil {
return defaultval
} else {
return v
}
return v
}
// Int returns the integer value for a given key.
func (c *XMLConfigContainer) Int(key string) (int, error) {
func (c *ConfigContainer) Int(key string) (int, error) {
return strconv.Atoi(c.data[key].(string))
}
// DefaultInt returns the integer value for a given key.
// if err != nil return defaltval
func (c *XMLConfigContainer) DefaultInt(key string, defaultval int) int {
if v, err := c.Int(key); err != nil {
func (c *ConfigContainer) DefaultInt(key string, defaultval int) int {
v, err := c.Int(key)
if err != nil {
return defaultval
} else {
return v
}
return v
}
// Int64 returns the int64 value for a given key.
func (c *XMLConfigContainer) Int64(key string) (int64, error) {
func (c *ConfigContainer) Int64(key string) (int64, error) {
return strconv.ParseInt(c.data[key].(string), 10, 64)
}
// DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval
func (c *XMLConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
if v, err := c.Int64(key); err != nil {
func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
v, err := c.Int64(key)
if err != nil {
return defaultval
} else {
return v
}
return v
}
// Float returns the float value for a given key.
func (c *XMLConfigContainer) Float(key string) (float64, error) {
func (c *ConfigContainer) Float(key string) (float64, error) {
return strconv.ParseFloat(c.data[key].(string), 64)
}
// DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval
func (c *XMLConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
if v, err := c.Float(key); err != nil {
func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
v, err := c.Float(key)
if err != nil {
return defaultval
} else {
return v
}
return v
}
// String returns the string value for a given key.
func (c *XMLConfigContainer) String(key string) string {
func (c *ConfigContainer) String(key string) string {
if v, ok := c.data[key].(string); ok {
return v
}
@ -159,40 +161,39 @@ func (c *XMLConfigContainer) String(key string) string {
// DefaultString returns the string value for a given key.
// if err != nil return defaltval
func (c *XMLConfigContainer) DefaultString(key string, defaultval string) string {
if v := c.String(key); v == "" {
func (c *ConfigContainer) DefaultString(key string, defaultval string) string {
v := c.String(key)
if v == "" {
return defaultval
} else {
return v
}
return v
}
// Strings returns the []string value for a given key.
func (c *XMLConfigContainer) Strings(key string) []string {
func (c *ConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval
func (c *XMLConfigContainer) DefaultStrings(key string, defaultval []string) []string {
if v := c.Strings(key); len(v) == 0 {
func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string {
v := c.Strings(key)
if len(v) == 0 {
return defaultval
} else {
return v
}
return v
}
// GetSection returns map for the given section
func (c *XMLConfigContainer) GetSection(section string) (map[string]string, error) {
func (c *ConfigContainer) GetSection(section string) (map[string]string, error) {
if v, ok := c.data[section]; ok {
return v.(map[string]string), nil
} else {
return nil, errors.New("not exist setction")
}
return nil, errors.New("not exist setction")
}
// SaveConfigFile save the config into file
func (c *XMLConfigContainer) SaveConfigFile(filename string) (err error) {
func (c *ConfigContainer) SaveConfigFile(filename string) (err error) {
// Write configuration file by filename.
f, err := os.Create(filename)
if err != nil {
@ -207,8 +208,8 @@ func (c *XMLConfigContainer) SaveConfigFile(filename string) (err error) {
return err
}
// WriteValue writes a new value for key.
func (c *XMLConfigContainer) Set(key, val string) error {
// Set writes a new value for key.
func (c *ConfigContainer) Set(key, val string) error {
c.Lock()
defer c.Unlock()
c.data[key] = val
@ -216,7 +217,7 @@ func (c *XMLConfigContainer) Set(key, val string) error {
}
// DIY returns the raw value by a given key.
func (c *XMLConfigContainer) DIY(key string) (v interface{}, err error) {
func (c *ConfigContainer) DIY(key string) (v interface{}, err error) {
if v, ok := c.data[key]; ok {
return v, nil
}
@ -224,5 +225,5 @@ func (c *XMLConfigContainer) DIY(key string) (v interface{}, err error) {
}
func init() {
config.Register("xml", &XMLConfig{})
config.Register("xml", &Config{})
}

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// package yaml for config provider
// Package yaml for config provider
//
// depend on github.com/beego/goyaml2
//
@ -46,22 +46,23 @@ import (
"github.com/beego/goyaml2"
)
// YAMLConfig is a yaml config parser and implements Config interface.
type YAMLConfig struct{}
// Config is a yaml config parser and implements Config interface.
type Config struct{}
// Parse returns a ConfigContainer with parsed yaml config map.
func (yaml *YAMLConfig) Parse(filename string) (y config.ConfigContainer, err error) {
func (yaml *Config) Parse(filename string) (y config.Configer, err error) {
cnf, err := ReadYmlReader(filename)
if err != nil {
return
}
y = &YAMLConfigContainer{
y = &ConfigContainer{
data: cnf,
}
return
}
func (yaml *YAMLConfig) ParseData(data []byte) (config.ConfigContainer, error) {
// ParseData parse yaml data
func (yaml *Config) ParseData(data []byte) (config.Configer, error) {
// Save memory data to temporary file
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond()))
os.MkdirAll(path.Dir(tmpName), os.ModePerm)
@ -71,7 +72,7 @@ func (yaml *YAMLConfig) ParseData(data []byte) (config.ConfigContainer, error) {
return yaml.Parse(tmpName)
}
// Read yaml file to map.
// ReadYmlReader Read yaml file to map.
// if json like, use json package, unless goyaml2 package.
func ReadYmlReader(path string) (cnf map[string]interface{}, err error) {
f, err := os.Open(path)
@ -112,14 +113,14 @@ func ReadYmlReader(path string) (cnf map[string]interface{}, err error) {
return
}
// A Config represents the yaml configuration.
type YAMLConfigContainer struct {
// ConfigContainer A Config represents the yaml configuration.
type ConfigContainer struct {
data map[string]interface{}
sync.Mutex
}
// Bool returns the boolean value for a given key.
func (c *YAMLConfigContainer) Bool(key string) (bool, error) {
func (c *ConfigContainer) Bool(key string) (bool, error) {
if v, ok := c.data[key].(bool); ok {
return v, nil
}
@ -128,16 +129,16 @@ func (c *YAMLConfigContainer) Bool(key string) (bool, error) {
// DefaultBool return the bool value if has no error
// otherwise return the defaultval
func (c *YAMLConfigContainer) DefaultBool(key string, defaultval bool) bool {
if v, err := c.Bool(key); err != nil {
func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool {
v, err := c.Bool(key)
if err != nil {
return defaultval
} else {
return v
}
return v
}
// Int returns the integer value for a given key.
func (c *YAMLConfigContainer) Int(key string) (int, error) {
func (c *ConfigContainer) Int(key string) (int, error) {
if v, ok := c.data[key].(int64); ok {
return int(v), nil
}
@ -146,16 +147,16 @@ func (c *YAMLConfigContainer) Int(key string) (int, error) {
// DefaultInt returns the integer value for a given key.
// if err != nil return defaltval
func (c *YAMLConfigContainer) DefaultInt(key string, defaultval int) int {
if v, err := c.Int(key); err != nil {
func (c *ConfigContainer) DefaultInt(key string, defaultval int) int {
v, err := c.Int(key)
if err != nil {
return defaultval
} else {
return v
}
return v
}
// Int64 returns the int64 value for a given key.
func (c *YAMLConfigContainer) Int64(key string) (int64, error) {
func (c *ConfigContainer) Int64(key string) (int64, error) {
if v, ok := c.data[key].(int64); ok {
return v, nil
}
@ -164,16 +165,16 @@ func (c *YAMLConfigContainer) Int64(key string) (int64, error) {
// DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval
func (c *YAMLConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
if v, err := c.Int64(key); err != nil {
func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
v, err := c.Int64(key)
if err != nil {
return defaultval
} else {
return v
}
return v
}
// Float returns the float value for a given key.
func (c *YAMLConfigContainer) Float(key string) (float64, error) {
func (c *ConfigContainer) Float(key string) (float64, error) {
if v, ok := c.data[key].(float64); ok {
return v, nil
}
@ -182,16 +183,16 @@ func (c *YAMLConfigContainer) Float(key string) (float64, error) {
// DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval
func (c *YAMLConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
if v, err := c.Float(key); err != nil {
func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
v, err := c.Float(key)
if err != nil {
return defaultval
} else {
return v
}
return v
}
// String returns the string value for a given key.
func (c *YAMLConfigContainer) String(key string) string {
func (c *ConfigContainer) String(key string) string {
if v, ok := c.data[key].(string); ok {
return v
}
@ -200,40 +201,40 @@ func (c *YAMLConfigContainer) String(key string) string {
// DefaultString returns the string value for a given key.
// if err != nil return defaltval
func (c *YAMLConfigContainer) DefaultString(key string, defaultval string) string {
if v := c.String(key); v == "" {
func (c *ConfigContainer) DefaultString(key string, defaultval string) string {
v := c.String(key)
if v == "" {
return defaultval
} else {
return v
}
return v
}
// Strings returns the []string value for a given key.
func (c *YAMLConfigContainer) Strings(key string) []string {
func (c *ConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval
func (c *YAMLConfigContainer) DefaultStrings(key string, defaultval []string) []string {
if v := c.Strings(key); len(v) == 0 {
func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string {
v := c.Strings(key)
if len(v) == 0 {
return defaultval
} else {
return v
}
return v
}
// GetSection returns map for the given section
func (c *YAMLConfigContainer) GetSection(section string) (map[string]string, error) {
if v, ok := c.data[section]; ok {
func (c *ConfigContainer) GetSection(section string) (map[string]string, error) {
v, ok := c.data[section]
if ok {
return v.(map[string]string), nil
} else {
return nil, errors.New("not exist setction")
}
return nil, errors.New("not exist setction")
}
// SaveConfigFile save the config into file
func (c *YAMLConfigContainer) SaveConfigFile(filename string) (err error) {
func (c *ConfigContainer) SaveConfigFile(filename string) (err error) {
// Write configuration file by filename.
f, err := os.Create(filename)
if err != nil {
@ -244,8 +245,8 @@ func (c *YAMLConfigContainer) SaveConfigFile(filename string) (err error) {
return err
}
// WriteValue writes a new value for key.
func (c *YAMLConfigContainer) Set(key, val string) error {
// Set writes a new value for key.
func (c *ConfigContainer) Set(key, val string) error {
c.Lock()
defer c.Unlock()
c.data[key] = val
@ -253,7 +254,7 @@ func (c *YAMLConfigContainer) Set(key, val string) error {
}
// DIY returns the raw value by a given key.
func (c *YAMLConfigContainer) DIY(key string) (v interface{}, err error) {
func (c *ConfigContainer) DIY(key string) (v interface{}, err error) {
if v, ok := c.data[key]; ok {
return v, nil
}
@ -261,5 +262,5 @@ func (c *YAMLConfigContainer) DIY(key string) (v interface{}, err error) {
}
func init() {
config.Register("yaml", &YAMLConfig{})
config.Register("yaml", &Config{})
}

View File

@ -19,11 +19,11 @@ import (
)
func TestDefaults(t *testing.T) {
if FlashName != "BEEGO_FLASH" {
if BConfig.WebConfig.FlashName != "BEEGO_FLASH" {
t.Errorf("FlashName was not set to default.")
}
if FlashSeperator != "BEEGOFLASH" {
if BConfig.WebConfig.FlashSeparator != "BEEGOFLASH" {
t.Errorf("FlashName was not set to default.")
}
}

197
context/acceptencoder.go Normal file
View File

@ -0,0 +1,197 @@
// Copyright 2015 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package context
import (
"bytes"
"compress/flate"
"compress/gzip"
"compress/zlib"
"io"
"net/http"
"os"
"strconv"
"strings"
"sync"
)
type resetWriter interface {
io.Writer
Reset(w io.Writer)
}
type nopResetWriter struct {
io.Writer
}
func (n nopResetWriter) Reset(w io.Writer) {
//do nothing
}
type acceptEncoder struct {
name string
levelEncode func(int) resetWriter
bestSpeedPool *sync.Pool
bestCompressionPool *sync.Pool
}
func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter {
if ac.bestSpeedPool == nil || ac.bestCompressionPool == nil {
return nopResetWriter{wr}
}
var rwr resetWriter
switch level {
case flate.BestSpeed:
rwr = ac.bestSpeedPool.Get().(resetWriter)
case flate.BestCompression:
rwr = ac.bestCompressionPool.Get().(resetWriter)
default:
rwr = ac.levelEncode(level)
}
rwr.Reset(wr)
return rwr
}
func (ac acceptEncoder) put(wr resetWriter, level int) {
if ac.bestSpeedPool == nil || ac.bestCompressionPool == nil {
return
}
wr.Reset(nil)
switch level {
case flate.BestSpeed:
ac.bestSpeedPool.Put(wr)
case flate.BestCompression:
ac.bestCompressionPool.Put(wr)
}
}
var (
noneCompressEncoder = acceptEncoder{"", nil, nil, nil}
gzipCompressEncoder = acceptEncoder{"gzip",
func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr },
&sync.Pool{
New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestSpeed); return wr },
},
&sync.Pool{
New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr },
},
}
//according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed
//deflate
//The "zlib" format defined in RFC 1950 [31] in combination with
//the "deflate" compression mechanism described in RFC 1951 [29].
deflateCompressEncoder = acceptEncoder{"deflate",
func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr },
&sync.Pool{
New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestSpeed); return wr },
},
&sync.Pool{
New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr },
},
}
)
var (
encoderMap = map[string]acceptEncoder{ // all the other compress methods will ignore
"gzip": gzipCompressEncoder,
"deflate": deflateCompressEncoder,
"*": gzipCompressEncoder, // * means any compress will accept,we prefer gzip
"identity": noneCompressEncoder, // identity means none-compress
}
)
// WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate)
func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) {
return writeLevel(encoding, writer, file, flate.BestCompression)
}
// WriteBody reads writes content to writer by the specific encoding(gzip/deflate)
func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) {
return writeLevel(encoding, writer, bytes.NewReader(content), flate.BestSpeed)
}
// writeLevel reads from reader,writes to writer by specific encoding and compress level
// the compress level is defined by deflate package
func writeLevel(encoding string, writer io.Writer, reader io.Reader, level int) (bool, string, error) {
var outputWriter resetWriter
var err error
var ce = noneCompressEncoder
if cf, ok := encoderMap[encoding]; ok {
ce = cf
}
encoding = ce.name
outputWriter = ce.encode(writer, level)
defer ce.put(outputWriter, level)
_, err = io.Copy(outputWriter, reader)
if err != nil {
return false, "", err
}
switch outputWriter.(type) {
case io.WriteCloser:
outputWriter.(io.WriteCloser).Close()
}
return encoding != "", encoding, nil
}
// ParseEncoding will extract the right encoding for response
// the Accept-Encoding's sec is here:
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.3
func ParseEncoding(r *http.Request) string {
if r == nil {
return ""
}
return parseEncoding(r)
}
type q struct {
name string
value float64
}
func parseEncoding(r *http.Request) string {
acceptEncoding := r.Header.Get("Accept-Encoding")
if acceptEncoding == "" {
return ""
}
var lastQ q
for _, v := range strings.Split(acceptEncoding, ",") {
v = strings.TrimSpace(v)
if v == "" {
continue
}
vs := strings.Split(v, ";")
if len(vs) == 1 {
lastQ = q{vs[0], 1}
break
}
if len(vs) == 2 {
f, _ := strconv.ParseFloat(strings.Replace(vs[1], "q=", "", -1), 64)
if f == 0 {
continue
}
if f > lastQ.value {
lastQ = q{vs[0], f}
}
}
}
if cf, ok := encoderMap[lastQ.name]; ok {
return cf.name
}
return ""
}

View File

@ -0,0 +1,44 @@
// Copyright 2015 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package context
import (
"net/http"
"testing"
)
func Test_ExtractEncoding(t *testing.T) {
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip,deflate"}}}) != "gzip" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"deflate,gzip"}}}) != "deflate" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=.5,deflate"}}}) != "deflate" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=.5,deflate;q=0.3"}}}) != "gzip" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0,deflate"}}}) != "deflate" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"deflate;q=0.5,gzip;q=0.5,identity"}}}) != "" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"*"}}}) != "gzip" {
t.Fail()
}
}

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Package context provide the context utils
// Usage:
//
// import "github.com/astaxie/beego/context"
@ -22,10 +23,13 @@
package context
import (
"bufio"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"net"
"net/http"
"strconv"
"strings"
@ -34,14 +38,30 @@ import (
"github.com/astaxie/beego/utils"
)
// Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter.
// NewContext return the Context with Input and Output
func NewContext() *Context {
return &Context{
Input: NewInput(),
Output: NewOutput(),
}
}
// Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter.
// BeegoInput and BeegoOutput provides some api to operate request and response more easily.
type Context struct {
Input *BeegoInput
Output *BeegoOutput
Request *http.Request
ResponseWriter http.ResponseWriter
_xsrf_token string
ResponseWriter *Response
_xsrfToken string
}
// Reset init Context, BeegoInput and BeegoOutput
func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) {
ctx.Request = r
ctx.ResponseWriter = &Response{rw, false, 0}
ctx.Input.Reset(ctx)
ctx.Output.Reset(ctx)
}
// Redirect does redirection to localurl with http header status code.
@ -54,29 +74,28 @@ func (ctx *Context) Redirect(status int, localurl string) {
// Abort stops this request.
// if beego.ErrorMaps exists, panic body.
func (ctx *Context) Abort(status int, body string) {
ctx.ResponseWriter.WriteHeader(status)
panic(body)
}
// Write string to response body.
// WriteString Write string to response body.
// it sends response body.
func (ctx *Context) WriteString(content string) {
ctx.ResponseWriter.Write([]byte(content))
}
// Get cookie from request by a given key.
// GetCookie Get cookie from request by a given key.
// It's alias of BeegoInput.Cookie.
func (ctx *Context) GetCookie(key string) string {
return ctx.Input.Cookie(key)
}
// Set cookie for response.
// SetCookie Set cookie for response.
// It's alias of BeegoOutput.Cookie.
func (ctx *Context) SetCookie(name string, value string, others ...interface{}) {
ctx.Output.Cookie(name, value, others...)
}
// Get secure cookie from request by a given key.
// GetSecureCookie Get secure cookie from request by a given key.
func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) {
val := ctx.Input.Cookie(key)
if val == "" {
@ -103,7 +122,7 @@ func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) {
return string(res), true
}
// Set Secure cookie for response.
// SetSecureCookie Set Secure cookie for response.
func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) {
vs := base64.URLEncoding.EncodeToString([]byte(value))
timestamp := strconv.FormatInt(time.Now().UnixNano(), 10)
@ -114,23 +133,23 @@ func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interf
ctx.Output.Cookie(name, cookie, others...)
}
// XsrfToken creates a xsrf token string and returns.
func (ctx *Context) XsrfToken(key string, expire int64) string {
if ctx._xsrf_token == "" {
// XSRFToken creates a xsrf token string and returns.
func (ctx *Context) XSRFToken(key string, expire int64) string {
if ctx._xsrfToken == "" {
token, ok := ctx.GetSecureCookie(key, "_xsrf")
if !ok {
token = string(utils.RandomCreateBytes(32))
ctx.SetSecureCookie(key, "_xsrf", token, expire)
}
ctx._xsrf_token = token
ctx._xsrfToken = token
}
return ctx._xsrf_token
return ctx._xsrfToken
}
// CheckXsrfCookie checks xsrf token in this request is valid or not.
// CheckXSRFCookie checks xsrf token in this request is valid or not.
// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken"
// or in form field value named as "_xsrf".
func (ctx *Context) CheckXsrfCookie() bool {
func (ctx *Context) CheckXSRFCookie() bool {
token := ctx.Input.Query("_xsrf")
if token == "" {
token = ctx.Request.Header.Get("X-Xsrftoken")
@ -142,9 +161,57 @@ func (ctx *Context) CheckXsrfCookie() bool {
ctx.Abort(403, "'_xsrf' argument missing from POST")
return false
}
if ctx._xsrf_token != token {
if ctx._xsrfToken != token {
ctx.Abort(403, "XSRF cookie does not match POST argument")
return false
}
return true
}
//Response is a wrapper for the http.ResponseWriter
//started set to true if response was written to then don't execute other handler
type Response struct {
http.ResponseWriter
Started bool
Status int
}
// Write writes the data to the connection as part of an HTTP reply,
// and sets `started` to true.
// started means the response has sent out.
func (w *Response) Write(p []byte) (int, error) {
w.Started = true
return w.ResponseWriter.Write(p)
}
// WriteHeader sends an HTTP response header with status code,
// and sets `started` to true.
func (w *Response) WriteHeader(code int) {
w.Status = code
w.Started = true
w.ResponseWriter.WriteHeader(code)
}
// Hijack hijacker for http
func (w *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hj, ok := w.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, errors.New("webserver doesn't support hijacking")
}
return hj.Hijack()
}
// Flush http.Flusher
func (w *Response) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
// CloseNotify http.CloseNotifier
func (w *Response) CloseNotify() <-chan bool {
if cn, ok := w.ResponseWriter.(http.CloseNotifier); ok {
return cn.CloseNotify()
}
return nil
}

View File

@ -17,50 +17,69 @@ package context
import (
"bytes"
"errors"
"io"
"io/ioutil"
"net/http"
"net/url"
"reflect"
"regexp"
"strconv"
"strings"
"github.com/astaxie/beego/session"
)
// Regexes for checking the accept headers
// TODO make sure these are correct
var (
acceptsHTMLRegex = regexp.MustCompile(`(text/html|application/xhtml\+xml)(?:,|$)`)
acceptsXMLRegex = regexp.MustCompile(`(application/xml|text/xml)(?:,|$)`)
acceptsJSONRegex = regexp.MustCompile(`(application/json)(?:,|$)`)
maxParam = 50
)
// BeegoInput operates the http request header, data, cookie and body.
// it also contains router params and current session.
type BeegoInput struct {
CruSession session.SessionStore
Params map[string]string
Data map[interface{}]interface{} // store some values in this context when calling context in filter or controller.
Request *http.Request
Context *Context
CruSession session.Store
pnames []string
pvalues []string
data map[interface{}]interface{} // store some values in this context when calling context in filter or controller.
RequestBody []byte
RunController reflect.Type
RunMethod string
}
// NewInput return BeegoInput generated by http.Request.
func NewInput(req *http.Request) *BeegoInput {
// NewInput return BeegoInput generated by Context.
func NewInput() *BeegoInput {
return &BeegoInput{
Params: make(map[string]string),
Data: make(map[interface{}]interface{}),
Request: req,
pnames: make([]string, 0, maxParam),
pvalues: make([]string, 0, maxParam),
data: make(map[interface{}]interface{}),
}
}
// Reset init the BeegoInput
func (input *BeegoInput) Reset(ctx *Context) {
input.Context = ctx
input.CruSession = nil
input.pnames = input.pnames[:0]
input.pvalues = input.pvalues[:0]
input.data = nil
input.RequestBody = []byte{}
}
// Protocol returns request protocol name, such as HTTP/1.1 .
func (input *BeegoInput) Protocol() string {
return input.Request.Proto
return input.Context.Request.Proto
}
// Uri returns full request url with query string, fragment.
func (input *BeegoInput) Uri() string {
return input.Request.RequestURI
// URI returns full request url with query string, fragment.
func (input *BeegoInput) URI() string {
return input.Context.Request.RequestURI
}
// Url returns request url path (without query string, fragment).
func (input *BeegoInput) Url() string {
return input.Request.URL.Path
// URL returns request url path (without query string, fragment).
func (input *BeegoInput) URL() string {
return input.Context.Request.URL.Path
}
// Site returns base site url as scheme://domain type.
@ -70,10 +89,10 @@ func (input *BeegoInput) Site() string {
// Scheme returns request scheme as "http" or "https".
func (input *BeegoInput) Scheme() string {
if input.Request.URL.Scheme != "" {
return input.Request.URL.Scheme
if input.Context.Request.URL.Scheme != "" {
return input.Context.Request.URL.Scheme
}
if input.Request.TLS == nil {
if input.Context.Request.TLS == nil {
return "http"
}
return "https"
@ -88,19 +107,19 @@ func (input *BeegoInput) Domain() string {
// Host returns host name.
// if no host info in request, return localhost.
func (input *BeegoInput) Host() string {
if input.Request.Host != "" {
hostParts := strings.Split(input.Request.Host, ":")
if input.Context.Request.Host != "" {
hostParts := strings.Split(input.Context.Request.Host, ":")
if len(hostParts) > 0 {
return hostParts[0]
}
return input.Request.Host
return input.Context.Request.Host
}
return "localhost"
}
// Method returns http request method.
func (input *BeegoInput) Method() string {
return input.Request.Method
return input.Context.Request.Method
}
// Is returns boolean of this request is on given method, such as Is("POST").
@ -108,37 +127,37 @@ func (input *BeegoInput) Is(method string) bool {
return input.Method() == method
}
// Is this a GET method request?
// IsGet Is this a GET method request?
func (input *BeegoInput) IsGet() bool {
return input.Is("GET")
}
// Is this a POST method request?
// IsPost Is this a POST method request?
func (input *BeegoInput) IsPost() bool {
return input.Is("POST")
}
// Is this a Head method request?
// IsHead Is this a Head method request?
func (input *BeegoInput) IsHead() bool {
return input.Is("HEAD")
}
// Is this a OPTIONS method request?
// IsOptions Is this a OPTIONS method request?
func (input *BeegoInput) IsOptions() bool {
return input.Is("OPTIONS")
}
// Is this a PUT method request?
// IsPut Is this a PUT method request?
func (input *BeegoInput) IsPut() bool {
return input.Is("PUT")
}
// Is this a DELETE method request?
// IsDelete Is this a DELETE method request?
func (input *BeegoInput) IsDelete() bool {
return input.Is("DELETE")
}
// Is this a PATCH method request?
// IsPatch Is this a PATCH method request?
func (input *BeegoInput) IsPatch() bool {
return input.Is("PATCH")
}
@ -163,6 +182,21 @@ func (input *BeegoInput) IsUpload() bool {
return strings.Contains(input.Header("Content-Type"), "multipart/form-data")
}
// AcceptsHTML Checks if request accepts html response
func (input *BeegoInput) AcceptsHTML() bool {
return acceptsHTMLRegex.MatchString(input.Header("Accept"))
}
// AcceptsXML Checks if request accepts xml response
func (input *BeegoInput) AcceptsXML() bool {
return acceptsXMLRegex.MatchString(input.Header("Accept"))
}
// AcceptsJSON Checks if request accepts json response
func (input *BeegoInput) AcceptsJSON() bool {
return acceptsJSONRegex.MatchString(input.Header("Accept"))
}
// IP returns request client ip.
// if in proxy, return first proxy id.
// if error, return 127.0.0.1.
@ -172,7 +206,7 @@ func (input *BeegoInput) IP() string {
rip := strings.Split(ips[0], ":")
return rip[0]
}
ip := strings.Split(input.Request.RemoteAddr, ":")
ip := strings.Split(input.Context.Request.RemoteAddr, ":")
if len(ip) > 0 {
if ip[0] != "[" {
return ip[0]
@ -212,7 +246,7 @@ func (input *BeegoInput) SubDomains() string {
// Port returns request client port.
// when error or empty, return 80.
func (input *BeegoInput) Port() int {
parts := strings.Split(input.Request.Host, ":")
parts := strings.Split(input.Context.Request.Host, ":")
if len(parts) == 2 {
port, _ := strconv.Atoi(parts[1])
return port
@ -225,35 +259,59 @@ func (input *BeegoInput) UserAgent() string {
return input.Header("User-Agent")
}
// ParamsLen return the length of the params
func (input *BeegoInput) ParamsLen() int {
return len(input.pnames)
}
// Param returns router param by a given key.
func (input *BeegoInput) Param(key string) string {
if v, ok := input.Params[key]; ok {
return v
for i, v := range input.pnames {
if v == key && i <= len(input.pvalues) {
return input.pvalues[i]
}
}
return ""
}
// Params returns the map[key]value.
func (input *BeegoInput) Params() map[string]string {
m := make(map[string]string)
for i, v := range input.pnames {
if i <= len(input.pvalues) {
m[v] = input.pvalues[i]
}
}
return m
}
// SetParam will set the param with key and value
func (input *BeegoInput) SetParam(key, val string) {
input.pvalues = append(input.pvalues, val)
input.pnames = append(input.pnames, key)
}
// Query returns input data item string by a given string.
func (input *BeegoInput) Query(key string) string {
if val := input.Param(key); val != "" {
return val
}
if input.Request.Form == nil {
input.Request.ParseForm()
if input.Context.Request.Form == nil {
input.Context.Request.ParseForm()
}
return input.Request.Form.Get(key)
return input.Context.Request.Form.Get(key)
}
// Header returns request header item string by a given string.
// if non-existed, return empty string.
func (input *BeegoInput) Header(key string) string {
return input.Request.Header.Get(key)
return input.Context.Request.Header.Get(key)
}
// Cookie returns request cookie item string by a given key.
// if non-existed, return empty string.
func (input *BeegoInput) Cookie(key string) string {
ck, err := input.Request.Cookie(key)
ck, err := input.Context.Request.Cookie(key)
if err != nil {
return ""
}
@ -267,18 +325,27 @@ func (input *BeegoInput) Session(key interface{}) interface{} {
}
// CopyBody returns the raw request body data as bytes.
func (input *BeegoInput) CopyBody() []byte {
requestbody, _ := ioutil.ReadAll(input.Request.Body)
input.Request.Body.Close()
func (input *BeegoInput) CopyBody(MaxMemory int64) []byte {
safe := &io.LimitedReader{R: input.Context.Request.Body, N: MaxMemory}
requestbody, _ := ioutil.ReadAll(safe)
input.Context.Request.Body.Close()
bf := bytes.NewBuffer(requestbody)
input.Request.Body = ioutil.NopCloser(bf)
input.Context.Request.Body = ioutil.NopCloser(bf)
input.RequestBody = requestbody
return requestbody
}
// Data return the implicit data in the input
func (input *BeegoInput) Data() map[interface{}]interface{} {
if input.data == nil {
input.data = make(map[interface{}]interface{})
}
return input.data
}
// GetData returns the stored data in this context.
func (input *BeegoInput) GetData(key interface{}) interface{} {
if v, ok := input.Data[key]; ok {
if v, ok := input.data[key]; ok {
return v
}
return nil
@ -287,17 +354,20 @@ func (input *BeegoInput) GetData(key interface{}) interface{} {
// SetData stores data with given key in this context.
// This data are only available in this context.
func (input *BeegoInput) SetData(key, val interface{}) {
input.Data[key] = val
if input.data == nil {
input.data = make(map[interface{}]interface{})
}
input.data[key] = val
}
// parseForm or parseMultiForm based on Content-type
// ParseFormOrMulitForm parseForm or parseMultiForm based on Content-type
func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error {
// Parse the body depending on the content type.
if strings.Contains(input.Header("Content-Type"), "multipart/form-data") {
if err := input.Request.ParseMultipartForm(maxMemory); err != nil {
if err := input.Context.Request.ParseMultipartForm(maxMemory); err != nil {
return errors.New("Error parsing request body:" + err.Error())
}
} else if err := input.Request.ParseForm(); err != nil {
} else if err := input.Context.Request.ParseForm(); err != nil {
return errors.New("Error parsing request body:" + err.Error())
}
return nil
@ -306,7 +376,7 @@ func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error {
// Bind data from request.Form[key] to dest
// like /?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie
// var id int beegoInput.Bind(&id, "id") id ==123
// var isok bool beegoInput.Bind(&isok, "isok") id ==true
// var isok bool beegoInput.Bind(&isok, "isok") isok ==true
// var ft float64 beegoInput.Bind(&ft, "ft") ft ==1.2
// ol := make([]int, 0, 2) beegoInput.Bind(&ol, "ol") ol ==[1 2]
// ul := make([]string, 0, 2) beegoInput.Bind(&ul, "ul") ul ==[str array]
@ -329,7 +399,7 @@ func (input *BeegoInput) Bind(dest interface{}, key string) error {
}
func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value {
rv := reflect.Zero(reflect.TypeOf(0))
rv := reflect.Zero(typ)
switch typ.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
val := input.Query(key)
@ -362,19 +432,19 @@ func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value {
}
rv = input.bindBool(val, typ)
case reflect.Slice:
rv = input.bindSlice(&input.Request.Form, key, typ)
rv = input.bindSlice(&input.Context.Request.Form, key, typ)
case reflect.Struct:
rv = input.bindStruct(&input.Request.Form, key, typ)
rv = input.bindStruct(&input.Context.Request.Form, key, typ)
case reflect.Ptr:
rv = input.bindPoint(key, typ)
case reflect.Map:
rv = input.bindMap(&input.Request.Form, key, typ)
rv = input.bindMap(&input.Context.Request.Form, key, typ)
}
return rv
}
func (input *BeegoInput) bindValue(val string, typ reflect.Type) reflect.Value {
rv := reflect.Zero(reflect.TypeOf(0))
rv := reflect.Zero(typ)
switch typ.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
rv = input.bindInt(val, typ)

View File

@ -17,12 +17,15 @@ package context
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
)
func TestParse(t *testing.T) {
r, _ := http.NewRequest("GET", "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil)
beegoInput := NewInput(r)
beegoInput := NewInput()
beegoInput.Context = NewContext()
beegoInput.Context.Reset(httptest.NewRecorder(), r)
beegoInput.ParseFormOrMulitForm(1 << 20)
var id int
@ -73,7 +76,9 @@ func TestParse(t *testing.T) {
func TestSubDomain(t *testing.T) {
r, _ := http.NewRequest("GET", "http://www.example.com/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil)
beegoInput := NewInput(r)
beegoInput := NewInput()
beegoInput.Context = NewContext()
beegoInput.Context.Reset(httptest.NewRecorder(), r)
subdomain := beegoInput.SubDomains()
if subdomain != "www" {
@ -81,13 +86,13 @@ func TestSubDomain(t *testing.T) {
}
r, _ = http.NewRequest("GET", "http://localhost/", nil)
beegoInput.Request = r
beegoInput.Context.Request = r
if beegoInput.SubDomains() != "" {
t.Fatal("Subdomain parse error, should be empty, got " + beegoInput.SubDomains())
}
r, _ = http.NewRequest("GET", "http://aa.bb.example.com/", nil)
beegoInput.Request = r
beegoInput.Context.Request = r
if beegoInput.SubDomains() != "aa.bb" {
t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains())
}
@ -101,13 +106,13 @@ func TestSubDomain(t *testing.T) {
*/
r, _ = http.NewRequest("GET", "http://example.com/", nil)
beegoInput.Request = r
beegoInput.Context.Request = r
if beegoInput.SubDomains() != "" {
t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains())
}
r, _ = http.NewRequest("GET", "http://aa.bb.cc.dd.example.com/", nil)
beegoInput.Request = r
beegoInput.Context.Request = r
if beegoInput.SubDomains() != "aa.bb.cc.dd" {
t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains())
}

View File

@ -16,8 +16,6 @@ package context
import (
"bytes"
"compress/flate"
"compress/gzip"
"encoding/json"
"encoding/xml"
"errors"
@ -29,6 +27,7 @@ import (
"path/filepath"
"strconv"
"strings"
"time"
)
// BeegoOutput does work for sending response header.
@ -44,6 +43,12 @@ func NewOutput() *BeegoOutput {
return &BeegoOutput{}
}
// Reset init BeegoOutput
func (output *BeegoOutput) Reset(ctx *Context) {
output.Context = ctx
output.Status = 0
}
// Header sets response header item string via given key.
func (output *BeegoOutput) Header(key, val string) {
output.Context.ResponseWriter.Header().Set(key, val)
@ -53,30 +58,16 @@ func (output *BeegoOutput) Header(key, val string) {
// if EnableGzip, compress content string.
// it sends out response body directly.
func (output *BeegoOutput) Body(content []byte) {
output_writer := output.Context.ResponseWriter.(io.Writer)
if output.EnableGzip == true && output.Context.Input.Header("Accept-Encoding") != "" {
splitted := strings.SplitN(output.Context.Input.Header("Accept-Encoding"), ",", -1)
encodings := make([]string, len(splitted))
for i, val := range splitted {
encodings[i] = strings.TrimSpace(val)
}
for _, val := range encodings {
if val == "gzip" {
output.Header("Content-Encoding", "gzip")
output_writer, _ = gzip.NewWriterLevel(output.Context.ResponseWriter, gzip.BestSpeed)
break
} else if val == "deflate" {
output.Header("Content-Encoding", "deflate")
output_writer, _ = flate.NewWriter(output.Context.ResponseWriter, flate.BestSpeed)
break
}
var encoding string
var buf = &bytes.Buffer{}
if output.EnableGzip {
encoding = ParseEncoding(output.Context.Request)
}
if b, n, _ := WriteBody(encoding, buf, content); b {
output.Header("Content-Encoding", n)
} else {
output.Header("Content-Length", strconv.Itoa(len(content)))
}
// Write status code if it has been set manually
// Set it to 0 afterwards to prevent "multiple response.WriteHeader calls"
if output.Status != 0 {
@ -84,13 +75,7 @@ func (output *BeegoOutput) Body(content []byte) {
output.Status = 0
}
output_writer.Write(content)
switch output_writer.(type) {
case *gzip.Writer:
output_writer.(*gzip.Writer).Close()
case *flate.Writer:
output_writer.(*flate.Writer).Close()
}
io.Copy(output.Context.ResponseWriter, buf)
}
// Cookie sets cookie value via given key.
@ -98,26 +83,24 @@ func (output *BeegoOutput) Body(content []byte) {
func (output *BeegoOutput) Cookie(name string, value string, others ...interface{}) {
var b bytes.Buffer
fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value))
//fix cookie not work in IE
if len(others) > 0 {
var maxAge int64
switch v := others[0].(type) {
case int:
if v > 0 {
fmt.Fprintf(&b, "; Max-Age=%d", v)
} else if v < 0 {
fmt.Fprintf(&b, "; Max-Age=0")
}
case int64:
if v > 0 {
fmt.Fprintf(&b, "; Max-Age=%d", v)
} else if v < 0 {
fmt.Fprintf(&b, "; Max-Age=0")
}
maxAge = int64(v)
case int32:
if v > 0 {
fmt.Fprintf(&b, "; Max-Age=%d", v)
} else if v < 0 {
fmt.Fprintf(&b, "; Max-Age=0")
maxAge = int64(v)
case int64:
maxAge = v
}
if maxAge > 0 {
fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(maxAge)*time.Second).UTC().Format(time.RFC1123), maxAge)
} else {
fmt.Fprintf(&b, "; Max-Age=0")
}
}
@ -185,9 +168,9 @@ func sanitizeValue(v string) string {
return cookieValueSanitizer.Replace(v)
}
// Json writes json to response body.
// JSON writes json to response body.
// if coding is true, it converts utf-8 to \u0000 type.
func (output *BeegoOutput) Json(data interface{}, hasIndent bool, coding bool) error {
func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, coding bool) error {
output.Header("Content-Type", "application/json; charset=utf-8")
var content []byte
var err error
@ -201,14 +184,14 @@ func (output *BeegoOutput) Json(data interface{}, hasIndent bool, coding bool) e
return err
}
if coding {
content = []byte(stringsToJson(string(content)))
content = []byte(stringsToJSON(string(content)))
}
output.Body(content)
return nil
}
// Jsonp writes jsonp to response body.
func (output *BeegoOutput) Jsonp(data interface{}, hasIndent bool) error {
// JSONP writes jsonp to response body.
func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error {
output.Header("Content-Type", "application/javascript; charset=utf-8")
var content []byte
var err error
@ -225,16 +208,16 @@ func (output *BeegoOutput) Jsonp(data interface{}, hasIndent bool) error {
if callback == "" {
return errors.New(`"callback" parameter required`)
}
callback_content := bytes.NewBufferString(" " + template.JSEscapeString(callback))
callback_content.WriteString("(")
callback_content.Write(content)
callback_content.WriteString(");\r\n")
output.Body(callback_content.Bytes())
callbackContent := bytes.NewBufferString(" " + template.JSEscapeString(callback))
callbackContent.WriteString("(")
callbackContent.Write(content)
callbackContent.WriteString(");\r\n")
output.Body(callbackContent.Bytes())
return nil
}
// Xml writes xml string to response body.
func (output *BeegoOutput) Xml(data interface{}, hasIndent bool) error {
// XML writes xml string to response body.
func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error {
output.Header("Content-Type", "application/xml; charset=utf-8")
var content []byte
var err error
@ -328,7 +311,7 @@ func (output *BeegoOutput) IsNotFound(status int) bool {
return output.Status == 404
}
// IsClient returns boolean of this request client sends error data.
// IsClientError returns boolean of this request client sends error data.
// HTTP 4xx means forbidden.
func (output *BeegoOutput) IsClientError(status int) bool {
return output.Status >= 400 && output.Status < 500
@ -340,7 +323,7 @@ func (output *BeegoOutput) IsServerError(status int) bool {
return output.Status >= 500 && output.Status < 600
}
func stringsToJson(str string) string {
func stringsToJSON(str string) string {
rs := []rune(str)
jsons := ""
for _, r := range rs {
@ -354,7 +337,7 @@ func stringsToJson(str string) string {
return jsons
}
// Sessions sets session item value with given key.
// Session sets session item value with given key.
func (output *BeegoOutput) Session(name interface{}, value interface{}) {
output.Context.Input.CruSession.Set(name, value)
}

View File

@ -19,7 +19,6 @@ import (
"errors"
"html/template"
"io"
"io/ioutil"
"mime/multipart"
"net/http"
"net/url"
@ -34,18 +33,19 @@ import (
//commonly used mime-types
const (
applicationJson = "application/json"
applicationXml = "application/xml"
textXml = "text/xml"
applicationJSON = "application/json"
applicationXML = "application/xml"
textXML = "text/xml"
)
var (
// custom error when user stop request handler manually.
USERSTOPRUN = errors.New("User stop run")
GlobalControllerRouter map[string][]ControllerComments = make(map[string][]ControllerComments) //pkgpath+controller:comments
// ErrAbort custom error when user stop request handler manually.
ErrAbort = errors.New("User stop run")
// GlobalControllerRouter store comments with controller. pkgpath+controller:comments
GlobalControllerRouter = make(map[string][]ControllerComments)
)
// store the comment for the controller method
// ControllerComments store the comment for the controller method
type ControllerComments struct {
Method string
Router string
@ -56,22 +56,31 @@ type ControllerComments struct {
// Controller defines some basic http request handler operations, such as
// http context, template and view, session and xsrf.
type Controller struct {
// context data
Ctx *context.Context
Data map[interface{}]interface{}
// route controller info
controllerName string
actionName string
TplNames string
methodMapping map[string]func() //method:routertree
gotofunc string
AppController interface{}
// template data
TplName string
Layout string
LayoutSections map[string]string // the key is the section name and the value is the template name
TplExt string
_xsrf_token string
gotofunc string
CruSession session.SessionStore
XSRFExpire int
AppController interface{}
EnableRender bool
// xsrf data
_xsrfToken string
XSRFExpire int
EnableXSRF bool
methodMapping map[string]func() //method:routertree
// session
CruSession session.Store
}
// ControllerInterface is an interface to uniform all controller handler.
@ -87,8 +96,8 @@ type ControllerInterface interface {
Options()
Finish()
Render() error
XsrfToken() string
CheckXsrfCookie() bool
XSRFToken() string
CheckXSRFCookie() bool
HandlerFunc(fn string) bool
URLMapping()
}
@ -96,7 +105,7 @@ type ControllerInterface interface {
// Init generates default values of controller operations.
func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) {
c.Layout = ""
c.TplNames = ""
c.TplName = ""
c.controllerName = controllerName
c.actionName = actionName
c.Ctx = ctx
@ -104,19 +113,15 @@ func (c *Controller) Init(ctx *context.Context, controllerName, actionName strin
c.AppController = app
c.EnableRender = true
c.EnableXSRF = true
c.Data = ctx.Input.Data
c.Data = ctx.Input.Data()
c.methodMapping = make(map[string]func())
}
// Prepare runs after Init before request function execution.
func (c *Controller) Prepare() {
}
func (c *Controller) Prepare() {}
// Finish runs after request function execution.
func (c *Controller) Finish() {
}
func (c *Controller) Finish() {}
// Get adds a request function to handle GET request.
func (c *Controller) Get() {
@ -153,20 +158,19 @@ func (c *Controller) Options() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405)
}
// call function fn
// HandlerFunc call function with the name
func (c *Controller) HandlerFunc(fnname string) bool {
if v, ok := c.methodMapping[fnname]; ok {
v()
return true
} else {
return false
}
return false
}
// URLMapping register the internal Controller router.
func (c *Controller) URLMapping() {
}
func (c *Controller) URLMapping() {}
// Mapping the method to function
func (c *Controller) Mapping(method string, fn func()) {
c.methodMapping[method] = fn
}
@ -177,13 +181,11 @@ func (c *Controller) Render() error {
return nil
}
rb, err := c.RenderBytes()
if err != nil {
return err
} else {
}
c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8")
c.Ctx.Output.Body(rb)
}
return nil
}
@ -196,24 +198,33 @@ func (c *Controller) RenderString() (string, error) {
// RenderBytes returns the bytes of rendered template string. Do not send out response.
func (c *Controller) RenderBytes() ([]byte, error) {
//if the controller has set layout, then first get the tplname's content set the content to the layout
var buf bytes.Buffer
if c.Layout != "" {
if c.TplNames == "" {
c.TplNames = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt
if c.TplName == "" {
c.TplName = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt
}
if RunMode == "dev" {
BuildTemplate(ViewsPath)
if BConfig.RunMode == DEV {
buildFiles := []string{c.TplName}
if c.LayoutSections != nil {
for _, sectionTpl := range c.LayoutSections {
if sectionTpl == "" {
continue
}
newbytes := bytes.NewBufferString("")
if _, ok := BeeTemplates[c.TplNames]; !ok {
panic("can't find templatefile in the path:" + c.TplNames)
buildFiles = append(buildFiles, sectionTpl)
}
err := BeeTemplates[c.TplNames].ExecuteTemplate(newbytes, c.TplNames, c.Data)
}
BuildTemplate(BConfig.WebConfig.ViewsPath, buildFiles...)
}
if _, ok := BeeTemplates[c.TplName]; !ok {
panic("can't find templatefile in the path:" + c.TplName)
}
err := BeeTemplates[c.TplName].ExecuteTemplate(&buf, c.TplName, c.Data)
if err != nil {
Trace("template Execute err:", err)
return nil, err
}
tplcontent, _ := ioutil.ReadAll(newbytes)
c.Data["LayoutContent"] = template.HTML(string(tplcontent))
c.Data["LayoutContent"] = template.HTML(buf.String())
if c.LayoutSections != nil {
for sectionName, sectionTpl := range c.LayoutSections {
@ -222,44 +233,41 @@ func (c *Controller) RenderBytes() ([]byte, error) {
continue
}
sectionBytes := bytes.NewBufferString("")
err = BeeTemplates[sectionTpl].ExecuteTemplate(sectionBytes, sectionTpl, c.Data)
buf.Reset()
err = BeeTemplates[sectionTpl].ExecuteTemplate(&buf, sectionTpl, c.Data)
if err != nil {
Trace("template Execute err:", err)
return nil, err
}
sectionContent, _ := ioutil.ReadAll(sectionBytes)
c.Data[sectionName] = template.HTML(string(sectionContent))
c.Data[sectionName] = template.HTML(buf.String())
}
}
ibytes := bytes.NewBufferString("")
err = BeeTemplates[c.Layout].ExecuteTemplate(ibytes, c.Layout, c.Data)
buf.Reset()
err = BeeTemplates[c.Layout].ExecuteTemplate(&buf, c.Layout, c.Data)
if err != nil {
Trace("template Execute err:", err)
return nil, err
}
icontent, _ := ioutil.ReadAll(ibytes)
return icontent, nil
} else {
if c.TplNames == "" {
c.TplNames = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt
return buf.Bytes(), nil
}
if RunMode == "dev" {
BuildTemplate(ViewsPath)
if c.TplName == "" {
c.TplName = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt
}
ibytes := bytes.NewBufferString("")
if _, ok := BeeTemplates[c.TplNames]; !ok {
panic("can't find templatefile in the path:" + c.TplNames)
if BConfig.RunMode == DEV {
BuildTemplate(BConfig.WebConfig.ViewsPath, c.TplName)
}
err := BeeTemplates[c.TplNames].ExecuteTemplate(ibytes, c.TplNames, c.Data)
if _, ok := BeeTemplates[c.TplName]; !ok {
panic("can't find templatefile in the path:" + c.TplName)
}
buf.Reset()
err := BeeTemplates[c.TplName].ExecuteTemplate(&buf, c.TplName, c.Data)
if err != nil {
Trace("template Execute err:", err)
return nil, err
}
icontent, _ := ioutil.ReadAll(ibytes)
return icontent, nil
}
return buf.Bytes(), nil
}
// Redirect sends the redirection response to url with status code.
@ -267,7 +275,7 @@ func (c *Controller) Redirect(url string, code int) {
c.Ctx.Redirect(code, url)
}
// Aborts stops controller handler and show the error data if code is defined in ErrorMap or code string.
// Abort stops controller handler and show the error data if code is defined in ErrorMap or code string.
func (c *Controller) Abort(code string) {
status, err := strconv.Atoi(code)
if err != nil {
@ -285,74 +293,69 @@ func (c *Controller) CustomAbort(status int, body string) {
}
// last panic user string
c.Ctx.ResponseWriter.Write([]byte(body))
panic(USERSTOPRUN)
panic(ErrAbort)
}
// StopRun makes panic of USERSTOPRUN error and go to recover function if defined.
func (c *Controller) StopRun() {
panic(USERSTOPRUN)
panic(ErrAbort)
}
// UrlFor does another controller handler in this request function.
// URLFor does another controller handler in this request function.
// it goes to this controller method if endpoint is not clear.
func (c *Controller) UrlFor(endpoint string, values ...interface{}) string {
if len(endpoint) <= 0 {
func (c *Controller) URLFor(endpoint string, values ...interface{}) string {
if len(endpoint) == 0 {
return ""
}
if endpoint[0] == '.' {
return UrlFor(reflect.Indirect(reflect.ValueOf(c.AppController)).Type().Name()+endpoint, values...)
} else {
return UrlFor(endpoint, values...)
return URLFor(reflect.Indirect(reflect.ValueOf(c.AppController)).Type().Name()+endpoint, values...)
}
return URLFor(endpoint, values...)
}
// ServeJson sends a json response with encoding charset.
func (c *Controller) ServeJson(encoding ...bool) {
var hasIndent bool
var hasencoding bool
if RunMode == "prod" {
hasIndent = false
} else {
// ServeJSON sends a json response with encoding charset.
func (c *Controller) ServeJSON(encoding ...bool) {
var (
hasIndent = true
hasEncoding = false
)
if BConfig.RunMode == PROD {
hasIndent = false
}
if len(encoding) > 0 && encoding[0] == true {
hasencoding = true
hasEncoding = true
}
c.Ctx.Output.Json(c.Data["json"], hasIndent, hasencoding)
c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding)
}
// ServeJsonp sends a jsonp response.
func (c *Controller) ServeJsonp() {
var hasIndent bool
if RunMode == "prod" {
// ServeJSONP sends a jsonp response.
func (c *Controller) ServeJSONP() {
hasIndent := true
if BConfig.RunMode == PROD {
hasIndent = false
} else {
hasIndent = true
}
c.Ctx.Output.Jsonp(c.Data["jsonp"], hasIndent)
c.Ctx.Output.JSONP(c.Data["jsonp"], hasIndent)
}
// ServeXml sends xml response.
func (c *Controller) ServeXml() {
var hasIndent bool
if RunMode == "prod" {
// ServeXML sends xml response.
func (c *Controller) ServeXML() {
hasIndent := true
if BConfig.RunMode == PROD {
hasIndent = false
} else {
hasIndent = true
}
c.Ctx.Output.Xml(c.Data["xml"], hasIndent)
c.Ctx.Output.XML(c.Data["xml"], hasIndent)
}
// ServeFormatted serve Xml OR Json, depending on the value of the Accept header
func (c *Controller) ServeFormatted() {
accept := c.Ctx.Input.Header("Accept")
switch accept {
case applicationJson:
c.ServeJson()
case applicationXml, textXml:
c.ServeXml()
case applicationJSON:
c.ServeJSON()
case applicationXML, textXML:
c.ServeXML()
default:
c.ServeJson()
c.ServeJSON()
}
}
@ -371,16 +374,13 @@ func (c *Controller) ParseForm(obj interface{}) error {
// GetString returns the input value by key string or the default value while it's present and input is blank
func (c *Controller) GetString(key string, def ...string) string {
var defv string
if len(def) > 0 {
defv = def[0]
}
if v := c.Ctx.Input.Query(key); v != "" {
return v
} else {
return defv
}
if len(def) > 0 {
return def[0]
}
return ""
}
// GetStrings returns the input string slice by key string or the default value while it's present and input is blank
@ -391,122 +391,79 @@ func (c *Controller) GetStrings(key string, def ...[]string) []string {
defv = def[0]
}
f := c.Input()
if f == nil {
if f := c.Input(); f == nil {
return defv
} else if vs := f[key]; len(vs) > 0 {
return vs
}
vs := f[key]
if len(vs) > 0 {
return vs
} else {
return defv
}
}
// GetInt returns input as an int or the default value while it's present and input is blank
func (c *Controller) GetInt(key string, def ...int) (int, error) {
var defv int
if len(def) > 0 {
defv = def[0]
strv := c.Ctx.Input.Query(key)
if len(strv) == 0 && len(def) > 0 {
return def[0], nil
}
if strv := c.Ctx.Input.Query(key); strv != "" {
return strconv.Atoi(strv)
} else {
return defv, nil
}
}
// GetInt8 return input as an int8 or the default value while it's present and input is blank
func (c *Controller) GetInt8(key string, def ...int8) (int8, error) {
var defv int8
if len(def) > 0 {
defv = def[0]
strv := c.Ctx.Input.Query(key)
if len(strv) == 0 && len(def) > 0 {
return def[0], nil
}
if strv := c.Ctx.Input.Query(key); strv != "" {
i64, err := strconv.ParseInt(strv, 10, 8)
i8 := int8(i64)
return i8, err
} else {
return defv, nil
}
return int8(i64), err
}
// GetInt16 returns input as an int16 or the default value while it's present and input is blank
func (c *Controller) GetInt16(key string, def ...int16) (int16, error) {
var defv int16
if len(def) > 0 {
defv = def[0]
strv := c.Ctx.Input.Query(key)
if len(strv) == 0 && len(def) > 0 {
return def[0], nil
}
if strv := c.Ctx.Input.Query(key); strv != "" {
i64, err := strconv.ParseInt(strv, 10, 16)
i16 := int16(i64)
return i16, err
} else {
return defv, nil
}
return int16(i64), err
}
// GetInt32 returns input as an int32 or the default value while it's present and input is blank
func (c *Controller) GetInt32(key string, def ...int32) (int32, error) {
var defv int32
if len(def) > 0 {
defv = def[0]
}
if strv := c.Ctx.Input.Query(key); strv != "" {
i64, err := strconv.ParseInt(c.Ctx.Input.Query(key), 10, 32)
i32 := int32(i64)
return i32, err
} else {
return defv, nil
strv := c.Ctx.Input.Query(key)
if len(strv) == 0 && len(def) > 0 {
return def[0], nil
}
i64, err := strconv.ParseInt(strv, 10, 32)
return int32(i64), err
}
// GetInt64 returns input value as int64 or the default value while it's present and input is blank.
func (c *Controller) GetInt64(key string, def ...int64) (int64, error) {
var defv int64
if len(def) > 0 {
defv = def[0]
strv := c.Ctx.Input.Query(key)
if len(strv) == 0 && len(def) > 0 {
return def[0], nil
}
if strv := c.Ctx.Input.Query(key); strv != "" {
return strconv.ParseInt(strv, 10, 64)
} else {
return defv, nil
}
}
// GetBool returns input value as bool or the default value while it's present and input is blank.
func (c *Controller) GetBool(key string, def ...bool) (bool, error) {
var defv bool
if len(def) > 0 {
defv = def[0]
strv := c.Ctx.Input.Query(key)
if len(strv) == 0 && len(def) > 0 {
return def[0], nil
}
if strv := c.Ctx.Input.Query(key); strv != "" {
return strconv.ParseBool(strv)
} else {
return defv, nil
}
}
// GetFloat returns input value as float64 or the default value while it's present and input is blank.
func (c *Controller) GetFloat(key string, def ...float64) (float64, error) {
var defv float64
if len(def) > 0 {
defv = def[0]
}
if strv := c.Ctx.Input.Query(key); strv != "" {
return strconv.ParseFloat(c.Ctx.Input.Query(key), 64)
} else {
return defv, nil
strv := c.Ctx.Input.Query(key)
if len(strv) == 0 && len(def) > 0 {
return def[0], nil
}
return strconv.ParseFloat(strv, 64)
}
// GetFile returns the file data in file upload field named as key.
@ -515,6 +472,40 @@ func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader,
return c.Ctx.Request.FormFile(key)
}
// GetFiles return multi-upload files
// files, err:=c.Getfiles("myfiles")
// if err != nil {
// http.Error(w, err.Error(), http.StatusNoContent)
// return
// }
// for i, _ := range files {
// //for each fileheader, get a handle to the actual file
// file, err := files[i].Open()
// defer file.Close()
// if err != nil {
// http.Error(w, err.Error(), http.StatusInternalServerError)
// return
// }
// //create destination file making sure the path is writeable.
// dst, err := os.Create("upload/" + files[i].Filename)
// defer dst.Close()
// if err != nil {
// http.Error(w, err.Error(), http.StatusInternalServerError)
// return
// }
// //copy the uploaded file to the destination file
// if _, err := io.Copy(dst, file); err != nil {
// http.Error(w, err.Error(), http.StatusInternalServerError)
// return
// }
// }
func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) {
if files, ok := c.Ctx.Request.MultipartForm.File[key]; ok {
return files, nil
}
return nil, http.ErrMissingFile
}
// SaveToFile saves uploaded file to new path.
// it only operates the first one of mutil-upload form file field.
func (c *Controller) SaveToFile(fromfile, tofile string) error {
@ -533,7 +524,7 @@ func (c *Controller) SaveToFile(fromfile, tofile string) error {
}
// StartSession starts session and load old session data info this controller.
func (c *Controller) StartSession() session.SessionStore {
func (c *Controller) StartSession() session.Store {
if c.CruSession == nil {
c.CruSession = c.Ctx.Input.CruSession
}
@ -556,7 +547,7 @@ func (c *Controller) GetSession(name interface{}) interface{} {
return c.CruSession.Get(name)
}
// SetSession removes value from session.
// DelSession removes value from session.
func (c *Controller) DelSession(name interface{}) {
if c.CruSession == nil {
c.StartSession()
@ -570,13 +561,14 @@ func (c *Controller) SessionRegenerateID() {
if c.CruSession != nil {
c.CruSession.SessionRelease(c.Ctx.ResponseWriter)
}
c.CruSession = GlobalSessions.SessionRegenerateId(c.Ctx.ResponseWriter, c.Ctx.Request)
c.CruSession = GlobalSessions.SessionRegenerateID(c.Ctx.ResponseWriter, c.Ctx.Request)
c.Ctx.Input.CruSession = c.CruSession
}
// DestroySession cleans session data and session cookie.
func (c *Controller) DestroySession() {
c.Ctx.Input.CruSession.Flush()
c.Ctx.Input.CruSession = nil
GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request)
}
@ -595,37 +587,35 @@ func (c *Controller) SetSecureCookie(Secret, name, value string, others ...inter
c.Ctx.SetSecureCookie(Secret, name, value, others...)
}
// XsrfToken creates a xsrf token string and returns.
func (c *Controller) XsrfToken() string {
if c._xsrf_token == "" {
var expire int64
// XSRFToken creates a CSRF token string and returns.
func (c *Controller) XSRFToken() string {
if c._xsrfToken == "" {
expire := int64(BConfig.WebConfig.XSRFExpire)
if c.XSRFExpire > 0 {
expire = int64(c.XSRFExpire)
} else {
expire = int64(XSRFExpire)
}
c._xsrf_token = c.Ctx.XsrfToken(XSRFKEY, expire)
c._xsrfToken = c.Ctx.XSRFToken(BConfig.WebConfig.XSRFKey, expire)
}
return c._xsrf_token
return c._xsrfToken
}
// CheckXsrfCookie checks xsrf token in this request is valid or not.
// CheckXSRFCookie checks xsrf token in this request is valid or not.
// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken"
// or in form field value named as "_xsrf".
func (c *Controller) CheckXsrfCookie() bool {
func (c *Controller) CheckXSRFCookie() bool {
if !c.EnableXSRF {
return true
}
return c.Ctx.CheckXsrfCookie()
return c.Ctx.CheckXSRFCookie()
}
// XsrfFormHtml writes an input field contains xsrf token value.
func (c *Controller) XsrfFormHtml() string {
return "<input type=\"hidden\" name=\"_xsrf\" value=\"" +
c._xsrf_token + "\"/>"
// XSRFFormHTML writes an input field contains xsrf token value.
func (c *Controller) XSRFFormHTML() string {
return `<input type="hidden" name="_xsrf" value="` +
c.XSRFToken() + `" />`
}
// GetControllerAndAction gets the executing controller name and action name.
func (c *Controller) GetControllerAndAction() (controllerName, actionName string) {
func (c *Controller) GetControllerAndAction() (string, string) {
return c.controllerName, c.actionName
}

View File

@ -15,61 +15,63 @@
package beego
import (
"fmt"
"testing"
"github.com/astaxie/beego/context"
)
func ExampleGetInt() {
i := &context.BeegoInput{Params: map[string]string{"age": "40"}}
func TestGetInt(t *testing.T) {
i := context.NewInput()
i.SetParam("age", "40")
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
val, _ := ctrlr.GetInt("age")
fmt.Printf("%T", val)
//Output: int
if val != 40 {
t.Errorf("TestGetInt expect 40,get %T,%v", val, val)
}
}
func ExampleGetInt8() {
i := &context.BeegoInput{Params: map[string]string{"age": "40"}}
func TestGetInt8(t *testing.T) {
i := context.NewInput()
i.SetParam("age", "40")
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
val, _ := ctrlr.GetInt8("age")
fmt.Printf("%T", val)
if val != 40 {
t.Errorf("TestGetInt8 expect 40,get %T,%v", val, val)
}
//Output: int8
}
func ExampleGetInt16() {
i := &context.BeegoInput{Params: map[string]string{"age": "40"}}
func TestGetInt16(t *testing.T) {
i := context.NewInput()
i.SetParam("age", "40")
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
val, _ := ctrlr.GetInt16("age")
fmt.Printf("%T", val)
//Output: int16
if val != 40 {
t.Errorf("TestGetInt16 expect 40,get %T,%v", val, val)
}
}
func ExampleGetInt32() {
i := &context.BeegoInput{Params: map[string]string{"age": "40"}}
func TestGetInt32(t *testing.T) {
i := context.NewInput()
i.SetParam("age", "40")
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
val, _ := ctrlr.GetInt32("age")
fmt.Printf("%T", val)
//Output: int32
if val != 40 {
t.Errorf("TestGetInt32 expect 40,get %T,%v", val, val)
}
}
func ExampleGetInt64() {
i := &context.BeegoInput{Params: map[string]string{"age": "40"}}
func TestGetInt64(t *testing.T) {
i := context.NewInput()
i.SetParam("age", "40")
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
val, _ := ctrlr.GetInt64("age")
fmt.Printf("%T", val)
//Output: int64
if val != 40 {
t.Errorf("TestGeetInt64 expect 40,get %T,%v", val, val)
}
}

17
doc.go Normal file
View File

@ -0,0 +1,17 @@
/*
Package beego provide a MVC framework
beego: an open-source, high-performance, modular, full-stack web framework
It is used for rapid development of RESTful APIs, web apps and backend services in Go.
beego is inspired by Tornado, Sinatra and Flask with the added benefit of some Go-specific features such as interfaces and struct embedding.
package main
import "github.com/astaxie/beego"
func main() {
beego.Run()
}
more information: http://beego.me
*/
package beego

23
docs.go
View File

@ -15,37 +15,24 @@
package beego
import (
"encoding/json"
"github.com/astaxie/beego/context"
)
var GlobalDocApi map[string]interface{}
func init() {
if EnableDocs {
GlobalDocApi = make(map[string]interface{})
}
}
// GlobalDocAPI store the swagger api documents
var GlobalDocAPI = make(map[string]interface{})
func serverDocs(ctx *context.Context) {
var obj interface{}
if splat := ctx.Input.Param(":splat"); splat == "" {
obj = GlobalDocApi["Root"]
obj = GlobalDocAPI["Root"]
} else {
if v, ok := GlobalDocApi[splat]; ok {
if v, ok := GlobalDocAPI[splat]; ok {
obj = v
}
}
if obj != nil {
bt, err := json.Marshal(obj)
if err != nil {
ctx.Output.SetStatus(504)
return
}
ctx.Output.Header("Content-Type", "application/json;charset=UTF-8")
ctx.Output.Header("Access-Control-Allow-Origin", "*")
ctx.Output.Body(bt)
ctx.Output.JSON(obj, false, false)
return
}
ctx.Output.SetStatus(404)

211
error.go
View File

@ -82,17 +82,18 @@ var tpl = `
`
// render default application error page with error and stack string.
func showErr(err interface{}, ctx *context.Context, Stack string) {
func showErr(err interface{}, ctx *context.Context, stack string) {
t, _ := template.New("beegoerrortemp").Parse(tpl)
data := make(map[string]string)
data["AppError"] = AppName + ":" + fmt.Sprint(err)
data["RequestMethod"] = ctx.Input.Method()
data["RequestURL"] = ctx.Input.Uri()
data["RemoteAddr"] = ctx.Input.IP()
data["Stack"] = Stack
data["BeegoVersion"] = VERSION
data["GoVersion"] = runtime.Version()
ctx.Output.SetStatus(500)
data := map[string]string{
"AppError": fmt.Sprintf("%s:%v", BConfig.AppName, err),
"RequestMethod": ctx.Input.Method(),
"RequestURL": ctx.Input.URI(),
"RemoteAddr": ctx.Input.IP(),
"Stack": stack,
"BeegoVersion": VERSION,
"GoVersion": runtime.Version(),
}
ctx.ResponseWriter.WriteHeader(500)
t.Execute(ctx.ResponseWriter, data)
}
@ -203,48 +204,49 @@ type errorInfo struct {
errorType int
}
// map of http handlers for each error string.
var ErrorMaps map[string]*errorInfo
func init() {
ErrorMaps = make(map[string]*errorInfo)
}
// ErrorMaps holds map of http handlers for each error string.
// there is 10 kinds default error(40x and 50x)
var ErrorMaps = make(map[string]*errorInfo, 10)
// show 401 unauthorized error.
func unauthorized(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{})
data["Title"] = "Unauthorized"
data := map[string]interface{}{
"Title": http.StatusText(401),
"BeegoVersion": VERSION,
}
data["Content"] = template.HTML("<br>The page you have requested can't be authorized." +
"<br>Perhaps you are here because:" +
"<br><br><ul>" +
"<br>The credentials you supplied are incorrect" +
"<br>There are errors in the website address" +
"</ul>")
data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 402 Payment Required
func paymentRequired(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{})
data["Title"] = "Payment Required"
data := map[string]interface{}{
"Title": http.StatusText(402),
"BeegoVersion": VERSION,
}
data["Content"] = template.HTML("<br>The page you have requested Payment Required." +
"<br>Perhaps you are here because:" +
"<br><br><ul>" +
"<br>The credentials you supplied are incorrect" +
"<br>There are errors in the website address" +
"</ul>")
data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 403 forbidden error.
func forbidden(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{})
data["Title"] = "Forbidden"
data := map[string]interface{}{
"Title": http.StatusText(403),
"BeegoVersion": VERSION,
}
data["Content"] = template.HTML("<br>The page you have requested is forbidden." +
"<br>Perhaps you are here because:" +
"<br><br><ul>" +
@ -252,15 +254,16 @@ func forbidden(rw http.ResponseWriter, r *http.Request) {
"<br>The site may be disabled" +
"<br>You need to log in" +
"</ul>")
data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 404 notfound error.
func notFound(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{})
data["Title"] = "Page Not Found"
data := map[string]interface{}{
"Title": http.StatusText(404),
"BeegoVersion": VERSION,
}
data["Content"] = template.HTML("<br>The page you have requested has flown the coop." +
"<br>Perhaps you are here because:" +
"<br><br><ul>" +
@ -269,194 +272,163 @@ func notFound(rw http.ResponseWriter, r *http.Request) {
"<br>You were looking for your puppy and got lost" +
"<br>You like 404 pages" +
"</ul>")
data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 405 Method Not Allowed
func methodNotAllowed(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{})
data["Title"] = "Method Not Allowed"
data := map[string]interface{}{
"Title": http.StatusText(405),
"BeegoVersion": VERSION,
}
data["Content"] = template.HTML("<br>The method you have requested Not Allowed." +
"<br>Perhaps you are here because:" +
"<br><br><ul>" +
"<br>The method specified in the Request-Line is not allowed for the resource identified by the Request-URI" +
"<br>The response MUST include an Allow header containing a list of valid methods for the requested resource." +
"</ul>")
data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 500 internal server error.
func internalServerError(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{})
data["Title"] = "Internal Server Error"
data := map[string]interface{}{
"Title": http.StatusText(500),
"BeegoVersion": VERSION,
}
data["Content"] = template.HTML("<br>The page you have requested is down right now." +
"<br><br><ul>" +
"<br>Please try again later and report the error to the website administrator" +
"<br></ul>")
data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 501 Not Implemented.
func notImplemented(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{})
data["Title"] = "Not Implemented"
data := map[string]interface{}{
"Title": http.StatusText(504),
"BeegoVersion": VERSION,
}
data["Content"] = template.HTML("<br>The page you have requested is Not Implemented." +
"<br><br><ul>" +
"<br>Please try again later and report the error to the website administrator" +
"<br></ul>")
data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 502 Bad Gateway.
func badGateway(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{})
data["Title"] = "Bad Gateway"
data := map[string]interface{}{
"Title": http.StatusText(502),
"BeegoVersion": VERSION,
}
data["Content"] = template.HTML("<br>The page you have requested is down right now." +
"<br><br><ul>" +
"<br>The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request." +
"<br>Please try again later and report the error to the website administrator" +
"<br></ul>")
data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 503 service unavailable error.
func serviceUnavailable(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{})
data["Title"] = "Service Unavailable"
data := map[string]interface{}{
"Title": http.StatusText(503),
"BeegoVersion": VERSION,
}
data["Content"] = template.HTML("<br>The page you have requested is unavailable." +
"<br>Perhaps you are here because:" +
"<br><br><ul>" +
"<br><br>The page is overloaded" +
"<br>Please try again later." +
"</ul>")
data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// show 504 Gateway Timeout.
func gatewayTimeout(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{})
data["Title"] = "Gateway Timeout"
data := map[string]interface{}{
"Title": http.StatusText(504),
"BeegoVersion": VERSION,
}
data["Content"] = template.HTML("<br>The page you have requested is unavailable." +
"<br>Perhaps you are here because:" +
"<br><br><ul>" +
"<br><br>The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI." +
"<br>Please try again later." +
"</ul>")
data["BeegoVersion"] = VERSION
t.Execute(rw, data)
}
// register default error http handlers, 404,401,403,500 and 503.
func registerDefaultErrorHandler() {
if _, ok := ErrorMaps["401"]; !ok {
Errorhandler("401", unauthorized)
}
if _, ok := ErrorMaps["402"]; !ok {
Errorhandler("402", paymentRequired)
}
if _, ok := ErrorMaps["403"]; !ok {
Errorhandler("403", forbidden)
}
if _, ok := ErrorMaps["404"]; !ok {
Errorhandler("404", notFound)
}
if _, ok := ErrorMaps["405"]; !ok {
Errorhandler("405", methodNotAllowed)
}
if _, ok := ErrorMaps["500"]; !ok {
Errorhandler("500", internalServerError)
}
if _, ok := ErrorMaps["501"]; !ok {
Errorhandler("501", notImplemented)
}
if _, ok := ErrorMaps["502"]; !ok {
Errorhandler("502", badGateway)
}
if _, ok := ErrorMaps["503"]; !ok {
Errorhandler("503", serviceUnavailable)
}
if _, ok := ErrorMaps["504"]; !ok {
Errorhandler("504", gatewayTimeout)
}
}
// ErrorHandler registers http.HandlerFunc to each http err code string.
// usage:
// beego.ErrorHandler("404",NotFound)
// beego.ErrorHandler("500",InternalServerError)
func Errorhandler(code string, h http.HandlerFunc) *App {
errinfo := &errorInfo{}
errinfo.errorType = errorTypeHandler
errinfo.handler = h
errinfo.method = code
ErrorMaps[code] = errinfo
func ErrorHandler(code string, h http.HandlerFunc) *App {
ErrorMaps[code] = &errorInfo{
errorType: errorTypeHandler,
handler: h,
method: code,
}
return BeeApp
}
// ErrorController registers ControllerInterface to each http err code string.
// usage:
// beego.ErrorHandler(&controllers.ErrorController{})
// beego.ErrorController(&controllers.ErrorController{})
func ErrorController(c ControllerInterface) *App {
reflectVal := reflect.ValueOf(c)
rt := reflectVal.Type()
ct := reflect.Indirect(reflectVal).Type()
for i := 0; i < rt.NumMethod(); i++ {
if !utils.InSlice(rt.Method(i).Name, exceptMethod) && strings.HasPrefix(rt.Method(i).Name, "Error") {
errinfo := &errorInfo{}
errinfo.errorType = errorTypeController
errinfo.controllerType = ct
errinfo.method = rt.Method(i).Name
errname := strings.TrimPrefix(rt.Method(i).Name, "Error")
ErrorMaps[errname] = errinfo
methodName := rt.Method(i).Name
if !utils.InSlice(methodName, exceptMethod) && strings.HasPrefix(methodName, "Error") {
errName := strings.TrimPrefix(methodName, "Error")
ErrorMaps[errName] = &errorInfo{
errorType: errorTypeController,
controllerType: ct,
method: methodName,
}
}
}
return BeeApp
}
// show error string as simple text message.
// if error string is empty, show 500 error as default.
func exception(errcode string, ctx *context.Context) {
code, err := strconv.Atoi(errcode)
if err != nil {
code = 503
// if error string is empty, show 503 or 500 error as default.
func exception(errCode string, ctx *context.Context) {
atoi := func(code string) int {
v, err := strconv.Atoi(code)
if err == nil {
return v
}
ctx.ResponseWriter.WriteHeader(code)
if h, ok := ErrorMaps[errcode]; ok {
executeError(h, ctx)
return
} else if h, ok := ErrorMaps["503"]; ok {
executeError(h, ctx)
return
} else {
ctx.WriteString(errcode)
return 503
}
for _, ec := range []string{errCode, "503", "500"} {
if h, ok := ErrorMaps[ec]; ok {
executeError(h, ctx, atoi(ec))
return
}
}
//if 50x error has been removed from errorMap
ctx.ResponseWriter.WriteHeader(atoi(errCode))
ctx.WriteString(errCode)
}
func executeError(err *errorInfo, ctx *context.Context) {
func executeError(err *errorInfo, ctx *context.Context, code int) {
if err.errorType == errorTypeHandler {
err.handler(ctx.ResponseWriter, ctx.Request)
return
}
if err.errorType == errorTypeController {
ctx.Output.SetStatus(code)
//Invoke the request handler
vc := reflect.New(err.controllerType)
execController, ok := vc.Interface().(ControllerInterface)
@ -471,18 +443,15 @@ func executeError(err *errorInfo, ctx *context.Context) {
execController.URLMapping()
in := make([]reflect.Value, 0)
method := vc.MethodByName(err.method)
method.Call(in)
method.Call([]reflect.Value{})
//render template
if ctx.Output.Status == 0 {
if AutoRender {
if BConfig.WebConfig.AutoRender {
if err := execController.Render(); err != nil {
panic(err)
}
}
}
// finish all runrouter. release resource
execController.Finish()

View File

@ -1,5 +0,0 @@
appname = beeapi
httpport = 8080
runmode = dev
autorender = false
copyrequestbody = true

View File

@ -1,63 +0,0 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package controllers
import (
"encoding/json"
"github.com/astaxie/beego"
"github.com/astaxie/beego/example/beeapi/models"
)
type ObjectController struct {
beego.Controller
}
func (o *ObjectController) Post() {
var ob models.Object
json.Unmarshal(o.Ctx.Input.RequestBody, &ob)
objectid := models.AddOne(ob)
o.Data["json"] = map[string]string{"ObjectId": objectid}
o.ServeJson()
}
func (o *ObjectController) Get() {
objectId := o.Ctx.Input.Params[":objectId"]
if objectId != "" {
ob, err := models.GetOne(objectId)
if err != nil {
o.Data["json"] = err
} else {
o.Data["json"] = ob
}
} else {
obs := models.GetAll()
o.Data["json"] = obs
}
o.ServeJson()
}
func (o *ObjectController) Put() {
objectId := o.Ctx.Input.Params[":objectId"]
var ob models.Object
json.Unmarshal(o.Ctx.Input.RequestBody, &ob)
err := models.Update(objectId, ob.Score)
if err != nil {
o.Data["json"] = err
} else {
o.Data["json"] = "update success!"
}
o.ServeJson()
}
func (o *ObjectController) Delete() {
objectId := o.Ctx.Input.Params[":objectId"]
models.Delete(objectId)
o.Data["json"] = "delete success!"
o.ServeJson()
}

View File

@ -1,30 +0,0 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package main
import (
"github.com/astaxie/beego"
"github.com/astaxie/beego/example/beeapi/controllers"
)
// Objects
// URL HTTP Verb Functionality
// /object POST Creating Objects
// /object/<objectId> GET Retrieving Objects
// /object/<objectId> PUT Updating Objects
// /object GET Queries
// /object/<objectId> DELETE Deleting Objects
func main() {
beego.RESTRouter("/object", &controllers.ObjectController{})
beego.Run()
}

View File

@ -1,58 +0,0 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie
package models
import (
"errors"
"strconv"
"time"
)
var (
Objects map[string]*Object
)
type Object struct {
ObjectId string
Score int64
PlayerName string
}
func init() {
Objects = make(map[string]*Object)
Objects["hjkhsbnmn123"] = &Object{"hjkhsbnmn123", 100, "astaxie"}
Objects["mjjkxsxsaa23"] = &Object{"mjjkxsxsaa23", 101, "someone"}
}
func AddOne(object Object) (ObjectId string) {
object.ObjectId = "astaxie" + strconv.FormatInt(time.Now().UnixNano(), 10)
Objects[object.ObjectId] = &object
return object.ObjectId
}
func GetOne(ObjectId string) (object *Object, err error) {
if v, ok := Objects[ObjectId]; ok {
return v, nil
}
return nil, errors.New("ObjectId Not Exist")
}
func GetAll() map[string]*Object {
return Objects
}
func Update(ObjectId string, Score int64) (err error) {
if v, ok := Objects[ObjectId]; ok {
v.Score = Score
return nil
}
return errors.New("ObjectId Not Exist")
}
func Delete(ObjectId string) {
delete(Objects, ObjectId)
}

View File

@ -1,3 +0,0 @@
appname = chat
httpport = 8080
runmode = dev

View File

@ -1,20 +0,0 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors Unknwon
package controllers
import (
"github.com/astaxie/beego"
)
type MainController struct {
beego.Controller
}
func (m *MainController) Get() {
m.Data["host"] = m.Ctx.Request.Host
m.TplNames = "index.tpl"
}

View File

@ -1,181 +0,0 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors Unknwon
package controllers
import (
"io/ioutil"
"math/rand"
"net/http"
"time"
"github.com/astaxie/beego"
"github.com/gorilla/websocket"
)
const (
// Time allowed to write a message to the client.
writeWait = 10 * time.Second
// Time allowed to read the next message from the client.
readWait = 60 * time.Second
// Send pings to client with this period. Must be less than readWait.
pingPeriod = (readWait * 9) / 10
// Maximum message size allowed from client.
maxMessageSize = 512
)
func init() {
rand.Seed(time.Now().UTC().UnixNano())
go h.run()
}
// connection is an middleman between the websocket connection and the hub.
type connection struct {
username string
// The websocket connection.
ws *websocket.Conn
// Buffered channel of outbound messages.
send chan []byte
}
// readPump pumps messages from the websocket connection to the hub.
func (c *connection) readPump() {
defer func() {
h.unregister <- c
c.ws.Close()
}()
c.ws.SetReadLimit(maxMessageSize)
c.ws.SetReadDeadline(time.Now().Add(readWait))
for {
op, r, err := c.ws.NextReader()
if err != nil {
break
}
switch op {
case websocket.PongMessage:
c.ws.SetReadDeadline(time.Now().Add(readWait))
case websocket.TextMessage:
message, err := ioutil.ReadAll(r)
if err != nil {
break
}
h.broadcast <- []byte(c.username + "_" + time.Now().Format("15:04:05") + ":" + string(message))
}
}
}
// write writes a message with the given opCode and payload.
func (c *connection) write(opCode int, payload []byte) error {
c.ws.SetWriteDeadline(time.Now().Add(writeWait))
return c.ws.WriteMessage(opCode, payload)
}
// writePump pumps messages from the hub to the websocket connection.
func (c *connection) writePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
c.ws.Close()
}()
for {
select {
case message, ok := <-c.send:
if !ok {
c.write(websocket.CloseMessage, []byte{})
return
}
if err := c.write(websocket.TextMessage, message); err != nil {
return
}
case <-ticker.C:
if err := c.write(websocket.PingMessage, []byte{}); err != nil {
return
}
}
}
}
type hub struct {
// Registered connections.
connections map[*connection]bool
// Inbound messages from the connections.
broadcast chan []byte
// Register requests from the connections.
register chan *connection
// Unregister requests from connections.
unregister chan *connection
}
var h = &hub{
broadcast: make(chan []byte, maxMessageSize),
register: make(chan *connection, 1),
unregister: make(chan *connection, 1),
connections: make(map[*connection]bool),
}
func (h *hub) run() {
for {
select {
case c := <-h.register:
h.connections[c] = true
case c := <-h.unregister:
delete(h.connections, c)
close(c.send)
case m := <-h.broadcast:
for c := range h.connections {
select {
case c.send <- m:
default:
close(c.send)
delete(h.connections, c)
}
}
}
}
}
type WSController struct {
beego.Controller
}
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
func (w *WSController) Get() {
ws, err := upgrader.Upgrade(w.Ctx.ResponseWriter, w.Ctx.Request, nil)
if _, ok := err.(websocket.HandshakeError); ok {
http.Error(w.Ctx.ResponseWriter, "Not a websocket handshake", 400)
return
} else if err != nil {
return
}
c := &connection{send: make(chan []byte, 256), ws: ws, username: randomString(10)}
h.register <- c
go c.writePump()
c.readPump()
}
func randomString(l int) string {
bytes := make([]byte, l)
for i := 0; i < l; i++ {
bytes[i] = byte(randInt(65, 90))
}
return string(bytes)
}
func randInt(min int, max int) int {
return min + rand.Intn(max-min)
}

View File

@ -1,17 +0,0 @@
// Beego (http://beego.me/)
// @description beego is an open-source, high-performance web framework for the Go programming language.
// @link http://github.com/astaxie/beego for the canonical source repository
// @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors Unknwon
package main
import (
"github.com/astaxie/beego"
"github.com/astaxie/beego/example/chat/controllers"
)
func main() {
beego.Router("/", &controllers.MainController{})
beego.Router("/ws", &controllers.WSController{})
beego.Run()
}

View File

@ -1,92 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<title>Chat Example</title>
<script src="//code.jquery.com/jquery-2.1.3.min.js"></script>
<script type="text/javascript">
$(function() {
var conn;
var msg = $("#msg");
var log = $("#log");
function appendLog(msg) {
var d = log[0]
var doScroll = d.scrollTop == d.scrollHeight - d.clientHeight;
msg.appendTo(log)
if (doScroll) {
d.scrollTop = d.scrollHeight - d.clientHeight;
}
}
$("#form").submit(function() {
if (!conn) {
return false;
}
if (!msg.val()) {
return false;
}
conn.send(msg.val());
msg.val("");
return false
});
if (window["WebSocket"]) {
conn = new WebSocket("ws://{{.host}}/ws");
conn.onclose = function(evt) {
appendLog($("<div><b>Connection closed.</b></div>"))
}
conn.onmessage = function(evt) {
appendLog($("<div/>").text(evt.data))
}
} else {
appendLog($("<div><b>Your browser does not support WebSockets.</b></div>"))
}
});
</script>
<style type="text/css">
html {
overflow: hidden;
}
body {
overflow: hidden;
padding: 0;
margin: 0;
width: 100%;
height: 100%;
background: gray;
}
#log {
background: white;
margin: 0;
padding: 0.5em 0.5em 0.5em 0.5em;
position: absolute;
top: 0.5em;
left: 0.5em;
right: 0.5em;
bottom: 3em;
overflow: auto;
}
#form {
padding: 0 0.5em 0 0.5em;
margin: 0;
position: absolute;
bottom: 1em;
left: 0px;
width: 100%;
overflow: hidden;
}
</style>
</head>
<body>
<div id="log"></div>
<form id="form">
<input type="submit" value="Send" />
<input type="text" id="msg" size="64"/>
</form>
</body>
</html>

View File

@ -16,11 +16,12 @@ package beego
import "github.com/astaxie/beego/context"
// FilterFunc defines filter function type.
// FilterFunc defines a filter function which is invoked before the controller handler is executed.
type FilterFunc func(*context.Context)
// FilterRouter defines filter operation before controller handler execution.
// it can match patterned url and do filter function when action arrives.
// FilterRouter defines a filter operation which is invoked before the controller handler is executed.
// It can match the URL against a pattern, and execute a filter function
// when a request with a matching URL arrives.
type FilterRouter struct {
filterFunc FilterFunc
tree *Tree
@ -28,16 +29,15 @@ type FilterRouter struct {
returnOnOutput bool
}
// ValidRouter check current request is valid for this filter.
// if matched, returns parsed params in this request by defined filter router pattern.
func (f *FilterRouter) ValidRouter(router string) (bool, map[string]string) {
isok, params := f.tree.Match(router)
if isok == nil {
return false, nil
// ValidRouter checks if the current request is matched by this filter.
// If the request is matched, the values of the URL parameters defined
// by the filter pattern are also returned.
func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool {
isOk := f.tree.Match(url, ctx)
if isOk != nil {
if b, ok := isOk.(bool); ok {
return b
}
if isok, ok := isok.(bool); ok {
return isok, params
} else {
return false, nil
}
return false
}

View File

@ -20,10 +20,16 @@ import (
"testing"
"github.com/astaxie/beego/context"
"github.com/astaxie/beego/logs"
)
func init() {
BeeLogger = logs.NewLogger(10000)
BeeLogger.SetLogger("console", "")
}
var FilterUser = func(ctx *context.Context) {
ctx.Output.Body([]byte("i am " + ctx.Input.Params[":last"] + ctx.Input.Params[":first"]))
ctx.Output.Body([]byte("i am " + ctx.Input.Param(":last") + ctx.Input.Param(":first")))
}
func TestFilter(t *testing.T) {

View File

@ -83,27 +83,27 @@ func (fd *FlashData) Store(c *Controller) {
c.Data["flash"] = fd.Data
var flashValue string
for key, value := range fd.Data {
flashValue += "\x00" + key + "\x23" + FlashSeperator + "\x23" + value + "\x00"
flashValue += "\x00" + key + "\x23" + BConfig.WebConfig.FlashSeparator + "\x23" + value + "\x00"
}
c.Ctx.SetCookie(FlashName, url.QueryEscape(flashValue), 0, "/")
c.Ctx.SetCookie(BConfig.WebConfig.FlashName, url.QueryEscape(flashValue), 0, "/")
}
// ReadFromRequest parsed flash data from encoded values in cookie.
func ReadFromRequest(c *Controller) *FlashData {
flash := NewFlash()
if cookie, err := c.Ctx.Request.Cookie(FlashName); err == nil {
if cookie, err := c.Ctx.Request.Cookie(BConfig.WebConfig.FlashName); err == nil {
v, _ := url.QueryUnescape(cookie.Value)
vals := strings.Split(v, "\x00")
for _, v := range vals {
if len(v) > 0 {
kv := strings.Split(v, "\x23"+FlashSeperator+"\x23")
kv := strings.Split(v, "\x23"+BConfig.WebConfig.FlashSeparator+"\x23")
if len(kv) == 2 {
flash.Data[kv[0]] = kv[1]
}
}
}
//read one time then delete it
c.Ctx.SetCookie(FlashName, "", -1, "/")
c.Ctx.SetCookie(BConfig.WebConfig.FlashName, "", -1, "/")
}
c.Data["flash"] = flash.Data
return flash

View File

@ -30,7 +30,7 @@ func (t *TestFlashController) TestWriteFlash() {
flash.Notice("TestFlashString")
flash.Store(&t.Controller)
// we choose to serve json because we don't want to load a template html file
t.ServeJson(true)
t.ServeJSON(true)
}
func TestFlashHeader(t *testing.T) {

28
grace/conn.go Normal file
View File

@ -0,0 +1,28 @@
package grace
import (
"errors"
"net"
)
type graceConn struct {
net.Conn
server *Server
}
func (c graceConn) Close() (err error) {
defer func() {
if r := recover(); r != nil {
switch x := r.(type) {
case string:
err = errors.New(x)
case error:
err = x
default:
err = errors.New("Unknown panic")
}
}
}()
c.server.wg.Done()
return c.Conn.Close()
}

158
grace/grace.go Normal file
View File

@ -0,0 +1,158 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package grace use to hot reload
// Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/
//
// Usage:
//
// import(
// "log"
// "net/http"
// "os"
//
// "github.com/astaxie/beego/grace"
// )
//
// func handler(w http.ResponseWriter, r *http.Request) {
// w.Write([]byte("WORLD!"))
// }
//
// func main() {
// mux := http.NewServeMux()
// mux.HandleFunc("/hello", handler)
//
// err := grace.ListenAndServe("localhost:8080", mux)
// if err != nil {
// log.Println(err)
// }
// log.Println("Server on 8080 stopped")
// os.Exit(0)
// }
package grace
import (
"flag"
"net/http"
"os"
"strings"
"sync"
"syscall"
"time"
)
const (
// PreSignal is the position to add filter before signal
PreSignal = iota
// PostSignal is the position to add filter after signal
PostSignal
// StateInit represent the application inited
StateInit
// StateRunning represent the application is running
StateRunning
// StateShuttingDown represent the application is shutting down
StateShuttingDown
// StateTerminate represent the application is killed
StateTerminate
)
var (
regLock *sync.Mutex
runningServers map[string]*Server
runningServersOrder []string
socketPtrOffsetMap map[string]uint
runningServersForked bool
// DefaultReadTimeOut is the HTTP read timeout
DefaultReadTimeOut time.Duration
// DefaultWriteTimeOut is the HTTP Write timeout
DefaultWriteTimeOut time.Duration
// DefaultMaxHeaderBytes is the Max HTTP Herder size, default is 0, no limit
DefaultMaxHeaderBytes int
// DefaultTimeout is the shutdown server's timeout. default is 60s
DefaultTimeout = 60 * time.Second
isChild bool
socketOrder string
once sync.Once
)
func onceInit() {
regLock = &sync.Mutex{}
flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)")
flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started")
runningServers = make(map[string]*Server)
runningServersOrder = []string{}
socketPtrOffsetMap = make(map[string]uint)
}
// NewServer returns a new graceServer.
func NewServer(addr string, handler http.Handler) (srv *Server) {
once.Do(onceInit)
regLock.Lock()
defer regLock.Unlock()
if !flag.Parsed() {
flag.Parse()
}
if len(socketOrder) > 0 {
for i, addr := range strings.Split(socketOrder, ",") {
socketPtrOffsetMap[addr] = uint(i)
}
} else {
socketPtrOffsetMap[addr] = uint(len(runningServersOrder))
}
srv = &Server{
wg: sync.WaitGroup{},
sigChan: make(chan os.Signal),
isChild: isChild,
SignalHooks: map[int]map[os.Signal][]func(){
PreSignal: {
syscall.SIGHUP: {},
syscall.SIGINT: {},
syscall.SIGTERM: {},
},
PostSignal: {
syscall.SIGHUP: {},
syscall.SIGINT: {},
syscall.SIGTERM: {},
},
},
state: StateInit,
Network: "tcp",
}
srv.Server = &http.Server{}
srv.Server.Addr = addr
srv.Server.ReadTimeout = DefaultReadTimeOut
srv.Server.WriteTimeout = DefaultWriteTimeOut
srv.Server.MaxHeaderBytes = DefaultMaxHeaderBytes
srv.Server.Handler = handler
runningServersOrder = append(runningServersOrder, addr)
runningServers[addr] = srv
return
}
// ListenAndServe refer http.ListenAndServe
func ListenAndServe(addr string, handler http.Handler) error {
server := NewServer(addr, handler)
return server.ListenAndServe()
}
// ListenAndServeTLS refer http.ListenAndServeTLS
func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error {
server := NewServer(addr, handler)
return server.ListenAndServeTLS(certFile, keyFile)
}

62
grace/listener.go Normal file
View File

@ -0,0 +1,62 @@
package grace
import (
"net"
"os"
"syscall"
"time"
)
type graceListener struct {
net.Listener
stop chan error
stopped bool
server *Server
}
func newGraceListener(l net.Listener, srv *Server) (el *graceListener) {
el = &graceListener{
Listener: l,
stop: make(chan error),
server: srv,
}
go func() {
_ = <-el.stop
el.stopped = true
el.stop <- el.Listener.Close()
}()
return
}
func (gl *graceListener) Accept() (c net.Conn, err error) {
tc, err := gl.Listener.(*net.TCPListener).AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
c = graceConn{
Conn: tc,
server: gl.server,
}
gl.server.wg.Add(1)
return
}
func (gl *graceListener) Close() error {
if gl.stopped {
return syscall.EINVAL
}
gl.stop <- nil
return <-gl.stop
}
func (gl *graceListener) File() *os.File {
// returns a dup(2) - FD_CLOEXEC flag *not* set
tl := gl.Listener.(*net.TCPListener)
fl, _ := tl.File()
return fl
}

293
grace/server.go Normal file
View File

@ -0,0 +1,293 @@
package grace
import (
"crypto/tls"
"fmt"
"log"
"net"
"net/http"
"os"
"os/exec"
"os/signal"
"strings"
"sync"
"syscall"
"time"
)
// Server embedded http.Server
type Server struct {
*http.Server
GraceListener net.Listener
SignalHooks map[int]map[os.Signal][]func()
tlsInnerListener *graceListener
wg sync.WaitGroup
sigChan chan os.Signal
isChild bool
state uint8
Network string
}
// Serve accepts incoming connections on the Listener l,
// creating a new service goroutine for each.
// The service goroutines read requests and then call srv.Handler to reply to them.
func (srv *Server) Serve() (err error) {
srv.state = StateRunning
err = srv.Server.Serve(srv.GraceListener)
log.Println(syscall.Getpid(), "Waiting for connections to finish...")
srv.wg.Wait()
srv.state = StateTerminate
return
}
// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
// to handle requests on incoming connections. If srv.Addr is blank, ":http" is
// used.
func (srv *Server) ListenAndServe() (err error) {
addr := srv.Addr
if addr == "" {
addr = ":http"
}
go srv.handleSignals()
l, err := srv.getListener(addr)
if err != nil {
log.Println(err)
return err
}
srv.GraceListener = newGraceListener(l, srv)
if srv.isChild {
process, err := os.FindProcess(os.Getppid())
if err != nil {
log.Println(err)
return err
}
err = process.Kill()
if err != nil {
return err
}
}
log.Println(os.Getpid(), srv.Addr)
return srv.Serve()
}
// ListenAndServeTLS listens on the TCP network address srv.Addr and then calls
// Serve to handle requests on incoming TLS connections.
//
// Filenames containing a certificate and matching private key for the server must
// be provided. If the certificate is signed by a certificate authority, the
// certFile should be the concatenation of the server's certificate followed by the
// CA's certificate.
//
// If srv.Addr is blank, ":https" is used.
func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
addr := srv.Addr
if addr == "" {
addr = ":https"
}
config := &tls.Config{}
if srv.TLSConfig != nil {
*config = *srv.TLSConfig
}
if config.NextProtos == nil {
config.NextProtos = []string{"http/1.1"}
}
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return
}
go srv.handleSignals()
l, err := srv.getListener(addr)
if err != nil {
log.Println(err)
return err
}
srv.tlsInnerListener = newGraceListener(l, srv)
srv.GraceListener = tls.NewListener(srv.tlsInnerListener, config)
if srv.isChild {
process, err := os.FindProcess(os.Getppid())
if err != nil {
log.Println(err)
return err
}
err = process.Kill()
if err != nil {
return err
}
}
log.Println(os.Getpid(), srv.Addr)
return srv.Serve()
}
// getListener either opens a new socket to listen on, or takes the acceptor socket
// it got passed when restarted.
func (srv *Server) getListener(laddr string) (l net.Listener, err error) {
if srv.isChild {
var ptrOffset uint
if len(socketPtrOffsetMap) > 0 {
ptrOffset = socketPtrOffsetMap[laddr]
log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr])
}
f := os.NewFile(uintptr(3+ptrOffset), "")
l, err = net.FileListener(f)
if err != nil {
err = fmt.Errorf("net.FileListener error: %v", err)
return
}
} else {
l, err = net.Listen(srv.Network, laddr)
if err != nil {
err = fmt.Errorf("net.Listen error: %v", err)
return
}
}
return
}
// handleSignals listens for os Signals and calls any hooked in function that the
// user had registered with the signal.
func (srv *Server) handleSignals() {
var sig os.Signal
signal.Notify(
srv.sigChan,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
)
pid := syscall.Getpid()
for {
sig = <-srv.sigChan
srv.signalHooks(PreSignal, sig)
switch sig {
case syscall.SIGHUP:
log.Println(pid, "Received SIGHUP. forking.")
err := srv.fork()
if err != nil {
log.Println("Fork err:", err)
}
case syscall.SIGINT:
log.Println(pid, "Received SIGINT.")
srv.shutdown()
case syscall.SIGTERM:
log.Println(pid, "Received SIGTERM.")
srv.shutdown()
default:
log.Printf("Received %v: nothing i care about...\n", sig)
}
srv.signalHooks(PostSignal, sig)
}
}
func (srv *Server) signalHooks(ppFlag int, sig os.Signal) {
if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet {
return
}
for _, f := range srv.SignalHooks[ppFlag][sig] {
f()
}
return
}
// shutdown closes the listener so that no new connections are accepted. it also
// starts a goroutine that will serverTimeout (stop all running requests) the server
// after DefaultTimeout.
func (srv *Server) shutdown() {
if srv.state != StateRunning {
return
}
srv.state = StateShuttingDown
if DefaultTimeout >= 0 {
go srv.serverTimeout(DefaultTimeout)
}
err := srv.GraceListener.Close()
if err != nil {
log.Println(syscall.Getpid(), "Listener.Close() error:", err)
} else {
log.Println(syscall.Getpid(), srv.GraceListener.Addr(), "Listener closed.")
}
}
// serverTimeout forces the server to shutdown in a given timeout - whether it
// finished outstanding requests or not. if Read/WriteTimeout are not set or the
// max header size is very big a connection could hang
func (srv *Server) serverTimeout(d time.Duration) {
defer func() {
if r := recover(); r != nil {
log.Println("WaitGroup at 0", r)
}
}()
if srv.state != StateShuttingDown {
return
}
time.Sleep(d)
log.Println("[STOP - Hammer Time] Forcefully shutting down parent")
for {
if srv.state == StateTerminate {
break
}
srv.wg.Done()
}
}
func (srv *Server) fork() (err error) {
regLock.Lock()
defer regLock.Unlock()
if runningServersForked {
return
}
runningServersForked = true
var files = make([]*os.File, len(runningServers))
var orderArgs = make([]string, len(runningServers))
for _, srvPtr := range runningServers {
switch srvPtr.GraceListener.(type) {
case *graceListener:
files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.GraceListener.(*graceListener).File()
default:
files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File()
}
orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr
}
log.Println(files)
path := os.Args[0]
var args []string
if len(os.Args) > 1 {
for _, arg := range os.Args[1:] {
if arg == "-graceful" {
break
}
args = append(args, arg)
}
}
args = append(args, "-graceful")
if len(runningServers) > 1 {
args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ",")))
log.Println(args)
}
cmd := exec.Command(path, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.ExtraFiles = files
err = cmd.Start()
if err != nil {
log.Fatalf("Restart: Failed to launch, error: %v", err)
}
return
}

95
hooks.go Normal file
View File

@ -0,0 +1,95 @@
package beego
import (
"encoding/json"
"mime"
"net/http"
"path/filepath"
"github.com/astaxie/beego/session"
)
//
func registerMime() error {
for k, v := range mimemaps {
mime.AddExtensionType(k, v)
}
return nil
}
// register default error http handlers, 404,401,403,500 and 503.
func registerDefaultErrorHandler() error {
m := map[string]func(http.ResponseWriter, *http.Request){
"401": unauthorized,
"402": paymentRequired,
"403": forbidden,
"404": notFound,
"405": methodNotAllowed,
"500": internalServerError,
"501": notImplemented,
"502": badGateway,
"503": serviceUnavailable,
"504": gatewayTimeout,
}
for e, h := range m {
if _, ok := ErrorMaps[e]; !ok {
ErrorHandler(e, h)
}
}
return nil
}
func registerSession() error {
if BConfig.WebConfig.Session.SessionOn {
var err error
sessionConfig := AppConfig.String("sessionConfig")
if sessionConfig == "" {
conf := map[string]interface{}{
"cookieName": BConfig.WebConfig.Session.SessionName,
"gclifetime": BConfig.WebConfig.Session.SessionGCMaxLifetime,
"providerConfig": filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig),
"secure": BConfig.Listen.EnableHTTPS,
"enableSetCookie": BConfig.WebConfig.Session.SessionAutoSetCookie,
"domain": BConfig.WebConfig.Session.SessionDomain,
"cookieLifeTime": BConfig.WebConfig.Session.SessionCookieLifeTime,
}
confBytes, err := json.Marshal(conf)
if err != nil {
return err
}
sessionConfig = string(confBytes)
}
if GlobalSessions, err = session.NewManager(BConfig.WebConfig.Session.SessionProvider, sessionConfig); err != nil {
return err
}
go GlobalSessions.GC()
}
return nil
}
func registerTemplate() error {
if BConfig.WebConfig.AutoRender {
if err := BuildTemplate(BConfig.WebConfig.ViewsPath); err != nil {
if BConfig.RunMode == DEV {
Warn(err)
}
return err
}
}
return nil
}
func registerDocs() error {
if BConfig.WebConfig.EnableDocs {
Get("/docs", serverDocs)
Get("/docs/*", serverDocs)
}
return nil
}
func registerAdmin() error {
if BConfig.Listen.EnableAdmin {
go beeAdminApp.Run()
}
return nil
}

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Package httplib is used as http.Client
// Usage:
//
// import "github.com/astaxie/beego/httplib"
@ -32,6 +33,7 @@ package httplib
import (
"bytes"
"compress/gzip"
"crypto/tls"
"encoding/json"
"encoding/xml"
@ -50,7 +52,14 @@ import (
"time"
)
var defaultSetting = BeegoHttpSettings{false, "beegoServer", 60 * time.Second, 60 * time.Second, nil, nil, nil, false}
var defaultSetting = BeegoHTTPSettings{
UserAgent: "beegoServer",
ConnectTimeout: 60 * time.Second,
ReadWriteTimeout: 60 * time.Second,
Gzip: true,
DumpBody: true,
}
var defaultCookieJar http.CookieJar
var settingMutex sync.Mutex
@ -61,132 +70,163 @@ func createDefaultCookie() {
defaultCookieJar, _ = cookiejar.New(nil)
}
// Overwrite default settings
func SetDefaultSetting(setting BeegoHttpSettings) {
// SetDefaultSetting Overwrite default settings
func SetDefaultSetting(setting BeegoHTTPSettings) {
settingMutex.Lock()
defer settingMutex.Unlock()
defaultSetting = setting
if defaultSetting.ConnectTimeout == 0 {
defaultSetting.ConnectTimeout = 60 * time.Second
}
if defaultSetting.ReadWriteTimeout == 0 {
defaultSetting.ReadWriteTimeout = 60 * time.Second
}
}
// return *BeegoHttpRequest with specific method
func newBeegoRequest(url, method string) *BeegoHttpRequest {
// NewBeegoRequest return *BeegoHttpRequest with specific method
func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest {
var resp http.Response
u, err := url.Parse(rawurl)
if err != nil {
log.Println("Httplib:", err)
}
req := http.Request{
URL: u,
Method: method,
Header: make(http.Header),
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
}
return &BeegoHttpRequest{url, &req, map[string]string{}, map[string]string{}, defaultSetting, &resp, nil}
return &BeegoHTTPRequest{
url: rawurl,
req: &req,
params: map[string][]string{},
files: map[string]string{},
setting: defaultSetting,
resp: &resp,
}
}
// Get returns *BeegoHttpRequest with GET method.
func Get(url string) *BeegoHttpRequest {
return newBeegoRequest(url, "GET")
func Get(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "GET")
}
// Post returns *BeegoHttpRequest with POST method.
func Post(url string) *BeegoHttpRequest {
return newBeegoRequest(url, "POST")
func Post(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "POST")
}
// Put returns *BeegoHttpRequest with PUT method.
func Put(url string) *BeegoHttpRequest {
return newBeegoRequest(url, "PUT")
func Put(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "PUT")
}
// Delete returns *BeegoHttpRequest DELETE method.
func Delete(url string) *BeegoHttpRequest {
return newBeegoRequest(url, "DELETE")
func Delete(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "DELETE")
}
// Head returns *BeegoHttpRequest with HEAD method.
func Head(url string) *BeegoHttpRequest {
return newBeegoRequest(url, "HEAD")
func Head(url string) *BeegoHTTPRequest {
return NewBeegoRequest(url, "HEAD")
}
// BeegoHttpSettings
type BeegoHttpSettings struct {
// BeegoHTTPSettings is the http.Client setting
type BeegoHTTPSettings struct {
ShowDebug bool
UserAgent string
ConnectTimeout time.Duration
ReadWriteTimeout time.Duration
TlsClientConfig *tls.Config
TLSClientConfig *tls.Config
Proxy func(*http.Request) (*url.URL, error)
Transport http.RoundTripper
EnableCookie bool
Gzip bool
DumpBody bool
}
// BeegoHttpRequest provides more useful methods for requesting one url than http.Request.
type BeegoHttpRequest struct {
// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request.
type BeegoHTTPRequest struct {
url string
req *http.Request
params map[string]string
params map[string][]string
files map[string]string
setting BeegoHttpSettings
setting BeegoHTTPSettings
resp *http.Response
body []byte
dump []byte
}
// Change request settings
func (b *BeegoHttpRequest) Setting(setting BeegoHttpSettings) *BeegoHttpRequest {
// GetRequest return the request object
func (b *BeegoHTTPRequest) GetRequest() *http.Request {
return b.req
}
// Setting Change request settings
func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest {
b.setting = setting
return b
}
// SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password.
func (b *BeegoHttpRequest) SetBasicAuth(username, password string) *BeegoHttpRequest {
func (b *BeegoHTTPRequest) SetBasicAuth(username, password string) *BeegoHTTPRequest {
b.req.SetBasicAuth(username, password)
return b
}
// SetEnableCookie sets enable/disable cookiejar
func (b *BeegoHttpRequest) SetEnableCookie(enable bool) *BeegoHttpRequest {
func (b *BeegoHTTPRequest) SetEnableCookie(enable bool) *BeegoHTTPRequest {
b.setting.EnableCookie = enable
return b
}
// SetUserAgent sets User-Agent header field
func (b *BeegoHttpRequest) SetUserAgent(useragent string) *BeegoHttpRequest {
func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest {
b.setting.UserAgent = useragent
return b
}
// Debug sets show debug or not when executing request.
func (b *BeegoHttpRequest) Debug(isdebug bool) *BeegoHttpRequest {
func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest {
b.setting.ShowDebug = isdebug
return b
}
// DumpBody setting whether need to Dump the Body.
func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest {
b.setting.DumpBody = isdump
return b
}
// DumpRequest return the DumpRequest
func (b *BeegoHTTPRequest) DumpRequest() []byte {
return b.dump
}
// SetTimeout sets connect time out and read-write time out for BeegoRequest.
func (b *BeegoHttpRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHttpRequest {
func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest {
b.setting.ConnectTimeout = connectTimeout
b.setting.ReadWriteTimeout = readWriteTimeout
return b
}
// SetTLSClientConfig sets tls connection configurations if visiting https url.
func (b *BeegoHttpRequest) SetTLSClientConfig(config *tls.Config) *BeegoHttpRequest {
b.setting.TlsClientConfig = config
func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest {
b.setting.TLSClientConfig = config
return b
}
// Header add header item string in request.
func (b *BeegoHttpRequest) Header(key, value string) *BeegoHttpRequest {
func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest {
b.req.Header.Set(key, value)
return b
}
// Set the protocol version for incoming requests.
// SetHost set the request host
func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest {
b.req.Host = host
return b
}
// SetProtocolVersion Set the protocol version for incoming requests.
// Client requests always use HTTP/1.1.
func (b *BeegoHttpRequest) SetProtocolVersion(vers string) *BeegoHttpRequest {
func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest {
if len(vers) == 0 {
vers = "HTTP/1.1"
}
@ -202,44 +242,49 @@ func (b *BeegoHttpRequest) SetProtocolVersion(vers string) *BeegoHttpRequest {
}
// SetCookie add cookie into request.
func (b *BeegoHttpRequest) SetCookie(cookie *http.Cookie) *BeegoHttpRequest {
func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest {
b.req.Header.Add("Cookie", cookie.String())
return b
}
// Set transport to
func (b *BeegoHttpRequest) SetTransport(transport http.RoundTripper) *BeegoHttpRequest {
// SetTransport set the setting transport
func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest {
b.setting.Transport = transport
return b
}
// Set http proxy
// SetProxy set the http proxy
// example:
//
// func(req *http.Request) (*url.URL, error) {
// u, _ := url.ParseRequestURI("http://127.0.0.1:8118")
// return u, nil
// }
func (b *BeegoHttpRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHttpRequest {
func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHTTPRequest {
b.setting.Proxy = proxy
return b
}
// Param adds query param in to request.
// params build query string as ?key1=value1&key2=value2...
func (b *BeegoHttpRequest) Param(key, value string) *BeegoHttpRequest {
b.params[key] = value
func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest {
if param, ok := b.params[key]; ok {
b.params[key] = append(param, value)
} else {
b.params[key] = []string{value}
}
return b
}
func (b *BeegoHttpRequest) PostFile(formname, filename string) *BeegoHttpRequest {
// PostFile add a post file to the request
func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest {
b.files[formname] = filename
return b
}
// Body adds request raw body.
// it supports string and []byte.
func (b *BeegoHttpRequest) Body(data interface{}) *BeegoHttpRequest {
func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest {
switch t := data.(type) {
case string:
bf := bytes.NewBufferString(t)
@ -253,7 +298,22 @@ func (b *BeegoHttpRequest) Body(data interface{}) *BeegoHttpRequest {
return b
}
func (b *BeegoHttpRequest) buildUrl(paramBody string) {
// JSONBody adds request raw body encoding by JSON.
func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) {
if b.req.Body == nil && obj != nil {
buf := bytes.NewBuffer(nil)
enc := json.NewEncoder(buf)
if err := enc.Encode(obj); err != nil {
return b, err
}
b.req.Body = ioutil.NopCloser(buf)
b.req.ContentLength = int64(buf.Len())
b.req.Header.Set("Content-Type", "application/json")
}
return b, nil
}
func (b *BeegoHTTPRequest) buildURL(paramBody string) {
// build GET url with query string
if b.req.Method == "GET" && len(paramBody) > 0 {
if strings.Index(b.url, "?") != -1 {
@ -264,8 +324,8 @@ func (b *BeegoHttpRequest) buildUrl(paramBody string) {
return
}
// build POST url and body
if b.req.Method == "POST" && b.req.Body == nil {
// build POST/PUT/PATCH url and body
if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH") && b.req.Body == nil {
// with files
if len(b.files) > 0 {
pr, pw := io.Pipe()
@ -274,21 +334,23 @@ func (b *BeegoHttpRequest) buildUrl(paramBody string) {
for formname, filename := range b.files {
fileWriter, err := bodyWriter.CreateFormFile(formname, filename)
if err != nil {
log.Fatal(err)
log.Println("Httplib:", err)
}
fh, err := os.Open(filename)
if err != nil {
log.Fatal(err)
log.Println("Httplib:", err)
}
//iocopy
_, err = io.Copy(fileWriter, fh)
fh.Close()
if err != nil {
log.Fatal(err)
log.Println("Httplib:", err)
}
}
for k, v := range b.params {
bodyWriter.WriteField(k, v)
for _, vv := range v {
bodyWriter.WriteField(k, vv)
}
}
bodyWriter.Close()
pw.Close()
@ -306,24 +368,36 @@ func (b *BeegoHttpRequest) buildUrl(paramBody string) {
}
}
func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) {
if b.resp.StatusCode != 0 {
return b.resp, nil
}
resp, err := b.DoRequest()
if err != nil {
return nil, err
}
b.resp = resp
return resp, nil
}
// DoRequest will do the client.Do
func (b *BeegoHTTPRequest) DoRequest() (*http.Response, error) {
var paramBody string
if len(b.params) > 0 {
var buf bytes.Buffer
for k, v := range b.params {
for _, vv := range v {
buf.WriteString(url.QueryEscape(k))
buf.WriteByte('=')
buf.WriteString(url.QueryEscape(v))
buf.WriteString(url.QueryEscape(vv))
buf.WriteByte('&')
}
}
paramBody = buf.String()
paramBody = paramBody[0 : len(paramBody)-1]
}
b.buildUrl(paramBody)
b.buildURL(paramBody)
url, err := url.Parse(b.url)
if err != nil {
return nil, err
@ -336,7 +410,7 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
if trans == nil {
// create default transport
trans = &http.Transport{
TLSClientConfig: b.setting.TlsClientConfig,
TLSClientConfig: b.setting.TLSClientConfig,
Proxy: b.setting.Proxy,
Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout),
}
@ -344,7 +418,7 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
// if b.transport is *http.Transport then set the settings.
if t, ok := trans.(*http.Transport); ok {
if t.TLSClientConfig == nil {
t.TLSClientConfig = b.setting.TlsClientConfig
t.TLSClientConfig = b.setting.TLSClientConfig
}
if t.Proxy == nil {
t.Proxy = b.setting.Proxy
@ -355,7 +429,7 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
}
}
var jar http.CookieJar = nil
var jar http.CookieJar
if b.setting.EnableCookie {
if defaultCookieJar == nil {
createDefaultCookie()
@ -373,24 +447,18 @@ func (b *BeegoHttpRequest) getResponse() (*http.Response, error) {
}
if b.setting.ShowDebug {
dump, err := httputil.DumpRequest(b.req, true)
dump, err := httputil.DumpRequest(b.req, b.setting.DumpBody)
if err != nil {
println(err.Error())
log.Println(err.Error())
}
println(string(dump))
b.dump = dump
}
resp, err := client.Do(b.req)
if err != nil {
return nil, err
}
b.resp = resp
return resp, nil
return client.Do(b.req)
}
// String returns the body string in response.
// it calls Response inner.
func (b *BeegoHttpRequest) String() (string, error) {
func (b *BeegoHTTPRequest) String() (string, error) {
data, err := b.Bytes()
if err != nil {
return "", err
@ -401,7 +469,7 @@ func (b *BeegoHttpRequest) String() (string, error) {
// Bytes returns the body []byte in response.
// it calls Response inner.
func (b *BeegoHttpRequest) Bytes() ([]byte, error) {
func (b *BeegoHTTPRequest) Bytes() ([]byte, error) {
if b.body != nil {
return b.body, nil
}
@ -413,16 +481,21 @@ func (b *BeegoHttpRequest) Bytes() ([]byte, error) {
return nil, nil
}
defer resp.Body.Close()
b.body, err = ioutil.ReadAll(resp.Body)
if b.setting.Gzip && resp.Header.Get("Content-Encoding") == "gzip" {
reader, err := gzip.NewReader(resp.Body)
if err != nil {
return nil, err
}
return b.body, nil
b.body, err = ioutil.ReadAll(reader)
} else {
b.body, err = ioutil.ReadAll(resp.Body)
}
return b.body, err
}
// ToFile saves the body data in response to one file.
// it calls Response inner.
func (b *BeegoHttpRequest) ToFile(filename string) error {
func (b *BeegoHTTPRequest) ToFile(filename string) error {
f, err := os.Create(filename)
if err != nil {
return err
@ -441,9 +514,9 @@ func (b *BeegoHttpRequest) ToFile(filename string) error {
return err
}
// ToJson returns the map that marshals from the body bytes as json in response .
// ToJSON returns the map that marshals from the body bytes as json in response .
// it calls Response inner.
func (b *BeegoHttpRequest) ToJson(v interface{}) error {
func (b *BeegoHTTPRequest) ToJSON(v interface{}) error {
data, err := b.Bytes()
if err != nil {
return err
@ -451,9 +524,9 @@ func (b *BeegoHttpRequest) ToJson(v interface{}) error {
return json.Unmarshal(data, v)
}
// ToXml returns the map that marshals from the body bytes as xml in response .
// ToXML returns the map that marshals from the body bytes as xml in response .
// it calls Response inner.
func (b *BeegoHttpRequest) ToXml(v interface{}) error {
func (b *BeegoHTTPRequest) ToXML(v interface{}) error {
data, err := b.Bytes()
if err != nil {
return err
@ -462,7 +535,7 @@ func (b *BeegoHttpRequest) ToXml(v interface{}) error {
}
// Response executes request client gets response mannually.
func (b *BeegoHttpRequest) Response() (*http.Response, error) {
func (b *BeegoHTTPRequest) Response() (*http.Response, error) {
return b.getResponse()
}
@ -473,7 +546,7 @@ func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, ad
if err != nil {
return nil, err
}
conn.SetDeadline(time.Now().Add(rwTimeout))
return conn, nil
err = conn.SetDeadline(time.Now().Add(rwTimeout))
return conn, err
}
}

View File

@ -19,6 +19,7 @@ import (
"os"
"strings"
"testing"
"time"
)
func TestResponse(t *testing.T) {
@ -149,10 +150,11 @@ func TestWithUserAgent(t *testing.T) {
func TestWithSetting(t *testing.T) {
v := "beego"
var setting BeegoHttpSettings
var setting BeegoHTTPSettings
setting.EnableCookie = true
setting.UserAgent = v
setting.Transport = nil
setting.ReadWriteTimeout = 5 * time.Second
SetDefaultSetting(setting)
str, err := Get("http://httpbin.org/get").String()
@ -176,11 +178,11 @@ func TestToJson(t *testing.T) {
t.Log(resp)
// httpbin will return http remote addr
type Ip struct {
type IP struct {
Origin string `json:"origin"`
}
var ip Ip
err = req.ToJson(&ip)
var ip IP
err = req.ToJSON(&ip)
if err != nil {
t.Fatal(err)
}

25
log.go
View File

@ -32,20 +32,20 @@ const (
LevelDebug
)
// SetLogLevel sets the global log level used by the simple
// logger.
// BeeLogger references the used application logger.
var BeeLogger = logs.NewLogger(100)
// SetLevel sets the global log level used by the simple logger.
func SetLevel(l int) {
BeeLogger.SetLevel(l)
}
// SetLogFuncCall set the CallDepth, default is 3
func SetLogFuncCall(b bool) {
BeeLogger.EnableFuncCallDepth(b)
BeeLogger.SetLogFuncCallDepth(3)
}
// logger references the used application logger.
var BeeLogger *logs.BeeLogger
// SetLogger sets a new logger.
func SetLogger(adaptername string, config string) error {
err := BeeLogger.SetLogger(adaptername, config)
@ -55,10 +55,12 @@ func SetLogger(adaptername string, config string) error {
return nil
}
// Emergency logs a message at emergency level.
func Emergency(v ...interface{}) {
BeeLogger.Emergency(generateFmtStr(len(v)), v...)
}
// Alert logs a message at alert level.
func Alert(v ...interface{}) {
BeeLogger.Alert(generateFmtStr(len(v)), v...)
}
@ -78,23 +80,24 @@ func Warning(v ...interface{}) {
BeeLogger.Warning(generateFmtStr(len(v)), v...)
}
// Deprecated: compatibility alias for Warning(), Will be removed in 1.5.0.
// Warn compatibility alias for Warning()
func Warn(v ...interface{}) {
Warning(v...)
BeeLogger.Warn(generateFmtStr(len(v)), v...)
}
// Notice logs a message at notice level.
func Notice(v ...interface{}) {
BeeLogger.Notice(generateFmtStr(len(v)), v...)
}
// Info logs a message at info level.
// Informational logs a message at info level.
func Informational(v ...interface{}) {
BeeLogger.Informational(generateFmtStr(len(v)), v...)
}
// Deprecated: compatibility alias for Warning(), Will be removed in 1.5.0.
// Info compatibility alias for Warning()
func Info(v ...interface{}) {
Informational(v...)
BeeLogger.Info(generateFmtStr(len(v)), v...)
}
// Debug logs a message at debug level.
@ -103,7 +106,7 @@ func Debug(v ...interface{}) {
}
// Trace logs a message at trace level.
// Deprecated: compatibility alias for Warning(), Will be removed in 1.5.0.
// compatibility alias for Warning()
func Trace(v ...interface{}) {
BeeLogger.Trace(generateFmtStr(len(v)), v...)
}

View File

@ -21,9 +21,9 @@ import (
"net"
)
// ConnWriter implements LoggerInterface.
// connWriter implements LoggerInterface.
// it writes messages in keep-live tcp connection.
type ConnWriter struct {
type connWriter struct {
lg *log.Logger
innerWriter io.WriteCloser
ReconnectOnMsg bool `json:"reconnectOnMsg"`
@ -33,22 +33,22 @@ type ConnWriter struct {
Level int `json:"level"`
}
// create new ConnWrite returning as LoggerInterface.
func NewConn() LoggerInterface {
conn := new(ConnWriter)
// NewConn create new ConnWrite returning as LoggerInterface.
func NewConn() Logger {
conn := new(connWriter)
conn.Level = LevelTrace
return conn
}
// init connection writer with json config.
// Init init connection writer with json config.
// json config only need key "level".
func (c *ConnWriter) Init(jsonconfig string) error {
func (c *connWriter) Init(jsonconfig string) error {
return json.Unmarshal([]byte(jsonconfig), c)
}
// write message in connection.
// WriteMsg write message in connection.
// if connection is down, try to re-connect.
func (c *ConnWriter) WriteMsg(msg string, level int) error {
func (c *connWriter) WriteMsg(msg string, level int) error {
if level > c.Level {
return nil
}
@ -66,19 +66,19 @@ func (c *ConnWriter) WriteMsg(msg string, level int) error {
return nil
}
// implementing method. empty.
func (c *ConnWriter) Flush() {
// Flush implementing method. empty.
func (c *connWriter) Flush() {
}
// destroy connection writer and close tcp listener.
func (c *ConnWriter) Destroy() {
// Destroy destroy connection writer and close tcp listener.
func (c *connWriter) Destroy() {
if c.innerWriter != nil {
c.innerWriter.Close()
}
}
func (c *ConnWriter) connect() error {
func (c *connWriter) connect() error {
if c.innerWriter != nil {
c.innerWriter.Close()
c.innerWriter = nil
@ -98,7 +98,7 @@ func (c *ConnWriter) connect() error {
return nil
}
func (c *ConnWriter) neddedConnectOnMsg() bool {
func (c *connWriter) neddedConnectOnMsg() bool {
if c.Reconnect {
c.Reconnect = false
return true

View File

@ -21,9 +21,11 @@ import (
"runtime"
)
type Brush func(string) string
// brush is a color join function
type brush func(string) string
func NewBrush(color string) Brush {
// newBrush return a fix color Brush
func newBrush(color string) brush {
pre := "\033["
reset := "\033[0m"
return func(text string) string {
@ -31,43 +33,43 @@ func NewBrush(color string) Brush {
}
}
var colors = []Brush{
NewBrush("1;37"), // Emergency white
NewBrush("1;36"), // Alert cyan
NewBrush("1;35"), // Critical magenta
NewBrush("1;31"), // Error red
NewBrush("1;33"), // Warning yellow
NewBrush("1;32"), // Notice green
NewBrush("1;34"), // Informational blue
NewBrush("1;34"), // Debug blue
var colors = []brush{
newBrush("1;37"), // Emergency white
newBrush("1;36"), // Alert cyan
newBrush("1;35"), // Critical magenta
newBrush("1;31"), // Error red
newBrush("1;33"), // Warning yellow
newBrush("1;32"), // Notice green
newBrush("1;34"), // Informational blue
newBrush("1;34"), // Debug blue
}
// ConsoleWriter implements LoggerInterface and writes messages to terminal.
type ConsoleWriter struct {
// consoleWriter implements LoggerInterface and writes messages to terminal.
type consoleWriter struct {
lg *log.Logger
Level int `json:"level"`
}
// create ConsoleWriter returning as LoggerInterface.
func NewConsole() LoggerInterface {
cw := &ConsoleWriter{
// NewConsole create ConsoleWriter returning as LoggerInterface.
func NewConsole() Logger {
cw := &consoleWriter{
lg: log.New(os.Stdout, "", log.Ldate|log.Ltime),
Level: LevelDebug,
}
return cw
}
// init console logger.
// Init init console logger.
// jsonconfig like '{"level":LevelTrace}'.
func (c *ConsoleWriter) Init(jsonconfig string) error {
func (c *consoleWriter) Init(jsonconfig string) error {
if len(jsonconfig) == 0 {
return nil
}
return json.Unmarshal([]byte(jsonconfig), c)
}
// write message in console.
func (c *ConsoleWriter) WriteMsg(msg string, level int) error {
// WriteMsg write message in console.
func (c *consoleWriter) WriteMsg(msg string, level int) error {
if level > c.Level {
return nil
}
@ -80,13 +82,13 @@ func (c *ConsoleWriter) WriteMsg(msg string, level int) error {
return nil
}
// implementing method. empty.
func (c *ConsoleWriter) Destroy() {
// Destroy implementing method. empty.
func (c *consoleWriter) Destroy() {
}
// implementing method. empty.
func (c *ConsoleWriter) Flush() {
// Flush implementing method. empty.
func (c *consoleWriter) Flush() {
}

View File

@ -42,12 +42,3 @@ func TestConsole(t *testing.T) {
log2.SetLogger("console", `{"level":3}`)
testConsoleCalls(log2)
}
func BenchmarkConsole(b *testing.B) {
log := NewLogger(10000)
log.EnableFuncCallDepth(true)
log.SetLogger("console", "")
for i := 0; i < b.N; i++ {
log.Debug("debug")
}
}

80
logs/es/es.go Normal file
View File

@ -0,0 +1,80 @@
package es
import (
"encoding/json"
"errors"
"fmt"
"net"
"net/url"
"time"
"github.com/astaxie/beego/logs"
"github.com/belogik/goes"
)
// NewES return a LoggerInterface
func NewES() logs.Logger {
cw := &esLogger{
Level: logs.LevelDebug,
}
return cw
}
type esLogger struct {
*goes.Connection
DSN string `json:"dsn"`
Level int `json:"level"`
}
// {"dsn":"http://localhost:9200/","level":1}
func (el *esLogger) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), el)
if err != nil {
return err
}
if el.DSN == "" {
return errors.New("empty dsn")
} else if u, err := url.Parse(el.DSN); err != nil {
return err
} else if u.Path == "" {
return errors.New("missing prefix")
} else if host, port, err := net.SplitHostPort(u.Host); err != nil {
return err
} else {
conn := goes.NewConnection(host, port)
el.Connection = conn
}
return nil
}
// WriteMsg will write the msg and level into es
func (el *esLogger) WriteMsg(msg string, level int) error {
if level > el.Level {
return nil
}
t := time.Now()
vals := make(map[string]interface{})
vals["@timestamp"] = t.Format(time.RFC3339)
vals["@msg"] = msg
d := goes.Document{
Index: fmt.Sprintf("%04d.%02d.%02d", t.Year(), t.Month(), t.Day()),
Type: "logs",
Fields: vals,
}
_, err := el.Index(d, nil)
return err
}
// Destroy is a empty method
func (el *esLogger) Destroy() {
}
// Flush is a empty method
func (el *esLogger) Flush() {
}
func init() {
logs.Register("es", NewES)
}

View File

@ -20,7 +20,6 @@ import (
"errors"
"fmt"
"io"
"log"
"os"
"path/filepath"
"strings"
@ -28,84 +27,62 @@ import (
"time"
)
// FileLogWriter implements LoggerInterface.
// fileLogWriter implements LoggerInterface.
// It writes messages by lines limit, file size limit, or time frequency.
type FileLogWriter struct {
*log.Logger
mw *MuxWriter
type fileLogWriter struct {
sync.Mutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize
// The opened file
Filename string `json:"filename"`
fileWriter *os.File
Maxlines int `json:"maxlines"`
maxlines_curlines int
// Rotate at line
MaxLines int `json:"maxlines"`
maxLinesCurLines int
// Rotate at size
Maxsize int `json:"maxsize"`
maxsize_cursize int
MaxSize int `json:"maxsize"`
maxSizeCurSize int
// Rotate daily
Daily bool `json:"daily"`
Maxdays int64 `json:"maxdays"`
daily_opendate int
MaxDays int64 `json:"maxdays"`
dailyOpenDate int
Rotate bool `json:"rotate"`
startLock sync.Mutex // Only one log can write to the file
Level int `json:"level"`
Perm os.FileMode `json:"perm"`
}
// an *os.File writer with locker.
type MuxWriter struct {
sync.Mutex
fd *os.File
}
// write to os.File.
func (l *MuxWriter) Write(b []byte) (int, error) {
l.Lock()
defer l.Unlock()
return l.fd.Write(b)
}
// set os.File in writer.
func (l *MuxWriter) SetFd(fd *os.File) {
if l.fd != nil {
l.fd.Close()
}
l.fd = fd
}
// create a FileLogWriter returning as LoggerInterface.
func NewFileWriter() LoggerInterface {
w := &FileLogWriter{
// NewFileWriter create a FileLogWriter returning as LoggerInterface.
func newFileWriter() Logger {
w := &fileLogWriter{
Filename: "",
Maxlines: 1000000,
Maxsize: 1 << 28, //256 MB
MaxLines: 1000000,
MaxSize: 1 << 28, //256 MB
Daily: true,
Maxdays: 7,
MaxDays: 7,
Rotate: true,
Level: LevelTrace,
Perm: 0660,
}
// use MuxWriter instead direct use os.File for lock write when rotate
w.mw = new(MuxWriter)
// set MuxWriter as Logger's io.Writer
w.Logger = log.New(w.mw, "", log.Ldate|log.Ltime)
return w
}
// Init file logger with json config.
// jsonconfig like:
// jsonConfig like:
// {
// "filename":"logs/beego.log",
// "maxlines":10000,
// "maxLines":10000,
// "maxsize":1<<30,
// "daily":true,
// "maxdays":15,
// "rotate":true
// "maxDays":15,
// "rotate":true,
// "perm":0600
// }
func (w *FileLogWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), w)
func (w *fileLogWriter) Init(jsonConfig string) error {
err := json.Unmarshal([]byte(jsonConfig), w)
if err != nil {
return err
}
@ -117,67 +94,120 @@ func (w *FileLogWriter) Init(jsonconfig string) error {
}
// start file logger. create log file and set to locker-inside file writer.
func (w *FileLogWriter) startLogger() error {
fd, err := w.createLogFile()
func (w *fileLogWriter) startLogger() error {
file, err := w.createLogFile()
if err != nil {
return err
}
w.mw.SetFd(fd)
if w.fileWriter != nil {
w.fileWriter.Close()
}
w.fileWriter = file
return w.initFd()
}
func (w *FileLogWriter) docheck(size int) {
w.startLock.Lock()
defer w.startLock.Unlock()
if w.Rotate && ((w.Maxlines > 0 && w.maxlines_curlines >= w.Maxlines) ||
(w.Maxsize > 0 && w.maxsize_cursize >= w.Maxsize) ||
(w.Daily && time.Now().Day() != w.daily_opendate)) {
if err := w.DoRotate(); err != nil {
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
return
}
}
w.maxlines_curlines++
w.maxsize_cursize += size
func (w *fileLogWriter) needRotate(size int, day int) bool {
return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) ||
(w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) ||
(w.Daily && day != w.dailyOpenDate)
}
// write logger message into file.
func (w *FileLogWriter) WriteMsg(msg string, level int) error {
// WriteMsg write logger message into file.
func (w *fileLogWriter) WriteMsg(msg string, level int) error {
if level > w.Level {
return nil
}
n := 24 + len(msg) // 24 stand for the length "2013/06/23 21:00:22 [T] "
w.docheck(n)
w.Logger.Println(msg)
return nil
//2016/01/12 21:34:33
now := time.Now()
y, mo, d := now.Date()
h, mi, s := now.Clock()
//len(2006/01/02 15:03:04)==19
var buf [20]byte
t := 3
for y >= 10 {
p := y / 10
buf[t] = byte('0' + y - p*10)
y = p
t--
}
buf[0] = byte('0' + y)
buf[4] = '/'
if mo > 9 {
buf[5] = '1'
buf[6] = byte('0' + mo - 9)
} else {
buf[5] = '0'
buf[6] = byte('0' + mo)
}
buf[7] = '/'
t = d / 10
buf[8] = byte('0' + t)
buf[9] = byte('0' + d - t*10)
buf[10] = ' '
t = h / 10
buf[11] = byte('0' + t)
buf[12] = byte('0' + h - t*10)
buf[13] = ':'
t = mi / 10
buf[14] = byte('0' + t)
buf[15] = byte('0' + mi - t*10)
buf[16] = ':'
t = s / 10
buf[17] = byte('0' + t)
buf[18] = byte('0' + s - t*10)
buf[19] = ' '
msg = string(buf[0:]) + msg + "\n"
if w.Rotate {
if w.needRotate(len(msg), d) {
w.Lock()
if w.needRotate(len(msg), d) {
if err := w.doRotate(); err != nil {
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
}
}
w.Unlock()
}
}
w.Lock()
_, err := w.fileWriter.Write([]byte(msg))
if err == nil {
w.maxLinesCurLines++
w.maxSizeCurSize += len(msg)
}
w.Unlock()
return err
}
func (w *FileLogWriter) createLogFile() (*os.File, error) {
func (w *fileLogWriter) createLogFile() (*os.File, error) {
// Open the log file
fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0660)
fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, w.Perm)
return fd, err
}
func (w *FileLogWriter) initFd() error {
fd := w.mw.fd
finfo, err := fd.Stat()
func (w *fileLogWriter) initFd() error {
fd := w.fileWriter
fInfo, err := fd.Stat()
if err != nil {
return fmt.Errorf("get stat err: %s\n", err)
}
w.maxsize_cursize = int(finfo.Size())
w.daily_opendate = time.Now().Day()
w.maxlines_curlines = 0
if finfo.Size() > 0 {
w.maxSizeCurSize = int(fInfo.Size())
w.dailyOpenDate = time.Now().Day()
w.maxLinesCurLines = 0
if fInfo.Size() > 0 {
count, err := w.lines()
if err != nil {
return err
}
w.maxlines_curlines = count
w.maxLinesCurLines = count
}
return nil
}
func (w *FileLogWriter) lines() (int, error) {
func (w *fileLogWriter) lines() (int, error) {
fd, err := os.Open(w.Filename)
if err != nil {
return 0, err
@ -205,59 +235,60 @@ func (w *FileLogWriter) lines() (int, error) {
}
// DoRotate means it need to write file in new file.
// new file name like xx.log.2013-01-01.2
func (w *FileLogWriter) DoRotate() error {
// new file name like xx.2013-01-01.2.log
func (w *fileLogWriter) doRotate() error {
_, err := os.Lstat(w.Filename)
if err == nil { // file exists
if err != nil {
return err
}
// file exists
// Find the next available number
num := 1
fname := ""
fName := ""
suffix := filepath.Ext(w.Filename)
filenameOnly := strings.TrimSuffix(w.Filename, suffix)
if suffix == "" {
suffix = ".log"
}
for ; err == nil && num <= 999; num++ {
fname = w.Filename + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), num)
_, err = os.Lstat(fname)
fName = filenameOnly + fmt.Sprintf(".%s.%03d%s", time.Now().Format("2006-01-02"), num, suffix)
_, err = os.Lstat(fName)
}
// return error if the last file checked still existed
if err == nil {
return fmt.Errorf("Rotate: Cannot find free log number to rename %s\n", w.Filename)
}
// block Logger's io.Writer
w.mw.Lock()
defer w.mw.Unlock()
fd := w.mw.fd
fd.Close()
// close fd before rename
// Rename the file to its newfound home
err = os.Rename(w.Filename, fname)
if err != nil {
return fmt.Errorf("Rotate: %s\n", err)
}
// close fileWriter before rename
w.fileWriter.Close()
// Rename the file to its new found name
// even if occurs error,we MUST guarantee to restart new logger
renameErr := os.Rename(w.Filename, fName)
// re-start logger
err = w.startLogger()
if err != nil {
return fmt.Errorf("Rotate StartLogger: %s\n", err)
}
startLoggerErr := w.startLogger()
go w.deleteOldLog()
}
if startLoggerErr != nil {
return fmt.Errorf("Rotate StartLogger: %s\n", startLoggerErr)
}
if renameErr != nil {
return fmt.Errorf("Rotate: %s\n", renameErr)
}
return nil
}
func (w *FileLogWriter) deleteOldLog() {
func (w *fileLogWriter) deleteOldLog() {
dir := filepath.Dir(w.Filename)
filepath.Walk(dir, func(path string, info os.FileInfo, err error) (returnErr error) {
defer func() {
if r := recover(); r != nil {
returnErr = fmt.Errorf("Unable to delete old log '%s', error: %+v", path, r)
fmt.Println(returnErr)
fmt.Fprintf(os.Stderr, "Unable to delete old log '%s', error: %v\n", path, r)
}
}()
if !info.IsDir() && info.ModTime().Unix() < (time.Now().Unix()-60*60*24*w.Maxdays) {
if !info.IsDir() && info.ModTime().Unix() < (time.Now().Unix()-60*60*24*w.MaxDays) {
if strings.HasPrefix(filepath.Base(path), filepath.Base(w.Filename)) {
os.Remove(path)
}
@ -266,18 +297,18 @@ func (w *FileLogWriter) deleteOldLog() {
})
}
// destroy file logger, close file writer.
func (w *FileLogWriter) Destroy() {
w.mw.fd.Close()
// Destroy close the file description, close file writer.
func (w *fileLogWriter) Destroy() {
w.fileWriter.Close()
}
// flush file logger.
// Flush flush file logger.
// there are no buffering messages in file logger in memory.
// flush file means sync file from disk.
func (w *FileLogWriter) Flush() {
w.mw.fd.Sync()
func (w *fileLogWriter) Flush() {
w.fileWriter.Sync()
}
func init() {
Register("file", NewFileWriter)
Register("file", newFileWriter)
}

View File

@ -23,7 +23,7 @@ import (
"time"
)
func TestFile(t *testing.T) {
func TestFile1(t *testing.T) {
log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test.log"}`)
log.Debug("debug")
@ -34,25 +34,24 @@ func TestFile(t *testing.T) {
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
time.Sleep(time.Second * 4)
f, err := os.Open("test.log")
if err != nil {
t.Fatal(err)
}
b := bufio.NewReader(f)
linenum := 0
lineNum := 0
for {
line, _, err := b.ReadLine()
if err != nil {
break
}
if len(line) > 0 {
linenum++
lineNum++
}
}
var expected = LevelDebug + 1
if linenum != expected {
t.Fatal(linenum, "not "+strconv.Itoa(expected)+" lines")
if lineNum != expected {
t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines")
}
os.Remove("test.log")
}
@ -68,25 +67,24 @@ func TestFile2(t *testing.T) {
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
time.Sleep(time.Second * 4)
f, err := os.Open("test2.log")
if err != nil {
t.Fatal(err)
}
b := bufio.NewReader(f)
linenum := 0
lineNum := 0
for {
line, _, err := b.ReadLine()
if err != nil {
break
}
if len(line) > 0 {
linenum++
lineNum++
}
}
var expected = LevelError + 1
if linenum != expected {
t.Fatal(linenum, "not "+strconv.Itoa(expected)+" lines")
if lineNum != expected {
t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines")
}
os.Remove("test2.log")
}
@ -102,13 +100,13 @@ func TestFileRotate(t *testing.T) {
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
time.Sleep(time.Second * 4)
rotatename := "test3.log" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1)
b, err := exists(rotatename)
rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log"
b, err := exists(rotateName)
if !b || err != nil {
os.Remove("test3.log")
t.Fatal("rotate not generated")
}
os.Remove(rotatename)
os.Remove(rotateName)
os.Remove("test3.log")
}
@ -131,3 +129,45 @@ func BenchmarkFile(b *testing.B) {
}
os.Remove("test4.log")
}
func BenchmarkFileAsynchronous(b *testing.B) {
log := NewLogger(100000)
log.SetLogger("file", `{"filename":"test4.log"}`)
log.Async()
for i := 0; i < b.N; i++ {
log.Debug("debug")
}
os.Remove("test4.log")
}
func BenchmarkFileCallDepth(b *testing.B) {
log := NewLogger(100000)
log.SetLogger("file", `{"filename":"test4.log"}`)
log.EnableFuncCallDepth(true)
log.SetLogFuncCallDepth(2)
for i := 0; i < b.N; i++ {
log.Debug("debug")
}
os.Remove("test4.log")
}
func BenchmarkFileAsynchronousCallDepth(b *testing.B) {
log := NewLogger(100000)
log.SetLogger("file", `{"filename":"test4.log"}`)
log.EnableFuncCallDepth(true)
log.SetLogFuncCallDepth(2)
log.Async()
for i := 0; i < b.N; i++ {
log.Debug("debug")
}
os.Remove("test4.log")
}
func BenchmarkFileOnGoroutine(b *testing.B) {
log := NewLogger(100000)
log.SetLogger("file", `{"filename":"test4.log"}`)
for i := 0; i < b.N; i++ {
go log.Debug("debug")
}
os.Remove("test4.log")
}

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Package logs provide a general log interface
// Usage:
//
// import "github.com/astaxie/beego/logs"
@ -34,8 +35,10 @@ package logs
import (
"fmt"
"os"
"path"
"runtime"
"strconv"
"sync"
)
@ -60,10 +63,10 @@ const (
LevelWarn = LevelWarning
)
type loggerType func() LoggerInterface
type loggerType func() Logger
// LoggerInterface defines the behavior of a log provider.
type LoggerInterface interface {
// Logger defines the behavior of a log provider.
type Logger interface {
Init(config string) error
WriteMsg(msg string, level int) error
Destroy()
@ -92,8 +95,14 @@ type BeeLogger struct {
level int
enableFuncCallDepth bool
loggerFuncCallDepth int
msg chan *logMsg
outputs map[string]LoggerInterface
asynchronous bool
msgChan chan *logMsg
outputs []*nameLogger
}
type nameLogger struct {
Logger
name string
}
type logMsg struct {
@ -101,90 +110,117 @@ type logMsg struct {
msg string
}
var logMsgPool *sync.Pool
// NewLogger returns a new BeeLogger.
// channellen means the number of messages in chan.
// channelLen means the number of messages in chan(used where asynchronous is true).
// if the buffering chan is full, logger adapters write to file or other way.
func NewLogger(channellen int64) *BeeLogger {
func NewLogger(channelLen int64) *BeeLogger {
bl := new(BeeLogger)
bl.level = LevelDebug
bl.loggerFuncCallDepth = 2
bl.msg = make(chan *logMsg, channellen)
bl.outputs = make(map[string]LoggerInterface)
//bl.SetLogger("console", "") // default output to console
bl.msgChan = make(chan *logMsg, channelLen)
return bl
}
// Async set the log to asynchronous and start the goroutine
func (bl *BeeLogger) Async() *BeeLogger {
bl.asynchronous = true
logMsgPool = &sync.Pool{
New: func() interface{} {
return &logMsg{}
},
}
go bl.startLogger()
return bl
}
// SetLogger provides a given logger adapter into BeeLogger with config string.
// config need to be correct JSON as string: {"interval":360}.
func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
func (bl *BeeLogger) SetLogger(adapterName string, config string) error {
bl.lock.Lock()
defer bl.lock.Unlock()
if log, ok := adapters[adaptername]; ok {
if log, ok := adapters[adapterName]; ok {
lg := log()
err := lg.Init(config)
bl.outputs[adaptername] = lg
if err != nil {
fmt.Println("logs.BeeLogger.SetLogger: " + err.Error())
fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error())
return err
}
bl.outputs = append(bl.outputs, &nameLogger{name: adapterName, Logger: lg})
} else {
return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adaptername)
return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName)
}
return nil
}
// remove a logger adapter in BeeLogger.
func (bl *BeeLogger) DelLogger(adaptername string) error {
// DelLogger remove a logger adapter in BeeLogger.
func (bl *BeeLogger) DelLogger(adapterName string) error {
bl.lock.Lock()
defer bl.lock.Unlock()
if lg, ok := bl.outputs[adaptername]; ok {
outputs := []*nameLogger{}
for _, lg := range bl.outputs {
if lg.name == adapterName {
lg.Destroy()
delete(bl.outputs, adaptername)
return nil
} else {
return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adaptername)
outputs = append(outputs, lg)
}
}
if len(outputs) == len(bl.outputs) {
return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName)
}
bl.outputs = outputs
return nil
}
func (bl *BeeLogger) writeToLoggers(msg string, level int) {
for _, l := range bl.outputs {
err := l.WriteMsg(msg, level)
if err != nil {
fmt.Fprintf(os.Stderr, "unable to WriteMsg to adapter:%v,error:%v\n", l.name, err)
}
}
}
func (bl *BeeLogger) writerMsg(loglevel int, msg string) error {
if loglevel > bl.level {
return nil
}
lm := new(logMsg)
lm.level = loglevel
func (bl *BeeLogger) writeMsg(logLevel int, msg string) error {
if bl.enableFuncCallDepth {
_, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth)
if _, filename := path.Split(file); filename == "log.go" && (line == 97 || line == 83) {
_, file, line, ok = runtime.Caller(bl.loggerFuncCallDepth + 1)
if !ok {
file = "???"
line = 0
}
if ok {
_, filename := path.Split(file)
lm.msg = fmt.Sprintf("[%s:%d] %s", filename, line, msg)
} else {
lm.msg = msg
msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "]" + msg
}
} else {
if bl.asynchronous {
lm := logMsgPool.Get().(*logMsg)
lm.level = logLevel
lm.msg = msg
bl.msgChan <- lm
} else {
bl.writeToLoggers(msg, logLevel)
}
bl.msg <- lm
return nil
}
// Set log message level.
//
// SetLevel Set log message level.
// If message level (such as LevelDebug) is higher than logger level (such as LevelWarning),
// log providers will not even be sent the message.
func (bl *BeeLogger) SetLevel(l int) {
bl.level = l
}
// set log funcCallDepth
// SetLogFuncCallDepth set log funcCallDepth
func (bl *BeeLogger) SetLogFuncCallDepth(d int) {
bl.loggerFuncCallDepth = d
}
// enable log funcCallDepth
// GetLogFuncCallDepth return log funcCallDepth for wrapper
func (bl *BeeLogger) GetLogFuncCallDepth() int {
return bl.loggerFuncCallDepth
}
// EnableFuncCallDepth enable log funcCallDepth
func (bl *BeeLogger) EnableFuncCallDepth(b bool) {
bl.enableFuncCallDepth = b
}
@ -194,104 +230,129 @@ func (bl *BeeLogger) EnableFuncCallDepth(b bool) {
func (bl *BeeLogger) startLogger() {
for {
select {
case bm := <-bl.msg:
for _, l := range bl.outputs {
err := l.WriteMsg(bm.msg, bm.level)
if err != nil {
fmt.Println("ERROR, unable to WriteMsg:", err)
}
}
case bm := <-bl.msgChan:
bl.writeToLoggers(bm.msg, bm.level)
logMsgPool.Put(bm)
}
}
}
// Log EMERGENCY level message.
// Emergency Log EMERGENCY level message.
func (bl *BeeLogger) Emergency(format string, v ...interface{}) {
if LevelEmergency > bl.level {
return
}
msg := fmt.Sprintf("[M] "+format, v...)
bl.writerMsg(LevelEmergency, msg)
bl.writeMsg(LevelEmergency, msg)
}
// Log ALERT level message.
// Alert Log ALERT level message.
func (bl *BeeLogger) Alert(format string, v ...interface{}) {
if LevelAlert > bl.level {
return
}
msg := fmt.Sprintf("[A] "+format, v...)
bl.writerMsg(LevelAlert, msg)
bl.writeMsg(LevelAlert, msg)
}
// Log CRITICAL level message.
// Critical Log CRITICAL level message.
func (bl *BeeLogger) Critical(format string, v ...interface{}) {
if LevelCritical > bl.level {
return
}
msg := fmt.Sprintf("[C] "+format, v...)
bl.writerMsg(LevelCritical, msg)
bl.writeMsg(LevelCritical, msg)
}
// Log ERROR level message.
// Error Log ERROR level message.
func (bl *BeeLogger) Error(format string, v ...interface{}) {
if LevelError > bl.level {
return
}
msg := fmt.Sprintf("[E] "+format, v...)
bl.writerMsg(LevelError, msg)
bl.writeMsg(LevelError, msg)
}
// Log WARNING level message.
// Warning Log WARNING level message.
func (bl *BeeLogger) Warning(format string, v ...interface{}) {
if LevelWarning > bl.level {
return
}
msg := fmt.Sprintf("[W] "+format, v...)
bl.writerMsg(LevelWarning, msg)
bl.writeMsg(LevelWarning, msg)
}
// Log NOTICE level message.
// Notice Log NOTICE level message.
func (bl *BeeLogger) Notice(format string, v ...interface{}) {
if LevelNotice > bl.level {
return
}
msg := fmt.Sprintf("[N] "+format, v...)
bl.writerMsg(LevelNotice, msg)
bl.writeMsg(LevelNotice, msg)
}
// Log INFORMATIONAL level message.
// Informational Log INFORMATIONAL level message.
func (bl *BeeLogger) Informational(format string, v ...interface{}) {
if LevelInformational > bl.level {
return
}
msg := fmt.Sprintf("[I] "+format, v...)
bl.writerMsg(LevelInformational, msg)
bl.writeMsg(LevelInformational, msg)
}
// Log DEBUG level message.
// Debug Log DEBUG level message.
func (bl *BeeLogger) Debug(format string, v ...interface{}) {
if LevelDebug > bl.level {
return
}
msg := fmt.Sprintf("[D] "+format, v...)
bl.writerMsg(LevelDebug, msg)
bl.writeMsg(LevelDebug, msg)
}
// Log WARN level message.
//
// Deprecated: compatibility alias for Warning(), Will be removed in 1.5.0.
// Warn Log WARN level message.
// compatibility alias for Warning()
func (bl *BeeLogger) Warn(format string, v ...interface{}) {
bl.Warning(format, v...)
if LevelWarning > bl.level {
return
}
msg := fmt.Sprintf("[W] "+format, v...)
bl.writeMsg(LevelWarning, msg)
}
// Log INFO level message.
//
// Deprecated: compatibility alias for Informational(), Will be removed in 1.5.0.
// Info Log INFO level message.
// compatibility alias for Informational()
func (bl *BeeLogger) Info(format string, v ...interface{}) {
bl.Informational(format, v...)
if LevelInformational > bl.level {
return
}
msg := fmt.Sprintf("[I] "+format, v...)
bl.writeMsg(LevelInformational, msg)
}
// Log TRACE level message.
//
// Deprecated: compatibility alias for Debug(), Will be removed in 1.5.0.
// Trace Log TRACE level message.
// compatibility alias for Debug()
func (bl *BeeLogger) Trace(format string, v ...interface{}) {
bl.Debug(format, v...)
if LevelDebug > bl.level {
return
}
msg := fmt.Sprintf("[D] "+format, v...)
bl.writeMsg(LevelDebug, msg)
}
// flush all chan data.
// Flush flush all chan data.
func (bl *BeeLogger) Flush() {
for _, l := range bl.outputs {
l.Flush()
}
}
// close logger, flush all chan data and destroy all adapters in BeeLogger.
// Close close logger, flush all chan data and destroy all adapters in BeeLogger.
func (bl *BeeLogger) Close() {
for {
if len(bl.msg) > 0 {
bm := <-bl.msg
for _, l := range bl.outputs {
err := l.WriteMsg(bm.msg, bm.level)
if err != nil {
fmt.Println("ERROR, unable to WriteMsg (while closing logger):", err)
}
}
if len(bl.msgChan) > 0 {
bm := <-bl.msgChan
bl.writeToLoggers(bm.msg, bm.level)
logMsgPool.Put(bm)
continue
}
break

View File

@ -24,31 +24,26 @@ import (
"time"
)
const (
// no usage
// subjectPhrase = "Diagnostic message from server"
)
// smtpWriter implements LoggerInterface and is used to send emails via given SMTP-server.
type SmtpWriter struct {
Username string `json:"Username"`
// SMTPWriter implements LoggerInterface and is used to send emails via given SMTP-server.
type SMTPWriter struct {
Username string `json:"username"`
Password string `json:"password"`
Host string `json:"Host"`
Host string `json:"host"`
Subject string `json:"subject"`
FromAddress string `json:"fromAddress"`
RecipientAddresses []string `json:"sendTos"`
Level int `json:"level"`
}
// create smtp writer.
func NewSmtpWriter() LoggerInterface {
return &SmtpWriter{Level: LevelTrace}
// NewSMTPWriter create smtp writer.
func newSMTPWriter() Logger {
return &SMTPWriter{Level: LevelTrace}
}
// init smtp writer with json config.
// Init smtp writer with json config.
// config like:
// {
// "Username":"example@gmail.com",
// "username":"example@gmail.com",
// "password:"password",
// "host":"smtp.gmail.com:465",
// "subject":"email title",
@ -56,7 +51,7 @@ func NewSmtpWriter() LoggerInterface {
// "sendTos":["email1","email2"],
// "level":LevelError
// }
func (s *SmtpWriter) Init(jsonconfig string) error {
func (s *SMTPWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), s)
if err != nil {
return err
@ -64,7 +59,7 @@ func (s *SmtpWriter) Init(jsonconfig string) error {
return nil
}
func (s *SmtpWriter) GetSmtpAuth(host string) smtp.Auth {
func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth {
if len(strings.Trim(s.Username, " ")) == 0 && len(strings.Trim(s.Password, " ")) == 0 {
return nil
}
@ -76,7 +71,7 @@ func (s *SmtpWriter) GetSmtpAuth(host string) smtp.Auth {
)
}
func (s *SmtpWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error {
func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error {
client, err := smtp.Dial(hostAddressWithPort)
if err != nil {
return err
@ -129,9 +124,9 @@ func (s *SmtpWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd
return nil
}
// write message in smtp writer.
// WriteMsg write message in smtp writer.
// it will send an email with subject and only this message.
func (s *SmtpWriter) WriteMsg(msg string, level int) error {
func (s *SMTPWriter) WriteMsg(msg string, level int) error {
if level > s.Level {
return nil
}
@ -139,27 +134,27 @@ func (s *SmtpWriter) WriteMsg(msg string, level int) error {
hp := strings.Split(s.Host, ":")
// Set up authentication information.
auth := s.GetSmtpAuth(hp[0])
auth := s.getSMTPAuth(hp[0])
// Connect to the server, authenticate, set the sender and recipient,
// and send the email all in one step.
content_type := "Content-Type: text/plain" + "; charset=UTF-8"
contentType := "Content-Type: text/plain" + "; charset=UTF-8"
mailmsg := []byte("To: " + strings.Join(s.RecipientAddresses, ";") + "\r\nFrom: " + s.FromAddress + "<" + s.FromAddress +
">\r\nSubject: " + s.Subject + "\r\n" + content_type + "\r\n\r\n" + fmt.Sprintf(".%s", time.Now().Format("2006-01-02 15:04:05")) + msg)
">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", time.Now().Format("2006-01-02 15:04:05")) + msg)
return s.sendMail(s.Host, auth, s.FromAddress, s.RecipientAddresses, mailmsg)
}
// implementing method. empty.
func (s *SmtpWriter) Flush() {
// Flush implementing method. empty.
func (s *SMTPWriter) Flush() {
return
}
// implementing method. empty.
func (s *SmtpWriter) Destroy() {
// Destroy implementing method. empty.
func (s *SMTPWriter) Destroy() {
return
}
func init() {
Register("smtp", NewSmtpWriter)
Register("smtp", newSMTPWriter)
}

View File

@ -1,212 +0,0 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package beego
import (
"bytes"
"compress/flate"
"compress/gzip"
"errors"
"io"
"io/ioutil"
"net/http"
"os"
"strings"
"sync"
"time"
)
var gmfim map[string]*memFileInfo = make(map[string]*memFileInfo)
var lock sync.RWMutex
// OpenMemZipFile returns MemFile object with a compressed static file.
// it's used for serve static file if gzip enable.
func openMemZipFile(path string, zip string) (*memFile, error) {
osfile, e := os.Open(path)
if e != nil {
return nil, e
}
defer osfile.Close()
osfileinfo, e := osfile.Stat()
if e != nil {
return nil, e
}
modtime := osfileinfo.ModTime()
fileSize := osfileinfo.Size()
lock.RLock()
cfi, ok := gmfim[zip+":"+path]
lock.RUnlock()
if !(ok && cfi.ModTime() == modtime && cfi.fileSize == fileSize) {
var content []byte
if zip == "gzip" {
var zipbuf bytes.Buffer
gzipwriter, e := gzip.NewWriterLevel(&zipbuf, gzip.BestCompression)
if e != nil {
return nil, e
}
_, e = io.Copy(gzipwriter, osfile)
gzipwriter.Close()
if e != nil {
return nil, e
}
content, e = ioutil.ReadAll(&zipbuf)
if e != nil {
return nil, e
}
} else if zip == "deflate" {
var zipbuf bytes.Buffer
deflatewriter, e := flate.NewWriter(&zipbuf, flate.BestCompression)
if e != nil {
return nil, e
}
_, e = io.Copy(deflatewriter, osfile)
deflatewriter.Close()
if e != nil {
return nil, e
}
content, e = ioutil.ReadAll(&zipbuf)
if e != nil {
return nil, e
}
} else {
content, e = ioutil.ReadAll(osfile)
if e != nil {
return nil, e
}
}
cfi = &memFileInfo{osfileinfo, modtime, content, int64(len(content)), fileSize}
lock.Lock()
defer lock.Unlock()
gmfim[zip+":"+path] = cfi
}
return &memFile{fi: cfi, offset: 0}, nil
}
// MemFileInfo contains a compressed file bytes and file information.
// it implements os.FileInfo interface.
type memFileInfo struct {
os.FileInfo
modTime time.Time
content []byte
contentSize int64
fileSize int64
}
// Name returns the compressed filename.
func (fi *memFileInfo) Name() string {
return fi.Name()
}
// Size returns the raw file content size, not compressed size.
func (fi *memFileInfo) Size() int64 {
return fi.contentSize
}
// Mode returns file mode.
func (fi *memFileInfo) Mode() os.FileMode {
return fi.Mode()
}
// ModTime returns the last modified time of raw file.
func (fi *memFileInfo) ModTime() time.Time {
return fi.modTime
}
// IsDir returns the compressing file is a directory or not.
func (fi *memFileInfo) IsDir() bool {
return fi.IsDir()
}
// return nil. implement the os.FileInfo interface method.
func (fi *memFileInfo) Sys() interface{} {
return nil
}
// MemFile contains MemFileInfo and bytes offset when reading.
// it implements io.Reader,io.ReadCloser and io.Seeker.
type memFile struct {
fi *memFileInfo
offset int64
}
// Close memfile.
func (f *memFile) Close() error {
return nil
}
// Get os.FileInfo of memfile.
func (f *memFile) Stat() (os.FileInfo, error) {
return f.fi, nil
}
// read os.FileInfo of files in directory of memfile.
// it returns empty slice.
func (f *memFile) Readdir(count int) ([]os.FileInfo, error) {
infos := []os.FileInfo{}
return infos, nil
}
// Read bytes from the compressed file bytes.
func (f *memFile) Read(p []byte) (n int, err error) {
if len(f.fi.content)-int(f.offset) >= len(p) {
n = len(p)
} else {
n = len(f.fi.content) - int(f.offset)
err = io.EOF
}
copy(p, f.fi.content[f.offset:f.offset+int64(n)])
f.offset += int64(n)
return
}
var errWhence = errors.New("Seek: invalid whence")
var errOffset = errors.New("Seek: invalid offset")
// Read bytes from the compressed file bytes by seeker.
func (f *memFile) Seek(offset int64, whence int) (ret int64, err error) {
switch whence {
default:
return 0, errWhence
case os.SEEK_SET:
case os.SEEK_CUR:
offset += f.offset
case os.SEEK_END:
offset += int64(len(f.fi.content))
}
if offset < 0 || int(offset) > len(f.fi.content) {
return 0, errOffset
}
f.offset = offset
return f.offset, nil
}
// GetAcceptEncodingZip returns accept encoding format in http header.
// zip is first, then deflate if both accepted.
// If no accepted, return empty string.
func getAcceptEncodingZip(r *http.Request) string {
ss := r.Header.Get("Accept-Encoding")
ss = strings.ToLower(ss)
if strings.Contains(ss, "gzip") {
return "gzip"
} else if strings.Contains(ss, "deflate") {
return "deflate"
} else {
return ""
}
}

View File

@ -1,71 +0,0 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Usage:
//
// import "github.com/astaxie/beego/middleware"
//
// I18N = middleware.NewLocale("conf/i18n.conf", beego.AppConfig.String("language"))
//
// more docs: http://beego.me/docs/module/i18n.md
package middleware
import (
"encoding/json"
"io/ioutil"
"os"
)
type Translation struct {
filepath string
CurrentLocal string
Locales map[string]map[string]string
}
func NewLocale(filepath string, defaultlocal string) *Translation {
file, err := os.Open(filepath)
if err != nil {
panic("open " + filepath + " err :" + err.Error())
}
data, err := ioutil.ReadAll(file)
if err != nil {
panic("read " + filepath + " err :" + err.Error())
}
i18n := make(map[string]map[string]string)
if err = json.Unmarshal(data, &i18n); err != nil {
panic("json.Unmarshal " + filepath + " err :" + err.Error())
}
return &Translation{
filepath: filepath,
CurrentLocal: defaultlocal,
Locales: i18n,
}
}
func (t *Translation) SetLocale(local string) {
t.CurrentLocal = local
}
func (t *Translation) Translate(key string, local string) string {
if local == "" {
local = t.CurrentLocal
}
if ct, ok := t.Locales[key]; ok {
if v, o := ct[local]; o {
return v
}
}
return key
}

View File

@ -14,33 +14,40 @@
package migration
// Table store the tablename and Column
type Table struct {
TableName string
Columns []*Column
}
// Create return the create sql
func (t *Table) Create() string {
return ""
}
// Drop return the drop sql
func (t *Table) Drop() string {
return ""
}
// Column define the columns name type and Default
type Column struct {
Name string
Type string
Default interface{}
}
// Create return create sql with the provided tbname and columns
func Create(tbname string, columns ...Column) string {
return ""
}
// Drop return the drop sql with the provided tbname and columns
func Drop(tbname string, columns ...Column) string {
return ""
}
// TableDDL is still in think
func TableDDL(tbname string, columns ...Column) string {
return ""
}

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// migration package for migration
// Package migration is used for migration
//
// The table structure is as follow:
//
@ -39,8 +39,8 @@ import (
// const the data format for the bee generate migration datatype
const (
M_DATE_FORMAT = "20060102_150405"
M_DB_DATE_FORMAT = "2006-01-02 15:04:05"
DateFormat = "20060102_150405"
DBDateFormat = "2006-01-02 15:04:05"
)
// Migrationer is an interface for all Migration struct
@ -60,24 +60,24 @@ func init() {
migrationMap = make(map[string]Migrationer)
}
// the basic type which will implement the basic type
// Migration the basic type which will implement the basic type
type Migration struct {
sqls []string
Created string
}
// implement in the Inheritance struct for upgrade
// Up implement in the Inheritance struct for upgrade
func (m *Migration) Up() {
}
// implement in the Inheritance struct for down
// Down implement in the Inheritance struct for down
func (m *Migration) Down() {
}
// add sql want to execute
func (m *Migration) Sql(sql string) {
// SQL add sql want to execute
func (m *Migration) SQL(sql string) {
m.sqls = append(m.sqls, sql)
}
@ -86,7 +86,7 @@ func (m *Migration) Reset() {
m.sqls = make([]string, 0)
}
// execute the sql already add in the sql
// Exec execute the sql already add in the sql
func (m *Migration) Exec(name, status string) error {
o := orm.NewOrm()
for _, s := range m.sqls {
@ -104,33 +104,32 @@ func (m *Migration) addOrUpdateRecord(name, status string) error {
o := orm.NewOrm()
if status == "down" {
status = "rollback"
p, err := o.Raw("update migrations set `status` = ?, `rollback_statements` = ?, `created_at` = ? where name = ?").Prepare()
p, err := o.Raw("update migrations set status = ?, rollback_statements = ?, created_at = ? where name = ?").Prepare()
if err != nil {
return nil
}
_, err = p.Exec(status, strings.Join(m.sqls, "; "), time.Now().Format(M_DB_DATE_FORMAT), name)
_, err = p.Exec(status, strings.Join(m.sqls, "; "), time.Now().Format(DBDateFormat), name)
return err
} else {
}
status = "update"
p, err := o.Raw("insert into migrations(`name`, `created_at`, `statements`, `status`) values(?,?,?,?)").Prepare()
p, err := o.Raw("insert into migrations(name, created_at, statements, status) values(?,?,?,?)").Prepare()
if err != nil {
return err
}
_, err = p.Exec(name, time.Now().Format(M_DB_DATE_FORMAT), strings.Join(m.sqls, "; "), status)
_, err = p.Exec(name, time.Now().Format(DBDateFormat), strings.Join(m.sqls, "; "), status)
return err
}
}
// get the unixtime from the Created
// GetCreated get the unixtime from the Created
func (m *Migration) GetCreated() int64 {
t, err := time.Parse(M_DATE_FORMAT, m.Created)
t, err := time.Parse(DateFormat, m.Created)
if err != nil {
return 0
}
return t.Unix()
}
// register the Migration in the map
// Register register the Migration in the map
func Register(name string, m Migrationer) error {
if _, ok := migrationMap[name]; ok {
return errors.New("already exist name:" + name)
@ -139,7 +138,7 @@ func Register(name string, m Migrationer) error {
return nil
}
// upgrate the migration from lasttime
// Upgrade upgrate the migration from lasttime
func Upgrade(lasttime int64) error {
sm := sortMap(migrationMap)
i := 0
@ -163,7 +162,7 @@ func Upgrade(lasttime int64) error {
return nil
}
//rollback the migration by the name
// Rollback rollback the migration by the name
func Rollback(name string) error {
if v, ok := migrationMap[name]; ok {
beego.Info("start rollback")
@ -178,14 +177,13 @@ func Rollback(name string) error {
beego.Info("end rollback")
time.Sleep(2 * time.Second)
return nil
} else {
}
beego.Error("not exist the migrationMap name:" + name)
time.Sleep(2 * time.Second)
return errors.New("not exist the migrationMap name:" + name)
}
}
// reset all migration
// Reset reset all migration
// run all migration's down function
func Reset() error {
sm := sortMap(migrationMap)
@ -214,7 +212,7 @@ func Reset() error {
return nil
}
// first Reset, then Upgrade
// Refresh first Reset, then Upgrade
func Refresh() error {
err := Reset()
if err != nil {

14
mime.go
View File

@ -14,11 +14,7 @@
package beego
import (
"mime"
)
var mimemaps map[string]string = map[string]string{
var mimemaps = map[string]string{
".3dm": "x-world/x-3dmf",
".3dmf": "x-world/x-3dmf",
".7z": "application/x-7z-compressed",
@ -40,6 +36,7 @@ var mimemaps map[string]string = map[string]string{
".ani": "application/x-navi-animation",
".aos": "application/x-nokia-9000-communicator-add-on-software",
".aps": "application/mime",
".apk": "application/vnd.android.package-archive",
".arc": "application/x-arc-compressed",
".arj": "application/arj",
".art": "image/x-jg",
@ -557,10 +554,3 @@ var mimemaps map[string]string = map[string]string{
".oex": "application/x-opera-extension",
".mustache": "text/html",
}
func initMime() error {
for k, v := range mimemaps {
mime.AddExtensionType(k, v)
}
return nil
}

View File

@ -23,16 +23,17 @@ import (
type namespaceCond func(*beecontext.Context) bool
type innnerNamespace func(*Namespace)
// LinkNamespace used as link action
type LinkNamespace func(*Namespace)
// Namespace is store all the info
type Namespace struct {
prefix string
handlers *ControllerRegistor
handlers *ControllerRegister
}
// get new Namespace
func NewNamespace(prefix string, params ...innnerNamespace) *Namespace {
// NewNamespace get new Namespace
func NewNamespace(prefix string, params ...LinkNamespace) *Namespace {
ns := &Namespace{
prefix: prefix,
handlers: NewControllerRegister(),
@ -43,7 +44,7 @@ func NewNamespace(prefix string, params ...innnerNamespace) *Namespace {
return ns
}
// set condtion function
// Cond set condtion function
// if cond return true can run this namespace, else can't
// usage:
// ns.Cond(func (ctx *context.Context) bool{
@ -72,7 +73,7 @@ func (n *Namespace) Cond(cond namespaceCond) *Namespace {
return n
}
// add filter in the Namespace
// Filter add filter in the Namespace
// action has before & after
// FilterFunc
// usage:
@ -95,98 +96,98 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace {
return n
}
// same as beego.Rourer
// Router same as beego.Rourer
// refer: https://godoc.org/github.com/astaxie/beego#Router
func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace {
n.handlers.Add(rootpath, c, mappingMethods...)
return n
}
// same as beego.AutoRouter
// AutoRouter same as beego.AutoRouter
// refer: https://godoc.org/github.com/astaxie/beego#AutoRouter
func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace {
n.handlers.AddAuto(c)
return n
}
// same as beego.AutoPrefix
// AutoPrefix same as beego.AutoPrefix
// refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix
func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace {
n.handlers.AddAutoPrefix(prefix, c)
return n
}
// same as beego.Get
// Get same as beego.Get
// refer: https://godoc.org/github.com/astaxie/beego#Get
func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace {
n.handlers.Get(rootpath, f)
return n
}
// same as beego.Post
// Post same as beego.Post
// refer: https://godoc.org/github.com/astaxie/beego#Post
func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace {
n.handlers.Post(rootpath, f)
return n
}
// same as beego.Delete
// Delete same as beego.Delete
// refer: https://godoc.org/github.com/astaxie/beego#Delete
func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace {
n.handlers.Delete(rootpath, f)
return n
}
// same as beego.Put
// Put same as beego.Put
// refer: https://godoc.org/github.com/astaxie/beego#Put
func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace {
n.handlers.Put(rootpath, f)
return n
}
// same as beego.Head
// Head same as beego.Head
// refer: https://godoc.org/github.com/astaxie/beego#Head
func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace {
n.handlers.Head(rootpath, f)
return n
}
// same as beego.Options
// Options same as beego.Options
// refer: https://godoc.org/github.com/astaxie/beego#Options
func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace {
n.handlers.Options(rootpath, f)
return n
}
// same as beego.Patch
// Patch same as beego.Patch
// refer: https://godoc.org/github.com/astaxie/beego#Patch
func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace {
n.handlers.Patch(rootpath, f)
return n
}
// same as beego.Any
// Any same as beego.Any
// refer: https://godoc.org/github.com/astaxie/beego#Any
func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace {
n.handlers.Any(rootpath, f)
return n
}
// same as beego.Handler
// Handler same as beego.Handler
// refer: https://godoc.org/github.com/astaxie/beego#Handler
func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace {
n.handlers.Handler(rootpath, h)
return n
}
// add include class
// Include add include class
// refer: https://godoc.org/github.com/astaxie/beego#Include
func (n *Namespace) Include(cList ...ControllerInterface) *Namespace {
n.handlers.Include(cList...)
return n
}
// nest Namespace
// Namespace add nest Namespace
// usage:
//ns := beego.NewNamespace(“/v1”).
//Namespace(
@ -230,7 +231,7 @@ func (n *Namespace) Namespace(ns ...*Namespace) *Namespace {
return n
}
// register Namespace into beego.Handler
// AddNamespace register Namespace into beego.Handler
// support multi Namespace
func AddNamespace(nl ...*Namespace) {
for _, n := range nl {
@ -275,113 +276,113 @@ func addPrefix(t *Tree, prefix string) {
}
// Namespace Condition
func NSCond(cond namespaceCond) innnerNamespace {
// NSCond is Namespace Condition
func NSCond(cond namespaceCond) LinkNamespace {
return func(ns *Namespace) {
ns.Cond(cond)
}
}
// Namespace BeforeRouter filter
func NSBefore(filiterList ...FilterFunc) innnerNamespace {
// NSBefore Namespace BeforeRouter filter
func NSBefore(filiterList ...FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Filter("before", filiterList...)
}
}
// Namespace FinishRouter filter
func NSAfter(filiterList ...FilterFunc) innnerNamespace {
// NSAfter add Namespace FinishRouter filter
func NSAfter(filiterList ...FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Filter("after", filiterList...)
}
}
// Namespace Include ControllerInterface
func NSInclude(cList ...ControllerInterface) innnerNamespace {
// NSInclude Namespace Include ControllerInterface
func NSInclude(cList ...ControllerInterface) LinkNamespace {
return func(ns *Namespace) {
ns.Include(cList...)
}
}
// Namespace Router
func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) innnerNamespace {
// NSRouter call Namespace Router
func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace {
return func(ns *Namespace) {
ns.Router(rootpath, c, mappingMethods...)
}
}
// Namespace Get
func NSGet(rootpath string, f FilterFunc) innnerNamespace {
// NSGet call Namespace Get
func NSGet(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Get(rootpath, f)
}
}
// Namespace Post
func NSPost(rootpath string, f FilterFunc) innnerNamespace {
// NSPost call Namespace Post
func NSPost(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Post(rootpath, f)
}
}
// Namespace Head
func NSHead(rootpath string, f FilterFunc) innnerNamespace {
// NSHead call Namespace Head
func NSHead(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Head(rootpath, f)
}
}
// Namespace Put
func NSPut(rootpath string, f FilterFunc) innnerNamespace {
// NSPut call Namespace Put
func NSPut(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Put(rootpath, f)
}
}
// Namespace Delete
func NSDelete(rootpath string, f FilterFunc) innnerNamespace {
// NSDelete call Namespace Delete
func NSDelete(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Delete(rootpath, f)
}
}
// Namespace Any
func NSAny(rootpath string, f FilterFunc) innnerNamespace {
// NSAny call Namespace Any
func NSAny(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Any(rootpath, f)
}
}
// Namespace Options
func NSOptions(rootpath string, f FilterFunc) innnerNamespace {
// NSOptions call Namespace Options
func NSOptions(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Options(rootpath, f)
}
}
// Namespace Patch
func NSPatch(rootpath string, f FilterFunc) innnerNamespace {
// NSPatch call Namespace Patch
func NSPatch(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Patch(rootpath, f)
}
}
//Namespace AutoRouter
func NSAutoRouter(c ControllerInterface) innnerNamespace {
// NSAutoRouter call Namespace AutoRouter
func NSAutoRouter(c ControllerInterface) LinkNamespace {
return func(ns *Namespace) {
ns.AutoRouter(c)
}
}
// Namespace AutoPrefix
func NSAutoPrefix(prefix string, c ControllerInterface) innnerNamespace {
// NSAutoPrefix call Namespace AutoPrefix
func NSAutoPrefix(prefix string, c ControllerInterface) LinkNamespace {
return func(ns *Namespace) {
ns.AutoPrefix(prefix, c)
}
}
// Namespace add sub Namespace
func NSNamespace(prefix string, params ...innnerNamespace) innnerNamespace {
// NSNamespace add sub Namespace
func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace {
return func(ns *Namespace) {
n := NewNamespace(prefix, params...)
ns.Namespace(n)

View File

@ -46,7 +46,7 @@ func printHelp(errs ...string) {
os.Exit(2)
}
// listen for orm command and then run it if command arguments passed.
// RunCommand listen for orm command and then run it if command arguments passed.
func RunCommand() {
if len(os.Args) < 2 || os.Args[1] != "orm" {
return
@ -100,7 +100,7 @@ func (d *commandSyncDb) Parse(args []string) {
func (d *commandSyncDb) Run() error {
var drops []string
if d.force {
drops = getDbDropSql(d.al)
drops = getDbDropSQL(d.al)
}
db := d.al.DB
@ -124,7 +124,7 @@ func (d *commandSyncDb) Run() error {
}
}
sqls, indexes := getDbCreateSql(d.al)
sqls, indexes := getDbCreateSQL(d.al)
tables, err := d.al.DbBaser.GetTables(db)
if err != nil {
@ -180,7 +180,7 @@ func (d *commandSyncDb) Run() error {
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
}
query := idx.Sql
query := idx.SQL
_, err := db.Exec(query)
if d.verbose {
fmt.Printf(" %s\n", query)
@ -203,7 +203,7 @@ func (d *commandSyncDb) Run() error {
queries := []string{sqls[i]}
for _, idx := range indexes[mi.table] {
queries = append(queries, idx.Sql)
queries = append(queries, idx.SQL)
}
for _, query := range queries {
@ -228,12 +228,12 @@ func (d *commandSyncDb) Run() error {
}
// database creation commander interface implement.
type commandSqlAll struct {
type commandSQLAll struct {
al *alias
}
// parse orm command line arguments.
func (d *commandSqlAll) Parse(args []string) {
func (d *commandSQLAll) Parse(args []string) {
var name string
flagSet := flag.NewFlagSet("orm command: sqlall", flag.ExitOnError)
@ -244,13 +244,13 @@ func (d *commandSqlAll) Parse(args []string) {
}
// run orm line command.
func (d *commandSqlAll) Run() error {
sqls, indexes := getDbCreateSql(d.al)
func (d *commandSQLAll) Run() error {
sqls, indexes := getDbCreateSQL(d.al)
var all []string
for i, mi := range modelCache.allOrdered() {
queries := []string{sqls[i]}
for _, idx := range indexes[mi.table] {
queries = append(queries, idx.Sql)
queries = append(queries, idx.SQL)
}
sql := strings.Join(queries, "\n")
all = append(all, sql)
@ -262,10 +262,10 @@ func (d *commandSqlAll) Run() error {
func init() {
commands["syncdb"] = new(commandSyncDb)
commands["sqlall"] = new(commandSqlAll)
commands["sqlall"] = new(commandSQLAll)
}
// run syncdb command line.
// RunSyncdb run syncdb command line.
// name means table's alias name. default is "default".
// force means run next sql if the current is error.
// verbose means show all info when running command or not.

View File

@ -23,11 +23,11 @@ import (
type dbIndex struct {
Table string
Name string
Sql string
SQL string
}
// create database drop sql.
func getDbDropSql(al *alias) (sqls []string) {
func getDbDropSQL(al *alias) (sqls []string) {
if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model")
os.Exit(2)
@ -45,13 +45,14 @@ func getDbDropSql(al *alias) (sqls []string) {
func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
T := al.DbBaser.DbTypes()
fieldType := fi.fieldType
fieldSize := fi.size
checkColumn:
switch fieldType {
case TypeBooleanField:
col = T["bool"]
case TypeCharField:
col = fmt.Sprintf(T["string"], fi.size)
col = fmt.Sprintf(T["string"], fieldSize)
case TypeTextField:
col = T["string-text"]
case TypeDateField:
@ -65,7 +66,7 @@ checkColumn:
case TypeIntegerField:
col = T["int32"]
case TypeBigIntegerField:
if al.Driver == DR_Sqlite {
if al.Driver == DRSqlite {
fieldType = TypeIntegerField
goto checkColumn
}
@ -89,6 +90,7 @@ checkColumn:
}
case RelForeignKey, RelOneToOne:
fieldType = fi.relModelInfo.fields.pk.fieldType
fieldSize = fi.relModelInfo.fields.pk.size
goto checkColumn
}
@ -112,7 +114,7 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string {
}
// create database creation string.
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model")
os.Exit(2)
@ -142,7 +144,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
if fi.auto {
switch al.Driver {
case DR_Sqlite, DR_Postgres:
case DRSqlite, DRPostgres:
column += T["auto"]
default:
column += col + " " + T["auto"]
@ -201,7 +203,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
sql += strings.Join(columns, ",\n")
sql += "\n)"
if al.Driver == DR_MySQL {
if al.Driver == DRMySQL {
var engine string
if mi.model != nil {
engine = getTableEngine(mi.addrField)
@ -237,7 +239,7 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
index := dbIndex{}
index.Table = mi.table
index.Name = name
index.Sql = sql
index.SQL = sql
tableIndexes[mi.table] = append(tableIndexes[mi.table], index)
}
@ -247,7 +249,6 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
return
}
// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands
func getColumnDefault(fi *fieldInfo) string {
var (
@ -263,14 +264,18 @@ func getColumnDefault(fi *fieldInfo) string {
// These defaults will be useful if there no config value orm:"default" and NOT NULL is on
switch fi.fieldType {
case TypeDateField, TypeDateTimeField:
return v;
case TypeDateField, TypeDateTimeField, TypeTextField:
return v
case TypeBooleanField, TypeBitField, TypeSmallIntegerField, TypeIntegerField,
case TypeBitField, TypeSmallIntegerField, TypeIntegerField,
TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField,
TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField,
TypeDecimalField:
t = " DEFAULT %s "
d = "0"
case TypeBooleanField:
t = " DEFAULT %s "
d = "FALSE"
}
if fi.colDefault {

173
orm/db.go
View File

@ -24,12 +24,13 @@ import (
)
const (
format_Date = "2006-01-02"
format_DateTime = "2006-01-02 15:04:05"
formatDate = "2006-01-02"
formatDateTime = "2006-01-02 15:04:05"
)
var (
ErrMissPK = errors.New("missed pk value") // missing pk error
// ErrMissPK missing pk error
ErrMissPK = errors.New("missed pk value")
)
var (
@ -44,6 +45,8 @@ var (
"gte": true,
"lt": true,
"lte": true,
"eq": true,
"nq": true,
"startswith": true,
"endswith": true,
"istartswith": true,
@ -214,14 +217,14 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
}
}
if fi.null == false && value == nil {
return nil, errors.New(fmt.Sprintf("field `%s` cannot be NULL", fi.fullName))
return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName)
}
}
}
}
switch fi.fieldType {
case TypeDateField, TypeDateTimeField:
if fi.auto_now || fi.auto_now_add && insert {
if fi.autoNow || fi.autoNowAdd && insert {
if insert {
if t, ok := value.(time.Time); ok && !t.IsZero() {
break
@ -280,13 +283,12 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
var id int64
err := row.Scan(&id)
return id, err
} else {
if res, err := stmt.Exec(values...); err == nil {
}
res, err := stmt.Exec(values...)
if err == nil {
return res.LastInsertId()
} else {
}
return 0, err
}
}
}
// query sql ,read records and persist in dbBaser.
@ -324,7 +326,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q)
refs := make([]interface{}, colsNum)
for i, _ := range refs {
for i := range refs {
var ref interface{}
refs[i] = &ref
}
@ -337,15 +339,11 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
return ErrNoRows
}
return err
} else {
}
elm := reflect.New(mi.addrField.Elem().Type())
mind := reflect.Indirect(elm)
d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz)
ind.Set(mind)
}
return nil
}
@ -423,7 +421,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
Q := d.ins.TableQuote()
marks := make([]string, len(names))
for i, _ := range marks {
for i := range marks {
marks[i] = "?"
}
@ -442,20 +440,19 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
if res, err := q.Exec(query, values...); err == nil {
res, err := q.Exec(query, values...)
if err == nil {
if isMulti {
return res.RowsAffected()
}
return res.LastInsertId()
} else {
}
return 0, err
}
} else {
row := q.QueryRow(query, values...)
var id int64
err := row.Scan(&id)
return id, err
}
}
// execute update sql dbQuerier with given struct reflect.Value.
@ -491,11 +488,11 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
d.ins.ReplaceMarks(&query)
if res, err := q.Exec(query, setValues...); err == nil {
res, err := q.Exec(query, setValues...)
if err == nil {
return res.RowsAffected()
} else {
return 0, err
}
return 0, err
}
// execute delete sql dbQuerier with given struct reflect.Value.
@ -511,14 +508,12 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, pkName, Q)
d.ins.ReplaceMarks(&query)
if res, err := q.Exec(query, pkValue); err == nil {
res, err := q.Exec(query, pkValue)
if err == nil {
num, err := res.RowsAffected()
if err != nil {
return 0, err
}
if num > 0 {
if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
@ -527,17 +522,14 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
}
}
err := d.deleteRels(q, mi, []interface{}{pkValue}, tz)
if err != nil {
return num, err
}
}
return num, err
} else {
return 0, err
}
return 0, err
}
// update table-related record by querySet.
@ -563,11 +555,11 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
tables.parseRelated(qs.related, qs.relDepth)
}
where, args := tables.getCondSql(cond, false, tz)
where, args := tables.getCondSQL(cond, false, tz)
values = append(values, args...)
join := tables.getJoinSql()
join := tables.getJoinSQL()
var query, T string
@ -583,13 +575,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
col := fmt.Sprintf("%s%s%s%s", T, Q, v, Q)
if c, ok := values[i].(colValue); ok {
switch c.opt {
case Col_Add:
case ColAdd:
cols = append(cols, col+" = "+col+" + ?")
case Col_Minus:
case ColMinus:
cols = append(cols, col+" = "+col+" - ?")
case Col_Multiply:
case ColMultiply:
cols = append(cols, col+" = "+col+" * ?")
case Col_Except:
case ColExcept:
cols = append(cols, col+" = "+col+" / ?")
}
values[i] = c.value
@ -608,12 +600,11 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
}
d.ins.ReplaceMarks(&query)
if res, err := q.Exec(query, values...); err == nil {
res, err := q.Exec(query, values...)
if err == nil {
return res.RowsAffected()
} else {
return 0, err
}
return 0, err
}
// delete related records.
@ -622,23 +613,23 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
for _, fi := range mi.fields.fieldsReverse {
fi = fi.reverseFieldInfo
switch fi.onDelete {
case od_CASCADE:
case odCascade:
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
_, err := d.DeleteBatch(q, nil, fi.mi, cond, tz)
if err != nil {
return err
}
case od_SET_DEFAULT, od_SET_NULL:
case odSetDefault, odSetNULL:
cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
params := Params{fi.column: nil}
if fi.onDelete == od_SET_DEFAULT {
if fi.onDelete == odSetDefault {
params[fi.column] = fi.initial.String()
}
_, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz)
if err != nil {
return err
}
case od_DO_NOTHING:
case odDoNothing:
}
}
return nil
@ -659,8 +650,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
Q := d.ins.TableQuote()
where, args := tables.getCondSql(cond, false, tz)
join := tables.getJoinSql()
where, args := tables.getCondSQL(cond, false, tz)
join := tables.getJoinSQL()
cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q)
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where)
@ -668,16 +659,14 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
d.ins.ReplaceMarks(&query)
var rs *sql.Rows
if r, err := q.Query(query, args...); err != nil {
r, err := q.Query(query, args...)
if err != nil {
return 0, err
} else {
rs = r
}
rs = r
defer rs.Close()
var ref interface{}
args = make([]interface{}, 0)
cnt := 0
for rs.Next() {
@ -693,31 +682,28 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
}
marks := make([]string, len(args))
for i, _ := range marks {
for i := range marks {
marks[i] = "?"
}
sql := fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql)
d.ins.ReplaceMarks(&query)
if res, err := q.Exec(query, args...); err == nil {
res, err := q.Exec(query, args...)
if err == nil {
num, err := res.RowsAffected()
if err != nil {
return 0, err
}
if num > 0 {
err := d.deleteRels(q, mi, args, tz)
if err != nil {
return num, err
}
}
return num, nil
} else {
return 0, err
}
return 0, err
}
// read related records.
@ -799,10 +785,11 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
where, args := tables.getCondSql(cond, false, tz)
orderBy := tables.getOrderSql(qs.orders)
limit := tables.getLimitSql(mi, offset, rlimit)
join := tables.getJoinSql()
where, args := tables.getCondSQL(cond, false, tz)
groupBy := tables.getGroupSQL(qs.groups)
orderBy := tables.getOrderSQL(qs.orders)
limit := tables.getLimitSQL(mi, offset, rlimit)
join := tables.getJoinSQL()
for _, tbl := range tables.tables {
if tbl.sel {
@ -812,19 +799,23 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
}
}
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit)
sqlSelect := "SELECT"
if qs.distinct {
sqlSelect += " DISTINCT"
}
query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit)
d.ins.ReplaceMarks(&query)
var rs *sql.Rows
if r, err := q.Query(query, args...); err != nil {
r, err := q.Query(query, args...)
if err != nil {
return 0, err
} else {
rs = r
}
rs = r
refs := make([]interface{}, colsNum)
for i, _ := range refs {
for i := range refs {
var ref interface{}
refs[i] = &ref
}
@ -935,9 +926,9 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
where, args := tables.getCondSql(cond, false, tz)
tables.getOrderSql(qs.orders)
join := tables.getJoinSql()
where, args := tables.getCondSQL(cond, false, tz)
tables.getOrderSQL(qs.orders)
join := tables.getJoinSQL()
Q := d.ins.TableQuote()
@ -952,7 +943,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
}
// generate sql with replacing operator string placeholders and replaced values.
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
sql := ""
params := getFlatParams(fi, args, tz)
@ -964,7 +955,7 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
switch operator {
case "in":
marks := make([]string, len(params))
for i, _ := range marks {
for i := range marks {
marks[i] = "?"
}
sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
@ -977,7 +968,7 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
if len(params) > 1 {
panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params)))
}
sql = d.ins.OperatorSql(operator)
sql = d.ins.OperatorSQL(operator)
switch operator {
case "exact":
if arg == nil {
@ -1105,12 +1096,12 @@ setValue:
)
if len(s) >= 19 {
s = s[:19]
t, err = time.ParseInLocation(format_DateTime, s, tz)
t, err = time.ParseInLocation(formatDateTime, s, tz)
} else {
if len(s) > 10 {
s = s[:10]
}
t, err = time.ParseInLocation(format_Date, s, tz)
t, err = time.ParseInLocation(formatDate, s, tz)
}
t = t.In(DefaultTimeLoc)
@ -1441,26 +1432,24 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
}
}
where, args := tables.getCondSql(cond, false, tz)
orderBy := tables.getOrderSql(qs.orders)
limit := tables.getLimitSql(mi, qs.offset, qs.limit)
join := tables.getJoinSql()
where, args := tables.getCondSQL(cond, false, tz)
groupBy := tables.getGroupSQL(qs.groups)
orderBy := tables.getOrderSQL(qs.orders)
limit := tables.getLimitSQL(mi, qs.offset, qs.limit)
join := tables.getJoinSQL()
sels := strings.Join(cols, ", ")
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit)
query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s%s", sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit)
d.ins.ReplaceMarks(&query)
var rs *sql.Rows
if r, err := q.Query(query, args...); err != nil {
rs, err := q.Query(query, args...)
if err != nil {
return 0, err
} else {
rs = r
}
refs := make([]interface{}, len(cols))
for i, _ := range refs {
for i := range refs {
var ref interface{}
refs[i] = &ref
}
@ -1473,11 +1462,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
)
for rs.Next() {
if cnt == 0 {
if cols, err := rs.Columns(); err != nil {
cols, err := rs.Columns()
if err != nil {
return 0, err
} else {
columns = cols
}
columns = cols
}
if err := rs.Scan(refs...); err != nil {
@ -1641,7 +1630,7 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e
}
// not implement.
func (d *dbBase) OperatorSql(operator string) string {
func (d *dbBase) OperatorSQL(operator string) string {
panic(ErrNotImplement)
}

View File

@ -22,15 +22,17 @@ import (
"time"
)
// database driver constant int.
// DriverType database driver constant int.
type DriverType int
// Enum the Database driver
const (
_ DriverType = iota // int enum type
DR_MySQL // mysql
DR_Sqlite // sqlite
DR_Oracle // oracle
DR_Postgres // pgsql
DRMySQL // mysql
DRSqlite // sqlite
DROracle // oracle
DRPostgres // pgsql
DRTiDB // TiDB
)
// database driver string.
@ -53,15 +55,17 @@ var _ Driver = new(driver)
var (
dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
drivers = map[string]DriverType{
"mysql": DR_MySQL,
"postgres": DR_Postgres,
"sqlite3": DR_Sqlite,
"mysql": DRMySQL,
"postgres": DRPostgres,
"sqlite3": DRSqlite,
"tidb": DRTiDB,
}
dbBasers = map[DriverType]dbBaser{
DR_MySQL: newdbBaseMysql(),
DR_Sqlite: newdbBaseSqlite(),
DR_Oracle: newdbBaseMysql(),
DR_Postgres: newdbBasePostgres(),
DRMySQL: newdbBaseMysql(),
DRSqlite: newdbBaseSqlite(),
DROracle: newdbBaseOracle(),
DRPostgres: newdbBasePostgres(),
DRTiDB: newdbBaseTidb(),
}
)
@ -119,7 +123,7 @@ func detectTZ(al *alias) {
}
switch al.Driver {
case DR_MySQL:
case DRMySQL:
row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
var tz string
row.Scan(&tz)
@ -147,10 +151,10 @@ func detectTZ(al *alias) {
al.Engine = "INNODB"
}
case DR_Sqlite:
case DRSqlite:
al.TZ = time.UTC
case DR_Postgres:
case DRPostgres:
row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
var tz string
row.Scan(&tz)
@ -188,12 +192,13 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
return al, nil
}
// AddAliasWthDB add a aliasName for the drivename
func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error {
_, err := addAliasWthDB(aliasName, driverName, db)
return err
}
// Setting the database connect params. Use the database driver self dataSource args.
// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args.
func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error {
var (
err error
@ -236,7 +241,7 @@ end:
return err
}
// Register a database driver use specify driver name, this can be definition the driver is which database type.
// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
func RegisterDriver(driverName string, typ DriverType) error {
if t, ok := drivers[driverName]; ok == false {
drivers[driverName] = typ
@ -248,7 +253,7 @@ func RegisterDriver(driverName string, typ DriverType) error {
return nil
}
// Change the database default used timezone
// SetDataBaseTZ Change the database default used timezone
func SetDataBaseTZ(aliasName string, tz *time.Location) error {
if al, ok := dataBaseCache.get(aliasName); ok {
al.TZ = tz
@ -258,14 +263,14 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error {
return nil
}
// Change the max idle conns for *sql.DB, use specify database alias name
// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name
func SetMaxIdleConns(aliasName string, maxIdleConns int) {
al := getDbAlias(aliasName)
al.MaxIdleConns = maxIdleConns
al.DB.SetMaxIdleConns(maxIdleConns)
}
// Change the max open conns for *sql.DB, use specify database alias name
// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
func SetMaxOpenConns(aliasName string, maxOpenConns int) {
al := getDbAlias(aliasName)
al.MaxOpenConns = maxOpenConns
@ -275,7 +280,7 @@ func SetMaxOpenConns(aliasName string, maxOpenConns int) {
}
}
// Get *sql.DB from registered database by db alias name.
// GetDB Get *sql.DB from registered database by db alias name.
// Use "default" as alias name if you not set.
func GetDB(aliasNames ...string) (*sql.DB, error) {
var name string
@ -284,9 +289,9 @@ func GetDB(aliasNames ...string) (*sql.DB, error) {
} else {
name = "default"
}
if al, ok := dataBaseCache.get(name); ok {
al, ok := dataBaseCache.get(name)
if ok {
return al.DB, nil
} else {
return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
}
return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
}

View File

@ -30,6 +30,8 @@ var mysqlOperators = map[string]string{
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
"eq": "= ?",
"ne": "!= ?",
"startswith": "LIKE BINARY ?",
"endswith": "LIKE BINARY ?",
"istartswith": "LIKE ?",
@ -65,7 +67,7 @@ type dbBaseMysql struct {
var _ dbBaser = new(dbBaseMysql)
// get mysql operator.
func (d *dbBaseMysql) OperatorSql(operator string) string {
func (d *dbBaseMysql) OperatorSQL(operator string) string {
return mysqlOperators[operator]
}

View File

@ -29,6 +29,8 @@ var postgresOperators = map[string]string{
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
"eq": "= ?",
"ne": "!= ?",
"startswith": "LIKE ?",
"endswith": "LIKE ?",
"istartswith": "LIKE UPPER(?)",
@ -64,7 +66,7 @@ type dbBasePostgres struct {
var _ dbBaser = new(dbBasePostgres)
// get postgresql operator.
func (d *dbBasePostgres) OperatorSql(operator string) string {
func (d *dbBasePostgres) OperatorSQL(operator string) string {
return postgresOperators[operator]
}
@ -99,7 +101,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
num := 0
for _, c := range q {
if c == '?' {
num += 1
num++
}
}
if num == 0 {
@ -112,7 +114,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
if c == '?' {
data = append(data, '$')
data = append(data, []byte(strconv.Itoa(num))...)
num += 1
num++
} else {
data = append(data, c)
}

View File

@ -29,6 +29,8 @@ var sqliteOperators = map[string]string{
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
"eq": "= ?",
"ne": "!= ?",
"startswith": "LIKE ? ESCAPE '\\'",
"endswith": "LIKE ? ESCAPE '\\'",
"istartswith": "LIKE ? ESCAPE '\\'",
@ -64,7 +66,7 @@ type dbBaseSqlite struct {
var _ dbBaser = new(dbBaseSqlite)
// get sqlite operator.
func (d *dbBaseSqlite) OperatorSql(operator string) string {
func (d *dbBaseSqlite) OperatorSQL(operator string) string {
return sqliteOperators[operator]
}

View File

@ -164,7 +164,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
}
// generate join string.
func (t *dbTables) getJoinSql() (join string) {
func (t *dbTables) getJoinSQL() (join string) {
Q := t.base.TableQuote()
for _, jt := range t.tables {
@ -186,7 +186,7 @@ func (t *dbTables) getJoinSql() (join string) {
table = jt.mi.table
switch {
case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
c1 = jt.fi.mi.fields.pk.column
for _, ffi := range jt.mi.fields.fieldsRel {
if jt.fi.mi == ffi.relModelInfo {
@ -220,7 +220,7 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
)
num := len(exprs) - 1
names := make([]string, 0)
var names []string
inner := true
@ -326,7 +326,7 @@ loopFor:
}
// generate condition sql.
func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
if cond == nil || cond.IsEmpty() {
return
}
@ -347,7 +347,7 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
where += "NOT "
}
if p.isCond {
w, ps := t.getCondSql(p.cond, true, tz)
w, ps := t.getCondSQL(p.cond, true, tz)
if w != "" {
w = fmt.Sprintf("( %s) ", w)
}
@ -372,12 +372,12 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
operator = "exact"
}
operSql, args := t.base.GenerateOperatorSql(mi, fi, operator, p.args, tz)
operSQL, args := t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz)
leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
t.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
where += fmt.Sprintf("%s %s ", leftCol, operSql)
where += fmt.Sprintf("%s %s ", leftCol, operSQL)
params = append(params, args...)
}
@ -390,8 +390,32 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
return
}
// generate group sql.
func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
if len(groups) == 0 {
return
}
Q := t.base.TableQuote()
groupSqls := make([]string, 0, len(groups))
for _, group := range groups {
exprs := strings.Split(group, ExprSep)
index, _, fi, suc := t.parseExprs(t.mi, exprs)
if suc == false {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
}
groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q))
}
groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", "))
return
}
// generate order sql.
func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
if len(orders) == 0 {
return
}
@ -415,12 +439,12 @@ func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc))
}
orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
return
}
// generate limit sql.
func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) {
func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) {
if limit == 0 {
limit = int64(DefaultRowsLimit)
}

63
orm/db_tidb.go Normal file
View File

@ -0,0 +1,63 @@
// Copyright 2015 TiDB Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
)
// mysql dbBaser implementation.
type dbBaseTidb struct {
dbBase
}
var _ dbBaser = new(dbBaseTidb)
// get mysql operator.
func (d *dbBaseTidb) OperatorSQL(operator string) string {
return mysqlOperators[operator]
}
// get mysql table field types.
func (d *dbBaseTidb) DbTypes() map[string]string {
return mysqlTypes
}
// show table sql for mysql.
func (d *dbBaseTidb) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
}
// show columns sql of table for mysql.
func (d *dbBaseTidb) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
}
// execute sql to check index exist.
func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
var cnt int
row.Scan(&cnt)
return cnt > 0
}
// create new mysql dbBaser.
func newdbBaseTidb() dbBaser {
b := new(dbBaseTidb)
b.ins = b
return b
}

View File

@ -24,9 +24,8 @@ import (
func getDbAlias(name string) *alias {
if al, ok := dataBaseCache.get(name); ok {
return al
} else {
panic(fmt.Errorf("unknown DataBase alias name %s", name))
}
panic(fmt.Errorf("unknown DataBase alias name %s", name))
}
// get pk column info.
@ -80,19 +79,19 @@ outFor:
var err error
if len(v) >= 19 {
s := v[:19]
t, err = time.ParseInLocation(format_DateTime, s, DefaultTimeLoc)
t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc)
} else {
s := v
if len(v) > 10 {
s = v[:10]
}
t, err = time.ParseInLocation(format_Date, s, tz)
t, err = time.ParseInLocation(formatDate, s, tz)
}
if err == nil {
if fi.fieldType == TypeDateField {
v = t.In(tz).Format(format_Date)
v = t.In(tz).Format(formatDate)
} else {
v = t.In(tz).Format(format_DateTime)
v = t.In(tz).Format(formatDateTime)
}
}
}
@ -137,9 +136,9 @@ outFor:
case reflect.Struct:
if v, ok := arg.(time.Time); ok {
if fi != nil && fi.fieldType == TypeDateField {
arg = v.In(tz).Format(format_Date)
arg = v.In(tz).Format(formatDate)
} else {
arg = v.In(tz).Format(format_DateTime)
arg = v.In(tz).Format(formatDateTime)
}
} else {
typ := val.Type()

View File

@ -19,10 +19,10 @@ import (
)
const (
od_CASCADE = "cascade"
od_SET_NULL = "set_null"
od_SET_DEFAULT = "set_default"
od_DO_NOTHING = "do_nothing"
odCascade = "cascade"
odSetNULL = "set_null"
odSetDefault = "set_default"
odDoNothing = "do_nothing"
defaultStructTagName = "orm"
defaultStructTagDelim = ";"
)
@ -113,7 +113,7 @@ func (mc *_modelCache) clean() {
mc.done = false
}
// Clean model cache. Then you can re-RegisterModel.
// ResetModelCache Clean model cache. Then you can re-RegisterModel.
// Common use this api for test case.
func ResetModelCache() {
modelCache.clean()

View File

@ -51,12 +51,10 @@ func registerModel(prefix string, model interface{}) {
}
info := newModelInfo(val)
if info.fields.pk == nil {
outFor:
for _, fi := range info.fields.fieldsDB {
if fi.name == "Id" {
if fi.sf.Tag.Get(defaultStructTagName) == "" {
if strings.ToLower(fi.name) == "id" {
switch fi.addrValue.Elem().Kind() {
case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
fi.auto = true
@ -66,7 +64,6 @@ func registerModel(prefix string, model interface{}) {
}
}
}
}
if info.fields.pk == nil {
fmt.Printf("<orm.RegisterModel> `%s` need a primary key field\n", name)
@ -269,7 +266,10 @@ func bootStrap() {
if found == false {
mForC:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] {
if ffi.relModelInfo == mi {
conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough ||
fi.relTable != "" && fi.relTable == ffi.relTable ||
fi.relThrough == "" && fi.relTable == ""
if ffi.relModelInfo == mi && conditions {
found = true
fi.reverseField = ffi.reverseFieldInfoTwo.name
@ -298,12 +298,12 @@ end:
}
}
// register models
// RegisterModel register models
func RegisterModel(models ...interface{}) {
RegisterModelWithPrefix("", models...)
}
// register models with a prefix
// RegisterModelWithPrefix register models with a prefix
func RegisterModelWithPrefix(prefix string, models ...interface{}) {
if modelCache.done {
panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
@ -314,7 +314,7 @@ func RegisterModelWithPrefix(prefix string, models ...interface{}) {
}
}
// bootrap models.
// BootStrap bootrap models.
// make all model parsed and can not add more models
func BootStrap() {
if modelCache.done {

View File

@ -15,49 +15,28 @@
package orm
import (
"errors"
"fmt"
"strconv"
"time"
)
// Define the Type enum
const (
// bool
TypeBooleanField = 1 << iota
// string
TypeCharField
// string
TypeTextField
// time.Time
TypeDateField
// time.Time
TypeDateTimeField
// int8
TypeBitField
// int16
TypeSmallIntegerField
// int32
TypeIntegerField
// int64
TypeBigIntegerField
// uint8
TypePositiveBitField
// uint16
TypePositiveSmallIntegerField
// uint32
TypePositiveIntegerField
// uint64
TypePositiveBigIntegerField
// float64
TypeFloatField
// float64
TypeDecimalField
RelForeignKey
RelOneToOne
RelManyToMany
@ -65,6 +44,7 @@ const (
RelReverseMany
)
// Define some logic enum
const (
IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5
IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 8 << 9
@ -72,25 +52,30 @@ const (
IsFieldType = ^-RelReverseMany<<1 + 1
)
// A true/false field.
// BooleanField A true/false field.
type BooleanField bool
// Value return the BooleanField
func (e BooleanField) Value() bool {
return bool(e)
}
// Set will set the BooleanField
func (e *BooleanField) Set(d bool) {
*e = BooleanField(d)
}
// String format the Bool to string
func (e *BooleanField) String() string {
return strconv.FormatBool(e.Value())
}
// FieldType return BooleanField the type
func (e *BooleanField) FieldType() int {
return TypeBooleanField
}
// SetRaw set the interface to bool
func (e *BooleanField) SetRaw(value interface{}) error {
switch d := value.(type) {
case bool:
@ -102,56 +87,65 @@ func (e *BooleanField) SetRaw(value interface{}) error {
}
return err
default:
return errors.New(fmt.Sprintf("<BooleanField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<BooleanField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the current value
func (e *BooleanField) RawValue() interface{} {
return e.Value()
}
// verify the BooleanField implement the Fielder interface
var _ Fielder = new(BooleanField)
// A string field
// CharField A string field
// required values tag: size
// The size is enforced at the database level and in modelss validation.
// eg: `orm:"size(120)"`
type CharField string
// Value return the CharField's Value
func (e CharField) Value() string {
return string(e)
}
// Set CharField value
func (e *CharField) Set(d string) {
*e = CharField(d)
}
// String return the CharField
func (e *CharField) String() string {
return e.Value()
}
// FieldType return the enum type
func (e *CharField) FieldType() int {
return TypeCharField
}
// SetRaw set the interface to string
func (e *CharField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
return errors.New(fmt.Sprintf("<CharField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<CharField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the CharField value
func (e *CharField) RawValue() interface{} {
return e.Value()
}
// verify CharField implement Fielder
var _ Fielder = new(CharField)
// A date, represented in go by a time.Time instance.
// DateField A date, represented in go by a time.Time instance.
// only date values like 2006-01-02
// Has a few extra, optional attr tag:
//
@ -166,106 +160,125 @@ var _ Fielder = new(CharField)
// eg: `orm:"auto_now"` or `orm:"auto_now_add"`
type DateField time.Time
// Value return the time.Time
func (e DateField) Value() time.Time {
return time.Time(e)
}
// Set set the DateField's value
func (e *DateField) Set(d time.Time) {
*e = DateField(d)
}
// String convert datatime to string
func (e *DateField) String() string {
return e.Value().String()
}
// FieldType return enum type Date
func (e *DateField) FieldType() int {
return TypeDateField
}
// SetRaw convert the interface to time.Time. Allow string and time.Time
func (e *DateField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := timeParse(d, format_Date)
v, err := timeParse(d, formatDate)
if err != nil {
e.Set(v)
}
return err
default:
return errors.New(fmt.Sprintf("<DateField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<DateField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return Date value
func (e *DateField) RawValue() interface{} {
return e.Value()
}
// verify DateField implement fielder interface
var _ Fielder = new(DateField)
// A date, represented in go by a time.Time instance.
// DateTimeField A date, represented in go by a time.Time instance.
// datetime values like 2006-01-02 15:04:05
// Takes the same extra arguments as DateField.
type DateTimeField time.Time
// Value return the datatime value
func (e DateTimeField) Value() time.Time {
return time.Time(e)
}
// Set set the time.Time to datatime
func (e *DateTimeField) Set(d time.Time) {
*e = DateTimeField(d)
}
// String return the time's String
func (e *DateTimeField) String() string {
return e.Value().String()
}
// FieldType return the enum TypeDateTimeField
func (e *DateTimeField) FieldType() int {
return TypeDateTimeField
}
// SetRaw convert the string or time.Time to DateTimeField
func (e *DateTimeField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := timeParse(d, format_DateTime)
v, err := timeParse(d, formatDateTime)
if err != nil {
e.Set(v)
}
return err
default:
return errors.New(fmt.Sprintf("<DateTimeField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<DateTimeField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the datatime value
func (e *DateTimeField) RawValue() interface{} {
return e.Value()
}
// verify datatime implement fielder
var _ Fielder = new(DateTimeField)
// A floating-point number represented in go by a float32 value.
// FloatField A floating-point number represented in go by a float32 value.
type FloatField float64
// Value return the FloatField value
func (e FloatField) Value() float64 {
return float64(e)
}
// Set the Float64
func (e *FloatField) Set(d float64) {
*e = FloatField(d)
}
// String return the string
func (e *FloatField) String() string {
return ToStr(e.Value(), -1, 32)
}
// FieldType return the enum type
func (e *FloatField) FieldType() int {
return TypeFloatField
}
// SetRaw converter interface Float64 float32 or string to FloatField
func (e *FloatField) SetRaw(value interface{}) error {
switch d := value.(type) {
case float32:
@ -278,36 +291,43 @@ func (e *FloatField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<FloatField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<FloatField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the FloatField value
func (e *FloatField) RawValue() interface{} {
return e.Value()
}
// verify FloatField implement Fielder
var _ Fielder = new(FloatField)
// -32768 to 32767
// SmallIntegerField -32768 to 32767
type SmallIntegerField int16
// Value return int16 value
func (e SmallIntegerField) Value() int16 {
return int16(e)
}
// Set the SmallIntegerField value
func (e *SmallIntegerField) Set(d int16) {
*e = SmallIntegerField(d)
}
// String convert smallint to string
func (e *SmallIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type SmallIntegerField
func (e *SmallIntegerField) FieldType() int {
return TypeSmallIntegerField
}
// SetRaw convert interface int16/string to int16
func (e *SmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int16:
@ -318,36 +338,43 @@ func (e *SmallIntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<SmallIntegerField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<SmallIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return smallint value
func (e *SmallIntegerField) RawValue() interface{} {
return e.Value()
}
// verify SmallIntegerField implement Fielder
var _ Fielder = new(SmallIntegerField)
// -2147483648 to 2147483647
// IntegerField -2147483648 to 2147483647
type IntegerField int32
// Value return the int32
func (e IntegerField) Value() int32 {
return int32(e)
}
// Set IntegerField value
func (e *IntegerField) Set(d int32) {
*e = IntegerField(d)
}
// String convert Int32 to string
func (e *IntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return the enum type
func (e *IntegerField) FieldType() int {
return TypeIntegerField
}
// SetRaw convert interface int32/string to int32
func (e *IntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int32:
@ -358,36 +385,43 @@ func (e *IntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<IntegerField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<IntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return IntegerField value
func (e *IntegerField) RawValue() interface{} {
return e.Value()
}
// verify IntegerField implement Fielder
var _ Fielder = new(IntegerField)
// -9223372036854775808 to 9223372036854775807.
// BigIntegerField -9223372036854775808 to 9223372036854775807.
type BigIntegerField int64
// Value return int64
func (e BigIntegerField) Value() int64 {
return int64(e)
}
// Set the BigIntegerField value
func (e *BigIntegerField) Set(d int64) {
*e = BigIntegerField(d)
}
// String convert BigIntegerField to string
func (e *BigIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type
func (e *BigIntegerField) FieldType() int {
return TypeBigIntegerField
}
// SetRaw convert interface int64/string to int64
func (e *BigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int64:
@ -398,36 +432,43 @@ func (e *BigIntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<BigIntegerField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<BigIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return BigIntegerField value
func (e *BigIntegerField) RawValue() interface{} {
return e.Value()
}
// verify BigIntegerField implement Fielder
var _ Fielder = new(BigIntegerField)
// 0 to 65535
// PositiveSmallIntegerField 0 to 65535
type PositiveSmallIntegerField uint16
// Value return uint16
func (e PositiveSmallIntegerField) Value() uint16 {
return uint16(e)
}
// Set PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) Set(d uint16) {
*e = PositiveSmallIntegerField(d)
}
// String convert uint16 to string
func (e *PositiveSmallIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveSmallIntegerField) FieldType() int {
return TypePositiveSmallIntegerField
}
// SetRaw convert Interface uint16/string to uint16
func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint16:
@ -438,36 +479,43 @@ func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue returns PositiveSmallIntegerField value
func (e *PositiveSmallIntegerField) RawValue() interface{} {
return e.Value()
}
// verify PositiveSmallIntegerField implement Fielder
var _ Fielder = new(PositiveSmallIntegerField)
// 0 to 4294967295
// PositiveIntegerField 0 to 4294967295
type PositiveIntegerField uint32
// Value return PositiveIntegerField value. Uint32
func (e PositiveIntegerField) Value() uint32 {
return uint32(e)
}
// Set the PositiveIntegerField value
func (e *PositiveIntegerField) Set(d uint32) {
*e = PositiveIntegerField(d)
}
// String convert PositiveIntegerField to string
func (e *PositiveIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveIntegerField) FieldType() int {
return TypePositiveIntegerField
}
// SetRaw convert interface uint32/string to Uint32
func (e *PositiveIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint32:
@ -478,36 +526,43 @@ func (e *PositiveIntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<PositiveIntegerField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<PositiveIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return the PositiveIntegerField Value
func (e *PositiveIntegerField) RawValue() interface{} {
return e.Value()
}
// verify PositiveIntegerField implement Fielder
var _ Fielder = new(PositiveIntegerField)
// 0 to 18446744073709551615
// PositiveBigIntegerField 0 to 18446744073709551615
type PositiveBigIntegerField uint64
// Value return uint64
func (e PositiveBigIntegerField) Value() uint64 {
return uint64(e)
}
// Set PositiveBigIntegerField value
func (e *PositiveBigIntegerField) Set(d uint64) {
*e = PositiveBigIntegerField(d)
}
// String convert PositiveBigIntegerField to string
func (e *PositiveBigIntegerField) String() string {
return ToStr(e.Value())
}
// FieldType return enum type
func (e *PositiveBigIntegerField) FieldType() int {
return TypePositiveIntegerField
}
// SetRaw convert interface uint64/string to Uint64
func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint64:
@ -518,48 +573,57 @@ func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return PositiveBigIntegerField value
func (e *PositiveBigIntegerField) RawValue() interface{} {
return e.Value()
}
// verify PositiveBigIntegerField implement Fielder
var _ Fielder = new(PositiveBigIntegerField)
// A large text field.
// TextField A large text field.
type TextField string
// Value return TextField value
func (e TextField) Value() string {
return string(e)
}
// Set the TextField value
func (e *TextField) Set(d string) {
*e = TextField(d)
}
// String convert TextField to string
func (e *TextField) String() string {
return e.Value()
}
// FieldType return enum type
func (e *TextField) FieldType() int {
return TypeTextField
}
// SetRaw convert interface string to string
func (e *TextField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
return errors.New(fmt.Sprintf("<TextField.SetRaw> unknown value `%s`", value))
return fmt.Errorf("<TextField.SetRaw> unknown value `%s`", value)
}
return nil
}
// RawValue return TextField value
func (e *TextField) RawValue() interface{} {
return e.Value()
}
// verify TextField implement Fielder
var _ Fielder = new(TextField)

View File

@ -119,8 +119,8 @@ type fieldInfo struct {
colDefault bool
initial StrTo
size int
auto_now bool
auto_now_add bool
autoNow bool
autoNowAdd bool
rel bool
reverse bool
reverseField string
@ -223,6 +223,11 @@ checkType:
break checkType
case "many":
fieldType = RelReverseMany
if tv := tags["rel_table"]; tv != "" {
fi.relTable = tv
} else if tv := tags["rel_through"]; tv != "" {
fi.relThrough = tv
}
break checkType
default:
err = fmt.Errorf("error")
@ -309,20 +314,20 @@ checkType:
if fi.rel && fi.dbcol {
switch onDelete {
case od_CASCADE, od_DO_NOTHING:
case od_SET_DEFAULT:
case odCascade, odDoNothing:
case odSetDefault:
if initial.Exist() == false {
err = errors.New("on_delete: set_default need set field a default value")
goto end
}
case od_SET_NULL:
case odSetNULL:
if fi.null == false {
err = errors.New("on_delete: set_null need set field null")
goto end
}
default:
if onDelete == "" {
onDelete = od_CASCADE
onDelete = odCascade
} else {
err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete)
goto end
@ -350,9 +355,9 @@ checkType:
fi.unique = false
case TypeDateField, TypeDateTimeField:
if attrs["auto_now"] {
fi.auto_now = true
fi.autoNow = true
} else if attrs["auto_now_add"] {
fi.auto_now_add = true
fi.autoNowAdd = true
}
case TypeFloatField:
case TypeDecimalField:

View File

@ -15,7 +15,6 @@
package orm
import (
"errors"
"fmt"
"os"
"reflect"
@ -72,13 +71,13 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
added := info.fields.Add(fi)
if added == false {
err = errors.New(fmt.Sprintf("duplicate column name: %s", fi.column))
err = fmt.Errorf("duplicate column name: %s", fi.column)
break
}
if fi.pk {
if info.fields.pk != nil {
err = errors.New(fmt.Sprintf("one model must have one pk field only"))
err = fmt.Errorf("one model must have one pk field only")
break
} else {
info.fields.pk = fi

View File

@ -25,6 +25,9 @@ import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
// As tidb can't use go get, so disable the tidb testing now
// _ "github.com/pingcap/tidb"
)
// A slice string field.
@ -76,21 +79,21 @@ func (e *SliceStringField) RawValue() interface{} {
var _ Fielder = new(SliceStringField)
// A json field.
type JsonField struct {
type JSONField struct {
Name string
Data string
}
func (e *JsonField) String() string {
func (e *JSONField) String() string {
data, _ := json.Marshal(e)
return string(data)
}
func (e *JsonField) FieldType() int {
func (e *JSONField) FieldType() int {
return TypeTextField
}
func (e *JsonField) SetRaw(value interface{}) error {
func (e *JSONField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
return json.Unmarshal([]byte(d), e)
@ -99,14 +102,14 @@ func (e *JsonField) SetRaw(value interface{}) error {
}
}
func (e *JsonField) RawValue() interface{} {
func (e *JSONField) RawValue() interface{} {
return e.String()
}
var _ Fielder = new(JsonField)
var _ Fielder = new(JSONField)
type Data struct {
Id int
ID int `orm:"column(id)"`
Boolean bool
Char string `orm:"size(50)"`
Text string `orm:"type(text)"`
@ -130,7 +133,7 @@ type Data struct {
}
type DataNull struct {
Id int
ID int `orm:"column(id)"`
Boolean bool `orm:"null"`
Char string `orm:"null;size(50)"`
Text string `orm:"null;type(text)"`
@ -193,7 +196,7 @@ type Float32 float64
type Float64 float64
type DataCustom struct {
Id int
ID int `orm:"column(id)"`
Boolean Boolean
Char string `orm:"size(50)"`
Text string `orm:"type(text)"`
@ -216,18 +219,18 @@ type DataCustom struct {
// only for mysql
type UserBig struct {
Id uint64
ID uint64 `orm:"column(id)"`
Name string
}
type User struct {
Id int
ID int `orm:"column(id)"`
UserName string `orm:"size(30);unique"`
Email string `orm:"size(100)"`
Password string `orm:"size(100)"`
Status int16 `orm:"column(Status)"`
IsStaff bool
IsActive bool `orm:"default(1)"`
IsActive bool `orm:"default(true)"`
Created time.Time `orm:"auto_now_add;type(date)"`
Updated time.Time `orm:"auto_now"`
Profile *Profile `orm:"null;rel(one);on_delete(set_null)"`
@ -235,21 +238,21 @@ type User struct {
ShouldSkip string `orm:"-"`
Nums int
Langs SliceStringField `orm:"size(100)"`
Extra JsonField `orm:"type(text)"`
Extra JSONField `orm:"type(text)"`
unexport bool `orm:"-"`
unexport_ bool
unexportBool bool
}
func (u *User) TableIndex() [][]string {
return [][]string{
[]string{"Id", "UserName"},
[]string{"Id", "Created"},
{"Id", "UserName"},
{"Id", "Created"},
}
}
func (u *User) TableUnique() [][]string {
return [][]string{
[]string{"UserName", "Email"},
{"UserName", "Email"},
}
}
@ -259,7 +262,7 @@ func NewUser() *User {
}
type Profile struct {
Id int
ID int `orm:"column(id)"`
Age int16
Money float64
User *User `orm:"reverse(one)" json:"-"`
@ -276,7 +279,7 @@ func NewProfile() *Profile {
}
type Post struct {
Id int
ID int `orm:"column(id)"`
User *User `orm:"rel(fk)"`
Title string `orm:"size(60)"`
Content string `orm:"type(text)"`
@ -287,7 +290,7 @@ type Post struct {
func (u *Post) TableIndex() [][]string {
return [][]string{
[]string{"Id", "Created"},
{"Id", "Created"},
}
}
@ -297,7 +300,7 @@ func NewPost() *Post {
}
type Tag struct {
Id int
ID int `orm:"column(id)"`
Name string `orm:"size(30)"`
BestPost *Post `orm:"rel(one);null"`
Posts []*Post `orm:"reverse(many)" json:"-"`
@ -309,7 +312,7 @@ func NewTag() *Tag {
}
type PostTags struct {
Id int
ID int `orm:"column(id)"`
Post *Post `orm:"rel(fk)"`
Tag *Tag `orm:"rel(fk)"`
}
@ -319,7 +322,7 @@ func (m *PostTags) TableName() string {
}
type Comment struct {
Id int
ID int `orm:"column(id)"`
Post *Post `orm:"rel(fk);column(post)"`
Content string `orm:"type(text)"`
Parent *Comment `orm:"null;rel(fk)"`
@ -331,6 +334,24 @@ func NewComment() *Comment {
return obj
}
type Group struct {
ID int `orm:"column(gid);size(32)"`
Name string
Permissions []*Permission `orm:"reverse(many)" json:"-"`
}
type Permission struct {
ID int `orm:"column(id)"`
Name string
Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.GroupPermissions)"`
}
type GroupPermissions struct {
ID int `orm:"column(id)"`
Group *Group `orm:"rel(fk)"`
Permission *Permission `orm:"rel(fk)"`
}
var DBARGS = struct {
Driver string
Source string
@ -345,6 +366,7 @@ var (
IsMysql = DBARGS.Driver == "mysql"
IsSqlite = DBARGS.Driver == "sqlite3"
IsPostgres = DBARGS.Driver == "postgres"
IsTidb = DBARGS.Driver == "tidb"
)
var (
@ -364,6 +386,7 @@ Default DB Drivers.
mysql: https://github.com/go-sql-driver/mysql
sqlite3: https://github.com/mattn/go-sqlite3
postgres: https://github.com/lib/pq
tidb: https://github.com/pingcap/tidb
usage:
@ -371,6 +394,7 @@ go get -u github.com/astaxie/beego/orm
go get -u github.com/go-sql-driver/mysql
go get -u github.com/mattn/go-sqlite3
go get -u github.com/lib/pq
go get -u github.com/pingcap/tidb
#### MySQL
mysql -u root -e 'create database orm_test;'
@ -390,6 +414,12 @@ psql -c 'create database orm_test;' -U postgres
export ORM_DRIVER=postgres
export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
go test -v github.com/astaxie/beego/orm
#### TiDB
export ORM_DRIVER=tidb
export ORM_SOURCE='memory://test/test'
go test -v github.com/astaxie/beego/orm
`)
os.Exit(2)
}
@ -397,7 +427,7 @@ go test -v github.com/astaxie/beego/orm
RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20)
alias := getDbAlias("default")
if alias.Driver == DR_MySQL {
if alias.Driver == DRMySQL {
alias.Engine = "INNODB"
}

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Package orm provide ORM for MySQL/PostgreSQL/sqlite
// Simple Usage
//
// package main
@ -59,12 +60,13 @@ import (
"time"
)
// DebugQueries define the debug
const (
Debug_Queries = iota
DebugQueries = iota
)
// Define common vars
var (
// DebugLevel = Debug_Queries
Debug = false
DebugLog = NewLog(os.Stderr)
DefaultRowsLimit = 1000
@ -79,7 +81,10 @@ var (
ErrNotImplement = errors.New("have not implement")
)
// Params stores the Params
type Params map[string]interface{}
// ParamsList stores paramslist
type ParamsList []interface{}
type orm struct {
@ -188,7 +193,7 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
o.setPk(mi, ind, id)
cnt += 1
cnt++
}
} else {
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
@ -489,7 +494,7 @@ func (o *orm) Driver() Driver {
return driver(o.alias.Name)
}
// create new orm
// NewOrm create new orm
func NewOrm() Ormer {
BootStrap() // execute only once
@ -501,7 +506,7 @@ func NewOrm() Ormer {
return o
}
// create a new ormer object with specify *sql.DB for query
// NewOrmWithDB create a new ormer object with specify *sql.DB for query
func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
var al *alias

View File

@ -19,6 +19,7 @@ import (
"strings"
)
// ExprSep define the expression separation
const (
ExprSep = "__"
)
@ -32,19 +33,19 @@ type condValue struct {
isCond bool
}
// condition struct.
// Condition struct.
// work for WHERE conditions.
type Condition struct {
params []condValue
}
// return new condition struct
// NewCondition return new condition struct
func NewCondition() *Condition {
c := &Condition{}
return c
}
// add expression to condition
// And add expression to condition
func (c Condition) And(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.And> args cannot empty"))
@ -53,7 +54,7 @@ func (c Condition) And(expr string, args ...interface{}) *Condition {
return &c
}
// add NOT expression to condition
// AndNot add NOT expression to condition
func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.AndNot> args cannot empty"))
@ -62,7 +63,7 @@ func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
return &c
}
// combine a condition to current condition
// AndCond combine a condition to current condition
func (c *Condition) AndCond(cond *Condition) *Condition {
c = c.clone()
if c == cond {
@ -74,7 +75,7 @@ func (c *Condition) AndCond(cond *Condition) *Condition {
return c
}
// add OR expression to condition
// Or add OR expression to condition
func (c Condition) Or(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.Or> args cannot empty"))
@ -83,7 +84,7 @@ func (c Condition) Or(expr string, args ...interface{}) *Condition {
return &c
}
// add OR NOT expression to condition
// OrNot add OR NOT expression to condition
func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.OrNot> args cannot empty"))
@ -92,7 +93,7 @@ func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
return &c
}
// combine a OR condition to current condition
// OrCond combine a OR condition to current condition
func (c *Condition) OrCond(cond *Condition) *Condition {
c = c.clone()
if c == cond {
@ -104,12 +105,12 @@ func (c *Condition) OrCond(cond *Condition) *Condition {
return c
}
// check the condition arguments are empty or not.
// IsEmpty check the condition arguments are empty or not.
func (c *Condition) IsEmpty() bool {
return len(c.params) == 0
}
// clone a condition
// clone clone a condition
func (c Condition) clone() *Condition {
return &c
}

View File

@ -23,11 +23,12 @@ import (
"time"
)
// Log implement the log.Logger
type Log struct {
*log.Logger
}
// set io.Writer to create a Logger.
// NewLog set io.Writer to create a Logger.
func NewLog(out io.Writer) *Log {
d := new(Log)
d.Logger = log.New(out, "[ORM]", 1e9)
@ -41,7 +42,7 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error
if err != nil {
flag = "FAIL"
}
con := fmt.Sprintf(" - %s - [Queries/%s] - [%s / %11s / %7.1fms] - [%s]", t.Format(format_DateTime), alias.Name, flag, operaton, elsp, query)
con := fmt.Sprintf(" - %s - [Queries/%s] - [%s / %11s / %7.1fms] - [%s]", t.Format(formatDateTime), alias.Name, flag, operaton, elsp, query)
cons := make([]string, 0, len(args))
for _, arg := range args {
cons = append(cons, fmt.Sprintf("%v", arg))

View File

@ -14,9 +14,7 @@
package orm
import (
"reflect"
)
import "reflect"
// model to model struct
type queryM2M struct {
@ -44,7 +42,21 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
dbase := orm.alias.DbBaser
var models []interface{}
var otherValues []interface{}
var otherNames []string
for _, colname := range mi.fields.dbcols {
if colname != mfi.column && colname != rfi.column && colname != fi.mi.fields.pk.column &&
mi.fields.columns[colname] != mi.fields.pk {
otherNames = append(otherNames, colname)
}
}
for i, md := range mds {
if reflect.Indirect(reflect.ValueOf(md)).Kind() != reflect.Struct && i > 0 {
otherValues = append(otherValues, md)
mds = append(mds[:i], mds[i+1:]...)
}
}
for _, md := range mds {
val := reflect.ValueOf(md)
if val.Kind() == reflect.Slice || val.Kind() == reflect.Array {
@ -67,11 +79,9 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
names := []string{mfi.column, rfi.column}
values := make([]interface{}, 0, len(models)*2)
for _, md := range models {
ind := reflect.Indirect(reflect.ValueOf(md))
var v2 interface{}
if ind.Kind() != reflect.Struct {
v2 = ind.Interface()
@ -81,11 +91,11 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
panic(ErrMissPK)
}
}
values = append(values, v1, v2)
}
names = append(names, otherNames...)
values = append(values, otherValues...)
return dbase.InsertValue(orm.db, mi, true, names, values)
}

View File

@ -25,11 +25,12 @@ type colValue struct {
type operator int
// define Col operations
const (
Col_Add operator = iota
Col_Minus
Col_Multiply
Col_Except
ColAdd operator = iota
ColMinus
ColMultiply
ColExcept
)
// ColValue do the field raw changes. e.g Nums = Nums + 10. usage:
@ -38,7 +39,7 @@ const (
// }
func ColValue(opt operator, value interface{}) interface{} {
switch opt {
case Col_Add, Col_Minus, Col_Multiply, Col_Except:
case ColAdd, ColMinus, ColMultiply, ColExcept:
default:
panic(fmt.Errorf("orm.ColValue wrong operator"))
}
@ -60,7 +61,9 @@ type querySet struct {
relDepth int
limit int64
offset int64
groups []string
orders []string
distinct bool
orm *orm
}
@ -105,6 +108,12 @@ func (o querySet) Offset(offset interface{}) QuerySeter {
return &o
}
// add GROUP expression
func (o querySet) GroupBy(exprs ...string) QuerySeter {
o.groups = exprs
return &o
}
// add ORDER expression.
// "column" means ASC, "-column" means DESC.
func (o querySet) OrderBy(exprs ...string) QuerySeter {
@ -112,17 +121,22 @@ func (o querySet) OrderBy(exprs ...string) QuerySeter {
return &o
}
// add DISTINCT to SELECT
func (o querySet) Distinct() QuerySeter {
o.distinct = true
return &o
}
// set relation model to query together.
// it will query relation models and assign to parent model.
func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
var related []string
if len(params) == 0 {
o.relDepth = DefaultRelsDepth
} else {
for _, p := range params {
switch val := p.(type) {
case string:
related = append(o.related, val)
o.related = append(o.related, val)
case int:
o.relDepth = val
default:
@ -130,7 +144,6 @@ func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
}
}
}
o.related = related
return &o
}

View File

@ -165,14 +165,14 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
if str != "" {
if len(str) >= 19 {
str = str[:19]
t, err := time.ParseInLocation(format_DateTime, str, o.orm.alias.TZ)
t, err := time.ParseInLocation(formatDateTime, str, o.orm.alias.TZ)
if err == nil {
t = t.In(DefaultTimeLoc)
ind.Set(reflect.ValueOf(t))
}
} else if len(str) >= 10 {
str = str[:10]
t, err := time.ParseInLocation(format_Date, str, DefaultTimeLoc)
t, err := time.ParseInLocation(formatDate, str, DefaultTimeLoc)
if err == nil {
ind.Set(reflect.ValueOf(t))
}
@ -255,12 +255,13 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr
// query data and map to container
func (o *rawSet) QueryRow(containers ...interface{}) error {
refs := make([]interface{}, 0, len(containers))
sInds := make([]reflect.Value, 0)
eTyps := make([]reflect.Type, 0)
var (
refs = make([]interface{}, 0, len(containers))
sInds []reflect.Value
eTyps []reflect.Type
sMi *modelInfo
)
structMode := false
var sMi *modelInfo
for _, container := range containers {
val := reflect.ValueOf(container)
ind := reflect.Indirect(val)
@ -385,12 +386,13 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
// query data rows and map to container
func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
refs := make([]interface{}, 0, len(containers))
sInds := make([]reflect.Value, 0)
eTyps := make([]reflect.Type, 0)
var (
refs = make([]interface{}, 0, len(containers))
sInds []reflect.Value
eTyps []reflect.Type
sMi *modelInfo
)
structMode := false
var sMi *modelInfo
for _, container := range containers {
val := reflect.ValueOf(container)
sInd := reflect.Indirect(val)
@ -557,10 +559,9 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
var rs *sql.Rows
if r, err := o.orm.db.Query(query, args...); err != nil {
rs, err := o.orm.db.Query(query, args...)
if err != nil {
return 0, err
} else {
rs = r
}
defer rs.Close()
@ -574,9 +575,10 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er
for rs.Next() {
if cnt == 0 {
if columns, err := rs.Columns(); err != nil {
columns, err := rs.Columns()
if err != nil {
return 0, err
} else {
}
if len(needCols) > 0 {
indexs = make([]int, 0, len(needCols))
} else {
@ -585,7 +587,7 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er
cols = columns
refs = make([]interface{}, len(cols))
for i, _ := range refs {
for i := range refs {
var ref sql.NullString
refs[i] = &ref
@ -600,7 +602,6 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er
}
}
}
}
if err := rs.Scan(refs...); err != nil {
return 0, err
@ -684,11 +685,9 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
var rs *sql.Rows
if r, err := o.orm.db.Query(query, args...); err != nil {
rs, err := o.orm.db.Query(query, args...)
if err != nil {
return 0, err
} else {
rs = r
}
defer rs.Close()
@ -706,16 +705,16 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
for rs.Next() {
if cnt == 0 {
if columns, err := rs.Columns(); err != nil {
columns, err := rs.Columns()
if err != nil {
return 0, err
} else {
}
cols = columns
refs = make([]interface{}, len(cols))
for i, _ := range refs {
for i := range refs {
if keyCol == cols[i] {
keyIndex = i
}
if typ == 1 || keyIndex == i {
var ref sql.NullString
refs[i] = &ref
@ -723,17 +722,14 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
var ref interface{}
refs[i] = &ref
}
if valueCol == cols[i] {
valueIndex = i
}
}
if keyIndex == -1 || valueIndex == -1 {
panic(fmt.Errorf("<RawSeter> RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol))
}
}
}
if err := rs.Scan(refs...); err != nil {
return 0, err

View File

@ -31,13 +31,26 @@ import (
var _ = os.PathSeparator
var (
test_Date = format_Date + " -0700"
test_DateTime = format_DateTime + " -0700"
testDate = formatDate + " -0700"
testDateTime = formatDateTime + " -0700"
)
func ValuesCompare(is bool, a interface{}, args ...interface{}) (err error, ok bool) {
type argAny []interface{}
// get interface by index from interface slice
func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
if i >= 0 && i < len(a) {
r = a[i]
}
if len(args) > 0 {
r = args[0]
}
return
}
func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err error) {
if len(args) == 0 {
return fmt.Errorf("miss args"), false
return false, fmt.Errorf("miss args")
}
b := args[0]
arg := argAny(args)
@ -71,21 +84,21 @@ func ValuesCompare(is bool, a interface{}, args ...interface{}) (err error, ok b
wrongArg:
if err != nil {
return err, false
return false, err
}
return nil, true
return true, nil
}
func AssertIs(a interface{}, args ...interface{}) error {
if err, ok := ValuesCompare(true, a, args...); ok == false {
if ok, err := ValuesCompare(true, a, args...); ok == false {
return err
}
return nil
}
func AssertNot(a interface{}, args ...interface{}) error {
if err, ok := ValuesCompare(false, a, args...); ok == false {
if ok, err := ValuesCompare(false, a, args...); ok == false {
return err
}
return nil
@ -171,8 +184,11 @@ func TestSyncDb(t *testing.T) {
RegisterModel(new(Comment))
RegisterModel(new(UserBig))
RegisterModel(new(PostTags))
RegisterModel(new(Group))
RegisterModel(new(Permission))
RegisterModel(new(GroupPermissions))
err := RunSyncdb("default", true, false)
err := RunSyncdb("default", true, Debug)
throwFail(t, err)
modelCache.clean()
@ -187,6 +203,9 @@ func TestRegisterModels(t *testing.T) {
RegisterModel(new(Comment))
RegisterModel(new(UserBig))
RegisterModel(new(PostTags))
RegisterModel(new(Group))
RegisterModel(new(Permission))
RegisterModel(new(GroupPermissions))
BootStrap()
@ -208,7 +227,7 @@ func TestModelSyntax(t *testing.T) {
}
}
var Data_Values = map[string]interface{}{
var DataValues = map[string]interface{}{
"Boolean": true,
"Char": "char",
"Text": "text",
@ -235,7 +254,7 @@ func TestDataTypes(t *testing.T) {
d := Data{}
ind := reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values {
for name, value := range DataValues {
e := ind.FieldByName(name)
e.Set(reflect.ValueOf(value))
}
@ -244,22 +263,22 @@ func TestDataTypes(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
d = Data{Id: 1}
d = Data{ID: 1}
err = dORM.Read(&d)
throwFail(t, err)
ind = reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values {
for name, value := range DataValues {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
}
throwFail(t, AssertIs(vu == value, true), value, vu)
}
@ -278,7 +297,7 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
d = DataNull{Id: 1}
d = DataNull{ID: 1}
err = dORM.Read(&d)
throwFail(t, err)
@ -309,7 +328,7 @@ func TestNullDataTypes(t *testing.T) {
_, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec()
throwFail(t, err)
d = DataNull{Id: 2}
d = DataNull{ID: 2}
err = dORM.Read(&d)
throwFail(t, err)
@ -362,7 +381,7 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 3))
d = DataNull{Id: 3}
d = DataNull{ID: 3}
err = dORM.Read(&d)
throwFail(t, err)
@ -402,7 +421,7 @@ func TestDataCustomTypes(t *testing.T) {
d := DataCustom{}
ind := reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values {
for name, value := range DataValues {
e := ind.FieldByName(name)
if !e.IsValid() {
continue
@ -414,13 +433,13 @@ func TestDataCustomTypes(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
d = DataCustom{Id: 1}
d = DataCustom{ID: 1}
err = dORM.Read(&d)
throwFail(t, err)
ind = reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values {
for name, value := range DataValues {
e := ind.FieldByName(name)
if !e.IsValid() {
continue
@ -451,7 +470,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
u := &User{Id: user.Id}
u := &User{ID: user.ID}
err = dORM.Read(u)
throwFail(t, err)
@ -461,8 +480,8 @@ func TestCRUD(t *testing.T) {
throwFail(t, AssertIs(u.Status, 3))
throwFail(t, AssertIs(u.IsStaff, true))
throwFail(t, AssertIs(u.IsActive, true))
throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), test_Date))
throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), test_DateTime))
throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), testDate))
throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), testDateTime))
user.UserName = "astaxie"
user.Profile = profile
@ -470,11 +489,11 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
u = &User{Id: user.Id}
u = &User{ID: user.ID}
err = dORM.Read(u)
throwFailNow(t, err)
throwFail(t, AssertIs(u.UserName, "astaxie"))
throwFail(t, AssertIs(u.Profile.Id, profile.Id))
throwFail(t, AssertIs(u.Profile.ID, profile.ID))
u = &User{UserName: "astaxie", Password: "pass"}
err = dORM.Read(u, "UserName")
@ -487,7 +506,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
u = &User{Id: user.Id}
u = &User{ID: user.ID}
err = dORM.Read(u)
throwFailNow(t, err)
throwFail(t, AssertIs(u.UserName, "QQ"))
@ -497,7 +516,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
u = &User{Id: user.Id}
u = &User{ID: user.ID}
err = dORM.Read(u)
throwFail(t, err)
throwFail(t, AssertIs(true, u.Profile == nil))
@ -506,7 +525,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
u = &User{Id: 100}
u = &User{ID: 100}
err = dORM.Read(u)
throwFail(t, AssertIs(err, ErrNoRows))
@ -516,7 +535,7 @@ func TestCRUD(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
ub = UserBig{Id: 1}
ub = UserBig{ID: 1}
err = dORM.Read(&ub)
throwFail(t, err)
throwFail(t, AssertIs(ub.Name, "name"))
@ -586,29 +605,29 @@ func TestInsertTestData(t *testing.T) {
throwFail(t, AssertIs(id, 4))
tags := []*Tag{
&Tag{Name: "golang", BestPost: &Post{Id: 2}},
&Tag{Name: "example"},
&Tag{Name: "format"},
&Tag{Name: "c++"},
{Name: "golang", BestPost: &Post{ID: 2}},
{Name: "example"},
{Name: "format"},
{Name: "c++"},
}
posts := []*Post{
&Post{User: users[0], Tags: []*Tag{tags[0]}, Title: "Introduction", Content: `Go is a new language. Although it borrows ideas from existing languages, it has unusual properties that make effective Go programs different in character from programs written in its relatives. A straightforward translation of a C++ or Java program into Go is unlikely to produce a satisfactory resultJava programs are written in Java, not Go. On the other hand, thinking about the problem from a Go perspective could produce a successful but quite different program. In other words, to write Go well, it's important to understand its properties and idioms. It's also important to know the established conventions for programming in Go, such as naming, formatting, program construction, and so on, so that programs you write will be easy for other Go programmers to understand.
{User: users[0], Tags: []*Tag{tags[0]}, Title: "Introduction", Content: `Go is a new language. Although it borrows ideas from existing languages, it has unusual properties that make effective Go programs different in character from programs written in its relatives. A straightforward translation of a C++ or Java program into Go is unlikely to produce a satisfactory resultJava programs are written in Java, not Go. On the other hand, thinking about the problem from a Go perspective could produce a successful but quite different program. In other words, to write Go well, it's important to understand its properties and idioms. It's also important to know the established conventions for programming in Go, such as naming, formatting, program construction, and so on, so that programs you write will be easy for other Go programmers to understand.
This document gives tips for writing clear, idiomatic Go code. It augments the language specification, the Tour of Go, and How to Write Go Code, all of which you should read first.`},
&Post{User: users[1], Tags: []*Tag{tags[0], tags[1]}, Title: "Examples", Content: `The Go package sources are intended to serve not only as the core library but also as examples of how to use the language. Moreover, many of the packages contain working, self-contained executable examples you can run directly from the golang.org web site, such as this one (click on the word "Example" to open it up). If you have a question about how to approach a problem or how something might be implemented, the documentation, code and examples in the library can provide answers, ideas and background.`},
&Post{User: users[1], Tags: []*Tag{tags[0], tags[2]}, Title: "Formatting", Content: `Formatting issues are the most contentious but the least consequential. People can adapt to different formatting styles but it's better if they don't have to, and less time is devoted to the topic if everyone adheres to the same style. The problem is how to approach this Utopia without a long prescriptive style guide.
{User: users[1], Tags: []*Tag{tags[0], tags[1]}, Title: "Examples", Content: `The Go package sources are intended to serve not only as the core library but also as examples of how to use the language. Moreover, many of the packages contain working, self-contained executable examples you can run directly from the golang.org web site, such as this one (click on the word "Example" to open it up). If you have a question about how to approach a problem or how something might be implemented, the documentation, code and examples in the library can provide answers, ideas and background.`},
{User: users[1], Tags: []*Tag{tags[0], tags[2]}, Title: "Formatting", Content: `Formatting issues are the most contentious but the least consequential. People can adapt to different formatting styles but it's better if they don't have to, and less time is devoted to the topic if everyone adheres to the same style. The problem is how to approach this Utopia without a long prescriptive style guide.
With Go we take an unusual approach and let the machine take care of most formatting issues. The gofmt program (also available as go fmt, which operates at the package level rather than source file level) reads a Go program and emits the source in a standard style of indentation and vertical alignment, retaining and if necessary reformatting comments. If you want to know how to handle some new layout situation, run gofmt; if the answer doesn't seem right, rearrange your program (or file a bug about gofmt), don't work around it.`},
&Post{User: users[2], Tags: []*Tag{tags[3]}, Title: "Commentary", Content: `Go provides C-style /* */ block comments and C++-style // line comments. Line comments are the norm; block comments appear mostly as package comments, but are useful within an expression or to disable large swaths of code.
{User: users[2], Tags: []*Tag{tags[3]}, Title: "Commentary", Content: `Go provides C-style /* */ block comments and C++-style // line comments. Line comments are the norm; block comments appear mostly as package comments, but are useful within an expression or to disable large swaths of code.
The programand web servergodoc processes Go source files to extract documentation about the contents of the package. Comments that appear before top-level declarations, with no intervening newlines, are extracted along with the declaration to serve as explanatory text for the item. The nature and style of these comments determines the quality of the documentation godoc produces.`},
}
comments := []*Comment{
&Comment{Post: posts[0], Content: "a comment"},
&Comment{Post: posts[1], Content: "yes"},
&Comment{Post: posts[1]},
&Comment{Post: posts[1]},
&Comment{Post: posts[2]},
&Comment{Post: posts[2]},
{Post: posts[0], Content: "a comment"},
{Post: posts[1], Content: "yes"},
{Post: posts[1]},
{Post: posts[1]},
{Post: posts[2]},
{Post: posts[2]},
}
for _, tag := range tags {
@ -635,10 +654,47 @@ The program—and web server—godoc processes Go source files to extract docume
throwFail(t, err)
throwFail(t, AssertIs(id > 0, true))
}
permissions := []*Permission{
{Name: "writePosts"},
{Name: "readComments"},
{Name: "readPosts"},
}
groups := []*Group{
{
Name: "admins",
Permissions: []*Permission{permissions[0], permissions[1], permissions[2]},
},
{
Name: "users",
Permissions: []*Permission{permissions[1], permissions[2]},
},
}
for _, permission := range permissions {
id, err := dORM.Insert(permission)
throwFail(t, err)
throwFail(t, AssertIs(id > 0, true))
}
for _, group := range groups {
_, err := dORM.Insert(group)
throwFail(t, err)
throwFail(t, AssertIs(id > 0, true))
num := len(group.Permissions)
if num > 0 {
nums, err := dORM.QueryM2M(group, "permissions").Add(group.Permissions)
throwFailNow(t, err)
throwFailNow(t, AssertIs(nums, num))
}
}
}
func TestCustomField(t *testing.T) {
user := User{Id: 2}
user := User{ID: 2}
err := dORM.Read(&user)
throwFailNow(t, err)
@ -648,7 +704,7 @@ func TestCustomField(t *testing.T) {
_, err = dORM.Update(&user, "Langs", "Extra")
throwFailNow(t, err)
user = User{Id: 2}
user = User{ID: 2}
err = dORM.Read(&user)
throwFailNow(t, err)
throwFailNow(t, AssertIs(len(user.Langs), 2))
@ -702,7 +758,7 @@ func TestOperators(t *testing.T) {
var shouldNum int
if IsSqlite {
if IsSqlite || IsTidb {
shouldNum = 2
} else {
shouldNum = 0
@ -740,7 +796,7 @@ func TestOperators(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
if IsSqlite {
if IsSqlite || IsTidb {
shouldNum = 1
} else {
shouldNum = 0
@ -758,7 +814,7 @@ func TestOperators(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 2))
if IsSqlite {
if IsSqlite || IsTidb {
shouldNum = 2
} else {
shouldNum = 0
@ -889,9 +945,9 @@ func TestAll(t *testing.T) {
throwFailNow(t, AssertIs(users2[0].UserName, "slene"))
throwFailNow(t, AssertIs(users2[1].UserName, "astaxie"))
throwFailNow(t, AssertIs(users2[2].UserName, "nobody"))
throwFailNow(t, AssertIs(users2[0].Id, 0))
throwFailNow(t, AssertIs(users2[1].Id, 0))
throwFailNow(t, AssertIs(users2[2].Id, 0))
throwFailNow(t, AssertIs(users2[0].ID, 0))
throwFailNow(t, AssertIs(users2[1].ID, 0))
throwFailNow(t, AssertIs(users2[2].ID, 0))
throwFailNow(t, AssertIs(users2[0].Profile == nil, false))
throwFailNow(t, AssertIs(users2[1].Profile == nil, false))
throwFailNow(t, AssertIs(users2[2].Profile == nil, true))
@ -986,6 +1042,10 @@ func TestValuesFlat(t *testing.T) {
}
func TestRelatedSel(t *testing.T) {
if IsTidb {
// Skip it. TiDB does not support relation now.
return
}
qs := dORM.QueryTable("user")
num, err := qs.Filter("profile__age", 28).Count()
throwFail(t, err)
@ -1112,7 +1172,7 @@ func TestReverseQuery(t *testing.T) {
func TestLoadRelated(t *testing.T) {
// load reverse foreign key
user := User{Id: 3}
user := User{ID: 3}
err := dORM.Read(&user)
throwFailNow(t, err)
@ -1121,7 +1181,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 2))
throwFailNow(t, AssertIs(len(user.Posts), 2))
throwFailNow(t, AssertIs(user.Posts[0].User.Id, 3))
throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3))
num, err = dORM.LoadRelated(&user, "Posts", true)
throwFailNow(t, err)
@ -1143,8 +1203,8 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting"))
// load reverse one to one
profile := Profile{Id: 3}
profile.BestPost = &Post{Id: 2}
profile := Profile{ID: 3}
profile.BestPost = &Post{ID: 2}
num, err = dORM.Update(&profile, "BestPost")
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1))
@ -1183,7 +1243,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(user.Profile.BestPost == nil, false))
throwFailNow(t, AssertIs(user.Profile.BestPost.Title, "Examples"))
post := Post{Id: 2}
post := Post{ID: 2}
// load rel foreign key
err = dORM.Read(&post)
@ -1204,7 +1264,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(post.User.Profile.Age, 30))
// load rel m2m
post = Post{Id: 2}
post = Post{ID: 2}
err = dORM.Read(&post)
throwFailNow(t, err)
@ -1224,7 +1284,7 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, AssertIs(post.Tags[0].BestPost.User.UserName, "astaxie"))
// load reverse m2m
tag := Tag{Id: 1}
tag := Tag{ID: 1}
err = dORM.Read(&tag)
throwFailNow(t, err)
@ -1233,22 +1293,22 @@ func TestLoadRelated(t *testing.T) {
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3))
throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction"))
throwFailNow(t, AssertIs(tag.Posts[0].User.Id, 2))
throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2))
throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true))
num, err = dORM.LoadRelated(&tag, "Posts", true)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3))
throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction"))
throwFailNow(t, AssertIs(tag.Posts[0].User.Id, 2))
throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2))
throwFailNow(t, AssertIs(tag.Posts[0].User.UserName, "slene"))
}
func TestQueryM2M(t *testing.T) {
post := Post{Id: 4}
post := Post{ID: 4}
m2m := dORM.QueryM2M(&post, "Tags")
tag1 := []*Tag{&Tag{Name: "TestTag1"}, &Tag{Name: "TestTag2"}}
tag1 := []*Tag{{Name: "TestTag1"}, {Name: "TestTag2"}}
tag2 := &Tag{Name: "TestTag3"}
tag3 := []interface{}{&Tag{Name: "TestTag4"}}
@ -1311,7 +1371,7 @@ func TestQueryM2M(t *testing.T) {
m2m = dORM.QueryM2M(&tag, "Posts")
post1 := []*Post{&Post{Title: "TestPost1"}, &Post{Title: "TestPost2"}}
post1 := []*Post{{Title: "TestPost1"}, {Title: "TestPost2"}}
post2 := &Post{Title: "TestPost3"}
post3 := []interface{}{&Post{Title: "TestPost4"}}
@ -1319,7 +1379,7 @@ func TestQueryM2M(t *testing.T) {
for _, post := range posts {
p := post.(*Post)
p.User = &User{Id: 1}
p.User = &User{ID: 1}
_, err := dORM.Insert(post)
throwFailNow(t, err)
}
@ -1394,6 +1454,18 @@ func TestQueryRelate(t *testing.T) {
// throwFailNow(t, AssertIs(num, 2))
}
func TestPkManyRelated(t *testing.T) {
permission := &Permission{Name: "readPosts"}
err := dORM.Read(permission, "Name")
throwFailNow(t, err)
var groups []*Group
qs := dORM.QueryTable("Group")
num, err := qs.Filter("Permissions__Permission", permission.ID).All(&groups)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 2))
}
func TestPrepareInsert(t *testing.T) {
qs := dORM.QueryTable("user")
i, err := qs.PrepareInsert()
@ -1459,10 +1531,10 @@ func TestRawQueryRow(t *testing.T) {
Decimal float64
)
data_values := make(map[string]interface{}, len(Data_Values))
dataValues := make(map[string]interface{}, len(DataValues))
for k, v := range Data_Values {
data_values[strings.ToLower(k)] = v
for k, v := range DataValues {
dataValues[strings.ToLower(k)] = v
}
Q := dDbBaser.TableQuote()
@ -1488,14 +1560,14 @@ func TestRawQueryRow(t *testing.T) {
throwFail(t, AssertIs(id, 1))
case "date":
v = v.(time.Time).In(DefaultTimeLoc)
value := data_values[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, test_Date))
value := dataValues[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, testDate))
case "datetime":
v = v.(time.Time).In(DefaultTimeLoc)
value := data_values[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, test_DateTime))
value := dataValues[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, testDateTime))
default:
throwFail(t, AssertIs(v, data_values[col]))
throwFail(t, AssertIs(v, dataValues[col]))
}
}
@ -1529,16 +1601,16 @@ func TestQueryRows(t *testing.T) {
ind := reflect.Indirect(reflect.ValueOf(datas[0]))
for name, value := range Data_Values {
for name, value := range DataValues {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
}
throwFail(t, AssertIs(vu == value, true), value, vu)
}
@ -1553,16 +1625,16 @@ func TestQueryRows(t *testing.T) {
ind = reflect.Indirect(reflect.ValueOf(datas2[0]))
for name, value := range Data_Values {
for name, value := range DataValues {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDate)
case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime)
}
throwFail(t, AssertIs(vu == value, true), value, vu)
}
@ -1699,25 +1771,25 @@ func TestUpdate(t *testing.T) {
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{
"Nums": ColValue(Col_Add, 100),
"Nums": ColValue(ColAdd, 100),
})
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{
"Nums": ColValue(Col_Minus, 50),
"Nums": ColValue(ColMinus, 50),
})
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{
"Nums": ColValue(Col_Multiply, 3),
"Nums": ColValue(ColMultiply, 3),
})
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name", "slene").Update(Params{
"Nums": ColValue(Col_Except, 5),
"Nums": ColValue(ColExcept, 5),
})
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
@ -1838,15 +1910,15 @@ func TestReadOrCreate(t *testing.T) {
throwFail(t, AssertIs(u.Status, 7))
throwFail(t, AssertIs(u.IsStaff, false))
throwFail(t, AssertIs(u.IsActive, true))
throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), test_Date))
throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), test_DateTime))
throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), u.Created.In(DefaultTimeLoc), testDate))
throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), u.Updated.In(DefaultTimeLoc), testDateTime))
nu := &User{UserName: u.UserName, Email: "someotheremail@gmail.com"}
created, pk, err = dORM.ReadOrCreate(nu, "UserName")
throwFail(t, err)
throwFail(t, AssertIs(created, false))
throwFail(t, AssertIs(nu.Id, u.Id))
throwFail(t, AssertIs(pk, u.Id))
throwFail(t, AssertIs(nu.ID, u.ID))
throwFail(t, AssertIs(pk, u.ID))
throwFail(t, AssertIs(nu.UserName, u.UserName))
throwFail(t, AssertIs(nu.Email, u.Email)) // should contain the value in the table, not the one specified above
throwFail(t, AssertIs(nu.Password, u.Password))

View File

@ -16,6 +16,7 @@ package orm
import "errors"
// QueryBuilder is the Query builder interface
type QueryBuilder interface {
Select(fields ...string) QueryBuilder
From(tables ...string) QueryBuilder
@ -43,15 +44,18 @@ type QueryBuilder interface {
String() string
}
// NewQueryBuilder return the QueryBuilder
func NewQueryBuilder(driver string) (qb QueryBuilder, err error) {
if driver == "mysql" {
qb = new(MySQLQueryBuilder)
} else if driver == "tidb" {
qb = new(TiDBQueryBuilder)
} else if driver == "postgres" {
err = errors.New("postgres query builder is not supported yet!")
err = errors.New("postgres query builder is not supported yet")
} else if driver == "sqlite" {
err = errors.New("sqlite query builder is not supported yet!")
err = errors.New("sqlite query builder is not supported yet")
} else {
err = errors.New("unknown driver for query builder!")
err = errors.New("unknown driver for query builder")
}
return
}

View File

@ -20,134 +20,160 @@ import (
"strings"
)
const COMMA_SPACE = ", "
// CommaSpace is the separation
const CommaSpace = ", "
// MySQLQueryBuilder is the SQL build
type MySQLQueryBuilder struct {
Tokens []string
}
// Select will join the fields
func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, COMMA_SPACE))
qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace))
return qb
}
// From join the tables
func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, COMMA_SPACE))
qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace))
return qb
}
// InnerJoin INNER JOIN the table
func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INNER JOIN", table)
return qb
}
// LeftJoin LEFT JOIN the table
func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LEFT JOIN", table)
return qb
}
// RightJoin RIGHT JOIN the table
func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table)
return qb
}
// On join with on cond
func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "ON", cond)
return qb
}
// Where join the Where cond
func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "WHERE", cond)
return qb
}
// And join the and cond
func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "AND", cond)
return qb
}
// Or join the or cond
func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OR", cond)
return qb
}
// In join the IN (vals)
func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, COMMA_SPACE), ")")
qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")")
return qb
}
// OrderBy join the Order by fields
func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, COMMA_SPACE))
qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace))
return qb
}
// Asc join the asc
func (qb *MySQLQueryBuilder) Asc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "ASC")
return qb
}
// Desc join the desc
func (qb *MySQLQueryBuilder) Desc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "DESC")
return qb
}
// Limit join the limit num
func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit))
return qb
}
// Offset join the offset num
func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset))
return qb
}
// GroupBy join the Group by fields
func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, COMMA_SPACE))
qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace))
return qb
}
// Having join the Having cond
func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "HAVING", cond)
return qb
}
// Update join the update table
func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, COMMA_SPACE))
qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace))
return qb
}
// Set join the set kv
func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, COMMA_SPACE))
qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace))
return qb
}
// Delete join the Delete tables
func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "DELETE")
if len(tables) != 0 {
qb.Tokens = append(qb.Tokens, strings.Join(tables, COMMA_SPACE))
qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace))
}
return qb
}
// InsertInto join the insert SQL
func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INSERT INTO", table)
if len(fields) != 0 {
fieldsStr := strings.Join(fields, COMMA_SPACE)
fieldsStr := strings.Join(fields, CommaSpace)
qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")")
}
return qb
}
// Values join the Values(vals)
func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder {
valsStr := strings.Join(vals, COMMA_SPACE)
valsStr := strings.Join(vals, CommaSpace)
qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")")
return qb
}
// Subquery join the sub as alias
func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string {
return fmt.Sprintf("(%s) AS %s", sub, alias)
}
// String join all Tokens
func (qb *MySQLQueryBuilder) String() string {
return strings.Join(qb.Tokens, " ")
}

176
orm/qb_tidb.go Normal file
View File

@ -0,0 +1,176 @@
// Copyright 2015 TiDB Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"strconv"
"strings"
)
// TiDBQueryBuilder is the SQL build
type TiDBQueryBuilder struct {
Tokens []string
}
// Select will join the fields
func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace))
return qb
}
// From join the tables
func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace))
return qb
}
// InnerJoin INNER JOIN the table
func (qb *TiDBQueryBuilder) InnerJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INNER JOIN", table)
return qb
}
// LeftJoin LEFT JOIN the table
func (qb *TiDBQueryBuilder) LeftJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LEFT JOIN", table)
return qb
}
// RightJoin RIGHT JOIN the table
func (qb *TiDBQueryBuilder) RightJoin(table string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table)
return qb
}
// On join with on cond
func (qb *TiDBQueryBuilder) On(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "ON", cond)
return qb
}
// Where join the Where cond
func (qb *TiDBQueryBuilder) Where(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "WHERE", cond)
return qb
}
// And join the and cond
func (qb *TiDBQueryBuilder) And(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "AND", cond)
return qb
}
// Or join the or cond
func (qb *TiDBQueryBuilder) Or(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OR", cond)
return qb
}
// In join the IN (vals)
func (qb *TiDBQueryBuilder) In(vals ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")")
return qb
}
// OrderBy join the Order by fields
func (qb *TiDBQueryBuilder) OrderBy(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace))
return qb
}
// Asc join the asc
func (qb *TiDBQueryBuilder) Asc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "ASC")
return qb
}
// Desc join the desc
func (qb *TiDBQueryBuilder) Desc() QueryBuilder {
qb.Tokens = append(qb.Tokens, "DESC")
return qb
}
// Limit join the limit num
func (qb *TiDBQueryBuilder) Limit(limit int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit))
return qb
}
// Offset join the offset num
func (qb *TiDBQueryBuilder) Offset(offset int) QueryBuilder {
qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset))
return qb
}
// GroupBy join the Group by fields
func (qb *TiDBQueryBuilder) GroupBy(fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace))
return qb
}
// Having join the Having cond
func (qb *TiDBQueryBuilder) Having(cond string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "HAVING", cond)
return qb
}
// Update join the update table
func (qb *TiDBQueryBuilder) Update(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace))
return qb
}
// Set join the set kv
func (qb *TiDBQueryBuilder) Set(kv ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace))
return qb
}
// Delete join the Delete tables
func (qb *TiDBQueryBuilder) Delete(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "DELETE")
if len(tables) != 0 {
qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace))
}
return qb
}
// InsertInto join the insert SQL
func (qb *TiDBQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "INSERT INTO", table)
if len(fields) != 0 {
fieldsStr := strings.Join(fields, CommaSpace)
qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")")
}
return qb
}
// Values join the Values(vals)
func (qb *TiDBQueryBuilder) Values(vals ...string) QueryBuilder {
valsStr := strings.Join(vals, CommaSpace)
qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")")
return qb
}
// Subquery join the sub as alias
func (qb *TiDBQueryBuilder) Subquery(sub string, alias string) string {
return fmt.Sprintf("(%s) AS %s", sub, alias)
}
// String join all Tokens
func (qb *TiDBQueryBuilder) String() string {
return strings.Join(qb.Tokens, " ")
}

View File

@ -20,13 +20,13 @@ import (
"time"
)
// database driver
// Driver define database driver
type Driver interface {
Name() string
Type() DriverType
}
// field info
// Fielder define field info
type Fielder interface {
String() string
FieldType() int
@ -34,84 +34,315 @@ type Fielder interface {
RawValue() interface{}
}
// orm struct
// Ormer define the orm interface
type Ormer interface {
Read(interface{}, ...string) error
ReadOrCreate(interface{}, string, ...string) (bool, int64, error)
// read data to model
// for example:
// this will find User by Id field
// u = &User{Id: user.Id}
// err = Ormer.Read(u)
// this will find User by UserName field
// u = &User{UserName: "astaxie", Password: "pass"}
// err = Ormer.Read(u, "UserName")
Read(md interface{}, cols ...string) error
// Try to read a row from the database, or insert one if it doesn't exist
ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error)
// insert model data to database
// for example:
// user := new(User)
// id, err = Ormer.Insert(user)
// user must a pointer and Insert will set user's pk field
Insert(interface{}) (int64, error)
InsertMulti(int, interface{}) (int64, error)
Update(interface{}, ...string) (int64, error)
Delete(interface{}) (int64, error)
LoadRelated(interface{}, string, ...interface{}) (int64, error)
QueryM2M(interface{}, string) QueryM2Mer
QueryTable(interface{}) QuerySeter
Using(string) error
// insert some models to database
InsertMulti(bulk int, mds interface{}) (int64, error)
// update model to database.
// cols set the columns those want to update.
// find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns
// for example:
// user := User{Id: 2}
// user.Langs = append(user.Langs, "zh-CN", "en-US")
// user.Extra.Name = "beego"
// user.Extra.Data = "orm"
// num, err = Ormer.Update(&user, "Langs", "Extra")
Update(md interface{}, cols ...string) (int64, error)
// delete model in database
Delete(md interface{}) (int64, error)
// load related models to md model.
// args are limit, offset int and order string.
//
// example:
// Ormer.LoadRelated(post,"Tags")
// for _,tag := range post.Tags{...}
//args[0] bool true useDefaultRelsDepth ; false depth 0
//args[0] int loadRelationDepth
//args[1] int limit default limit 1000
//args[2] int offset default offset 0
//args[3] string order for example : "-Id"
// make sure the relation is defined in model struct tags.
LoadRelated(md interface{}, name string, args ...interface{}) (int64, error)
// create a models to models queryer
// for example:
// post := Post{Id: 4}
// m2m := Ormer.QueryM2M(&post, "Tags")
QueryM2M(md interface{}, name string) QueryM2Mer
// return a QuerySeter for table operations.
// table name can be string or struct.
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
QueryTable(ptrStructOrTableName interface{}) QuerySeter
// switch to another registered database driver by given name.
Using(name string) error
// begin transaction
// for example:
// o := NewOrm()
// err := o.Begin()
// ...
// err = o.Rollback()
Begin() error
// commit transaction
Commit() error
// rollback transaction
Rollback() error
Raw(string, ...interface{}) RawSeter
// return a raw query seter for raw sql string.
// for example:
// ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec()
// // update user testing's name to slene
Raw(query string, args ...interface{}) RawSeter
Driver() Driver
}
// insert prepared statement
// Inserter insert prepared statement
type Inserter interface {
Insert(interface{}) (int64, error)
Close() error
}
// query seter
// QuerySeter query seter
type QuerySeter interface {
// add condition expression to QuerySeter.
// for example:
// filter by UserName == 'slene'
// qs.Filter("UserName", "slene")
// sql : left outer join profile on t0.id1==t1.id2 where t1.age == 28
// Filter("profile__Age", 28)
// // time compare
// qs.Filter("created", time.Now())
Filter(string, ...interface{}) QuerySeter
// add NOT condition to querySeter.
// have the same usage as Filter
Exclude(string, ...interface{}) QuerySeter
// set condition to QuerySeter.
// sql's where condition
// cond := orm.NewCondition()
// cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000)
// //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
// num, err := qs.SetCond(cond1).Count()
SetCond(*Condition) QuerySeter
Limit(interface{}, ...interface{}) QuerySeter
Offset(interface{}) QuerySeter
OrderBy(...string) QuerySeter
RelatedSel(...interface{}) QuerySeter
// add LIMIT value.
// args[0] means offset, e.g. LIMIT num,offset.
// if Limit <= 0 then Limit will be set to default limit ,eg 1000
// if QuerySeter doesn't call Limit, the sql's Limit will be set to default limit, eg 1000
// for example:
// qs.Limit(10, 2)
// // sql-> limit 10 offset 2
Limit(limit interface{}, args ...interface{}) QuerySeter
// add OFFSET value
// same as Limit function's args[0]
Offset(offset interface{}) QuerySeter
// add ORDER expression.
// "column" means ASC, "-column" means DESC.
// for example:
// qs.OrderBy("-status")
OrderBy(exprs ...string) QuerySeter
// set relation model to query together.
// it will query relation models and assign to parent model.
// for example:
// // will load all related fields use left join .
// qs.RelatedSel().One(&user)
// // will load related field only profile
// qs.RelatedSel("profile").One(&user)
// user.Profile.Age = 32
RelatedSel(params ...interface{}) QuerySeter
// return QuerySeter execution result number
// for example:
// num, err = qs.Filter("profile__age__gt", 28).Count()
Count() (int64, error)
// check result empty or not after QuerySeter executed
// the same as QuerySeter.Count > 0
Exist() bool
Update(Params) (int64, error)
// execute update with parameters
// for example:
// num, err = qs.Filter("user_name", "slene").Update(Params{
// "Nums": ColValue(Col_Minus, 50),
// }) // user slene's Nums will minus 50
// num, err = qs.Filter("UserName", "slene").Update(Params{
// "user_name": "slene2"
// }) // user slene's name will change to slene2
Update(values Params) (int64, error)
// delete from table
//for example:
// num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete()
// //delete two user who's name is testing1 or testing2
Delete() (int64, error)
// return a insert queryer.
// it can be used in times.
// example:
// i,err := sq.PrepareInsert()
// num, err = i.Insert(&user1) // user table will add one record user1 at once
// num, err = i.Insert(&user2) // user table will add one record user2 at once
// err = i.Close() //don't forget call Close
PrepareInsert() (Inserter, error)
All(interface{}, ...string) (int64, error)
One(interface{}, ...string) error
Values(*[]Params, ...string) (int64, error)
ValuesList(*[]ParamsList, ...string) (int64, error)
ValuesFlat(*ParamsList, string) (int64, error)
RowsToMap(*Params, string, string) (int64, error)
RowsToStruct(interface{}, string, string) (int64, error)
// query all data and map to containers.
// cols means the columns when querying.
// for example:
// var users []*User
// qs.All(&users) // users[0],users[1],users[2] ...
All(container interface{}, cols ...string) (int64, error)
// query one row data and map to containers.
// cols means the columns when querying.
// for example:
// var user User
// qs.One(&user) //user.UserName == "slene"
One(container interface{}, cols ...string) error
// query all data and map to []map[string]interface.
// expres means condition expression.
// it converts data to []map[column]value.
// for example:
// var maps []Params
// qs.Values(&maps) //maps[0]["UserName"]=="slene"
Values(results *[]Params, exprs ...string) (int64, error)
// query all data and map to [][]interface
// it converts data to [][column_index]value
// for example:
// var list []ParamsList
// qs.ValuesList(&list) // list[0][1] == "slene"
ValuesList(results *[]ParamsList, exprs ...string) (int64, error)
// query all data and map to []interface.
// it's designed for one column record set, auto change to []value, not [][column]value.
// for example:
// var list ParamsList
// qs.ValuesFlat(&list, "UserName") // list[0] == "slene"
ValuesFlat(result *ParamsList, expr string) (int64, error)
// query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to map[string]interface{}{
// "total": 100,
// "found": 200,
// }
RowsToMap(result *Params, keyCol, valueCol string) (int64, error)
// query all rows into struct with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to struct {
// Total int
// Found int
// }
RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
}
// model to model query struct
// QueryM2Mer model to model query struct
// all operations are on the m2m table only, will not affect the origin model table
type QueryM2Mer interface {
// add models to origin models when creating queryM2M.
// example:
// m2m := orm.QueryM2M(post,"Tag")
// m2m.Add(&Tag1{},&Tag2{})
// for _,tag := range post.Tags{}{ ... }
// param could also be any of the follow
// []*Tag{{Id:3,Name: "TestTag1"}, {Id:4,Name: "TestTag2"}}
// &Tag{Id:5,Name: "TestTag3"}
// []interface{}{&Tag{Id:6,Name: "TestTag4"}}
// insert one or more rows to m2m table
// make sure the relation is defined in post model struct tag.
Add(...interface{}) (int64, error)
// remove models following the origin model relationship
// only delete rows from m2m table
// for example:
//tag3 := &Tag{Id:5,Name: "TestTag3"}
//num, err = m2m.Remove(tag3)
Remove(...interface{}) (int64, error)
// check model is existed in relationship of origin model
Exist(interface{}) bool
// clean all models in related of origin model
Clear() (int64, error)
// count all related models of origin model
Count() (int64, error)
}
// raw query statement
// RawPreparer raw query statement
type RawPreparer interface {
Exec(...interface{}) (sql.Result, error)
Close() error
}
// raw query seter
// RawSeter raw query seter
// create From Ormer.Raw
// for example:
// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q)
// rs := Ormer.Raw(sql, 1)
type RawSeter interface {
//execute sql and get result
Exec() (sql.Result, error)
QueryRow(...interface{}) error
QueryRows(...interface{}) (int64, error)
//query data and map to container
//for example:
// var name string
// var id int
// rs.QueryRow(&id,&name) // id==2 name=="slene"
QueryRow(containers ...interface{}) error
// query data rows and map to container
// var ids []int
// var names []int
// query = fmt.Sprintf("SELECT 'id','name' FROM %suser%s", Q, Q)
// num, err = dORM.Raw(query).QueryRows(&ids,&names) // ids=>{1,2},names=>{"nobody","slene"}
QueryRows(containers ...interface{}) (int64, error)
SetArgs(...interface{}) RawSeter
Values(*[]Params, ...string) (int64, error)
ValuesList(*[]ParamsList, ...string) (int64, error)
ValuesFlat(*ParamsList, ...string) (int64, error)
RowsToMap(*Params, string, string) (int64, error)
RowsToStruct(interface{}, string, string) (int64, error)
// query data to []map[string]interface
// see QuerySeter's Values
Values(container *[]Params, cols ...string) (int64, error)
// query data to [][]interface
// see QuerySeter's ValuesList
ValuesList(container *[]ParamsList, cols ...string) (int64, error)
// query data to []interface
// see QuerySeter's ValuesFlat
ValuesFlat(container *ParamsList, cols ...string) (int64, error)
// query all rows into map[string]interface with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to map[string]interface{}{
// "total": 100,
// "found": 200,
// }
RowsToMap(result *Params, keyCol, valueCol string) (int64, error)
// query all rows into struct with specify key and value column name.
// keyCol = "name", valueCol = "value"
// table data
// name | value
// total | 100
// found | 200
// to struct {
// Total int
// Found int
// }
RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
// return prepared raw statement for used in times.
// for example:
// pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare()
// r, err := pre.Exec("name1") // INSERT INTO tag (name) VALUES (`name1`)
Prepare() (RawPreparer, error)
}
// statement querier
// stmtQuerier statement querier
type stmtQuerier interface {
Close() error
Exec(args ...interface{}) (sql.Result, error)
@ -160,8 +391,8 @@ type dbBaser interface {
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
OperatorSql(string) string
GenerateOperatorSql(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
OperatorSQL(string) string
GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{})
GenerateOperatorLeftCol(*fieldInfo, string, *string)
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error)

View File

@ -22,9 +22,10 @@ import (
"time"
)
// StrTo is the target string
type StrTo string
// set string
// Set string
func (f *StrTo) Set(v string) {
if v != "" {
*f = StrTo(v)
@ -33,93 +34,93 @@ func (f *StrTo) Set(v string) {
}
}
// clean string
// Clear string
func (f *StrTo) Clear() {
*f = StrTo(0x1E)
}
// check string exist
// Exist check string exist
func (f StrTo) Exist() bool {
return string(f) != string(0x1E)
}
// string to bool
// Bool string to bool
func (f StrTo) Bool() (bool, error) {
return strconv.ParseBool(f.String())
}
// string to float32
// Float32 string to float32
func (f StrTo) Float32() (float32, error) {
v, err := strconv.ParseFloat(f.String(), 32)
return float32(v), err
}
// string to float64
// Float64 string to float64
func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64)
}
// string to int
// Int string to int
func (f StrTo) Int() (int, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int(v), err
}
// string to int8
// Int8 string to int8
func (f StrTo) Int8() (int8, error) {
v, err := strconv.ParseInt(f.String(), 10, 8)
return int8(v), err
}
// string to int16
// Int16 string to int16
func (f StrTo) Int16() (int16, error) {
v, err := strconv.ParseInt(f.String(), 10, 16)
return int16(v), err
}
// string to int32
// Int32 string to int32
func (f StrTo) Int32() (int32, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int32(v), err
}
// string to int64
// Int64 string to int64
func (f StrTo) Int64() (int64, error) {
v, err := strconv.ParseInt(f.String(), 10, 64)
return int64(v), err
}
// string to uint
// Uint string to uint
func (f StrTo) Uint() (uint, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint(v), err
}
// string to uint8
// Uint8 string to uint8
func (f StrTo) Uint8() (uint8, error) {
v, err := strconv.ParseUint(f.String(), 10, 8)
return uint8(v), err
}
// string to uint16
// Uint16 string to uint16
func (f StrTo) Uint16() (uint16, error) {
v, err := strconv.ParseUint(f.String(), 10, 16)
return uint16(v), err
}
// string to uint31
// Uint32 string to uint31
func (f StrTo) Uint32() (uint32, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint32(v), err
}
// string to uint64
// Uint64 string to uint64
func (f StrTo) Uint64() (uint64, error) {
v, err := strconv.ParseUint(f.String(), 10, 64)
return uint64(v), err
}
// string to string
// String string to string
func (f StrTo) String() string {
if f.Exist() {
return string(f)
@ -127,7 +128,7 @@ func (f StrTo) String() string {
return ""
}
// interface to string
// ToStr interface to string
func ToStr(value interface{}, args ...int) (s string) {
switch v := value.(type) {
case bool:
@ -166,7 +167,7 @@ func ToStr(value interface{}, args ...int) (s string) {
return s
}
// interface to int64
// ToInt64 interface to int64
func ToInt64(value interface{}) (d int64) {
val := reflect.ValueOf(value)
switch value.(type) {
@ -195,7 +196,7 @@ func snakeString(s string) string {
}
data = append(data, d)
}
return strings.ToLower(string(data[:len(data)]))
return strings.ToLower(string(data[:]))
}
// camel string, xx_yy to XxYy
@ -220,7 +221,7 @@ func camelString(s string) string {
}
data = append(data, d)
}
return string(data[:len(data)])
return string(data[:])
}
type argString []string
@ -248,30 +249,12 @@ func (a argInt) Get(i int, args ...int) (r int) {
return
}
type argAny []interface{}
// get interface by index from interface slice
func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
if i >= 0 && i < len(a) {
r = a[i]
}
if len(args) > 0 {
r = args[0]
}
return
}
// parse time to string with location
func timeParse(dateString, format string) (time.Time, error) {
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
return tp, err
}
// format time string
func timeFormat(t time.Time, format string) string {
return t.Format(format)
}
// get pointer indirect type
func indirectType(v reflect.Type) reflect.Type {
switch v.Kind() {

Some files were not shown because too many files have changed in this diff Show More