Merge pull request #3 from astaxie/develop

develop
This commit is contained in:
Waleed Gadelkareem 2019-02-14 16:23:35 +01:00 committed by GitHub
commit 1942438b22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
136 changed files with 16996 additions and 758 deletions

View File

@ -1,4 +0,0 @@
github.com/astaxie/beego/*/*:S1012
github.com/astaxie/beego/*:S1012
github.com/astaxie/beego/*/*:S1007
github.com/astaxie/beego/*:S1007

View File

@ -1,9 +1,8 @@
language: go language: go
go: go:
- 1.6.4 - "1.10.x"
- 1.7.5 - "1.11.x"
- 1.8.1
services: services:
- redis-server - redis-server
- mysql - mysql
@ -11,7 +10,6 @@ services:
- memcached - memcached
env: env:
- ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db - ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db
- ORM_DRIVER=mysql ORM_SOURCE="root:@/orm_test?charset=utf8"
- ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" - ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
before_install: before_install:
- git clone git://github.com/ideawu/ssdb.git - git clone git://github.com/ideawu/ssdb.git
@ -23,10 +21,11 @@ install:
- go get github.com/go-sql-driver/mysql - go get github.com/go-sql-driver/mysql
- go get github.com/mattn/go-sqlite3 - go get github.com/mattn/go-sqlite3
- go get github.com/bradfitz/gomemcache/memcache - go get github.com/bradfitz/gomemcache/memcache
- go get github.com/garyburd/redigo/redis - go get github.com/gomodule/redigo/redis
- go get github.com/beego/x2j - go get github.com/beego/x2j
- go get github.com/couchbase/go-couchbase - go get github.com/couchbase/go-couchbase
- go get github.com/beego/goyaml2 - go get github.com/beego/goyaml2
- go get gopkg.in/yaml.v2
- go get github.com/belogik/goes - go get github.com/belogik/goes
- go get github.com/siddontang/ledisdb/config - go get github.com/siddontang/ledisdb/config
- go get github.com/siddontang/ledisdb/ledis - go get github.com/siddontang/ledisdb/ledis
@ -35,28 +34,31 @@ install:
- go get github.com/gogo/protobuf/proto - go get github.com/gogo/protobuf/proto
- go get github.com/Knetic/govaluate - go get github.com/Knetic/govaluate
- go get github.com/casbin/casbin - go get github.com/casbin/casbin
- go get -u honnef.co/go/tools/cmd/gosimple - go get github.com/elazarl/go-bindata-assetfs
- go get github.com/OwnLocal/goes
- go get -u honnef.co/go/tools/cmd/staticcheck
- go get -u github.com/mdempsky/unconvert - go get -u github.com/mdempsky/unconvert
- go get -u github.com/gordonklaus/ineffassign - go get -u github.com/gordonklaus/ineffassign
- go get -u github.com/golang/lint/golint - go get -u github.com/golang/lint/golint
- go get -u github.com/go-redis/redis
before_script: before_script:
- psql --version - psql --version
- sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi" - sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi"
- sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi" - sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi"
- sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi" - sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi"
- sh -c "if [ $(go version) == *1.[5-9]* ]; then go get github.com/golang/lint/golint; golint ./...; fi" - sh -c "go get github.com/golang/lint/golint; golint ./...;"
- sh -c "if [ $(go version) == *1.[5-9]* ]; then go tool vet .; fi" - sh -c "go list ./... | grep -v vendor | xargs go vet -v"
- mkdir -p res/var - mkdir -p res/var
- ./ssdb/ssdb-server ./ssdb/ssdb.conf -d - ./ssdb/ssdb-server ./ssdb/ssdb.conf -d
after_script: after_script:
-killall -w ssdb-server - killall -w ssdb-server
- rm -rf ./res/var/* - rm -rf ./res/var/*
script: script:
- go test -v ./... - go test -v ./...
- gosimple -ignore "$(cat .gosimpleignore)" $(go list ./... | grep -v /vendor/) - staticcheck -show-ignored -checks "-ST1017,-U1000,-ST1005,-S1034,-S1012,-SA4006,-SA6005,-SA1019,-SA1024"
- unconvert $(go list ./... | grep -v /vendor/) - unconvert $(go list ./... | grep -v /vendor/)
- ineffassign . - ineffassign .
- find . ! \( -path './vendor' -prune \) -type f -name '*.go' -print0 | xargs -0 gofmt -l -s - find . ! \( -path './vendor' -prune \) -type f -name '*.go' -print0 | xargs -0 gofmt -l -s
- golint ./... - golint ./...
addons: addons:
postgresql: "9.4" postgresql: "9.6"

View File

@ -1,8 +1,11 @@
# Beego [![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego) [![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego) [![Foundation](https://img.shields.io/badge/Golang-Foundation-green.svg)](http://golangfoundation.org) # Beego [![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego) [![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego) [![Foundation](https://img.shields.io/badge/Golang-Foundation-green.svg)](http://golangfoundation.org) [![Go Report Card](https://goreportcard.com/badge/github.com/astaxie/beego)](https://goreportcard.com/report/github.com/astaxie/beego)
beego is used for rapid development of RESTful APIs, web apps and backend services in Go. beego is used for rapid development of RESTful APIs, web apps and backend services in Go.
It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct embedding. It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct embedding.
Response time ranking: [web-frameworks](https://github.com/the-benchmarker/web-frameworks).
###### More info at [beego.me](http://beego.me). ###### More info at [beego.me](http://beego.me).
## Quick Start ## Quick Start

View File

@ -20,11 +20,10 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
"reflect"
"text/template" "text/template"
"time" "time"
"reflect"
"github.com/astaxie/beego/grace" "github.com/astaxie/beego/grace"
"github.com/astaxie/beego/logs" "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/toolbox" "github.com/astaxie/beego/toolbox"
@ -35,7 +34,7 @@ import (
var beeAdminApp *adminApp var beeAdminApp *adminApp
// FilterMonitorFunc is default monitor filter when admin module is enable. // FilterMonitorFunc is default monitor filter when admin module is enable.
// if this func returns, admin module records qbs for this request by condition of this function logic. // if this func returns, admin module records qps for this request by condition of this function logic.
// usage: // usage:
// func MyFilterMonitor(method, requestPath string, t time.Duration, pattern string, statusCode int) bool { // func MyFilterMonitor(method, requestPath string, t time.Duration, pattern string, statusCode int) bool {
// if method == "POST" { // if method == "POST" {
@ -67,15 +66,27 @@ func init() {
// AdminIndex is the default http.Handler for admin module. // AdminIndex is the default http.Handler for admin module.
// it matches url pattern "/". // it matches url pattern "/".
func adminIndex(rw http.ResponseWriter, r *http.Request) { func adminIndex(rw http.ResponseWriter, _ *http.Request) {
execTpl(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl) execTpl(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl)
} }
// QpsIndex is the http.Handler for writing qbs statistics map result info in http.ResponseWriter. // QpsIndex is the http.Handler for writing qps statistics map result info in http.ResponseWriter.
// it's registered with url pattern "/qbs" in admin module. // it's registered with url pattern "/qps" in admin module.
func qpsIndex(rw http.ResponseWriter, r *http.Request) { func qpsIndex(rw http.ResponseWriter, _ *http.Request) {
data := make(map[interface{}]interface{}) data := make(map[interface{}]interface{})
data["Content"] = toolbox.StatisticsMap.GetMap() data["Content"] = toolbox.StatisticsMap.GetMap()
// do html escape before display path, avoid xss
if content, ok := (data["Content"]).(M); ok {
if resultLists, ok := (content["Data"]).([][]string); ok {
for i := range resultLists {
if len(resultLists[i]) > 0 {
resultLists[i][0] = template.HTMLEscapeString(resultLists[i][0])
}
}
}
}
execTpl(rw, data, qpsTpl, defaultScriptsTpl) execTpl(rw, data, qpsTpl, defaultScriptsTpl)
} }
@ -92,7 +103,7 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
data := make(map[interface{}]interface{}) data := make(map[interface{}]interface{})
switch command { switch command {
case "conf": case "conf":
m := make(map[string]interface{}) m := make(M)
list("BConfig", BConfig, m) list("BConfig", BConfig, m)
m["AppConfigPath"] = appConfigPath m["AppConfigPath"] = appConfigPath
m["AppConfigProvider"] = appConfigProvider m["AppConfigProvider"] = appConfigProvider
@ -116,14 +127,14 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl) execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl)
case "filter": case "filter":
var ( var (
content = map[string]interface{}{ content = M{
"Fields": []string{ "Fields": []string{
"Router Pattern", "Router Pattern",
"Filter Function", "Filter Function",
}, },
} }
filterTypes = []string{} filterTypes = []string{}
filterTypeData = make(map[string]interface{}) filterTypeData = make(M)
) )
if BeeApp.Handlers.enableFilter { if BeeApp.Handlers.enableFilter {
@ -161,7 +172,7 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
} }
} }
func list(root string, p interface{}, m map[string]interface{}) { func list(root string, p interface{}, m M) {
pt := reflect.TypeOf(p) pt := reflect.TypeOf(p)
pv := reflect.ValueOf(p) pv := reflect.ValueOf(p)
if pt.Kind() == reflect.Ptr { if pt.Kind() == reflect.Ptr {
@ -184,11 +195,11 @@ func list(root string, p interface{}, m map[string]interface{}) {
} }
// PrintTree prints all registered routers. // PrintTree prints all registered routers.
func PrintTree() map[string]interface{} { func PrintTree() M {
var ( var (
content = map[string]interface{}{} content = M{}
methods = []string{} methods = []string{}
methodsData = make(map[string]interface{}) methodsData = make(M)
) )
for method, t := range BeeApp.Handlers.routers { for method, t := range BeeApp.Handlers.routers {
@ -279,12 +290,12 @@ func profIndex(rw http.ResponseWriter, r *http.Request) {
// Healthcheck is a http.Handler calling health checking and showing the result. // Healthcheck is a http.Handler calling health checking and showing the result.
// it's in "/healthcheck" pattern in admin module. // it's in "/healthcheck" pattern in admin module.
func healthcheck(rw http.ResponseWriter, req *http.Request) { func healthcheck(rw http.ResponseWriter, _ *http.Request) {
var ( var (
result []string result []string
data = make(map[interface{}]interface{}) data = make(map[interface{}]interface{})
resultList = new([][]string) resultList = new([][]string)
content = map[string]interface{}{ content = M{
"Fields": []string{"Name", "Message", "Status"}, "Fields": []string{"Name", "Message", "Status"},
} }
) )
@ -332,7 +343,7 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) {
} }
// List Tasks // List Tasks
content := make(map[string]interface{}) content := make(M)
resultList := new([][]string) resultList := new([][]string)
var fields = []string{ var fields = []string{
"Task Name", "Task Name",

View File

@ -6,7 +6,7 @@ import (
) )
func TestList_01(t *testing.T) { func TestList_01(t *testing.T) {
m := make(map[string]interface{}) m := make(M)
list("BConfig", BConfig, m) list("BConfig", BConfig, m)
t.Log(m) t.Log(m)
om := oldMap() om := oldMap()
@ -18,8 +18,8 @@ func TestList_01(t *testing.T) {
} }
} }
func oldMap() map[string]interface{} { func oldMap() M {
m := make(map[string]interface{}) m := make(M)
m["BConfig.AppName"] = BConfig.AppName m["BConfig.AppName"] = BConfig.AppName
m["BConfig.RunMode"] = BConfig.RunMode m["BConfig.RunMode"] = BConfig.RunMode
m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive
@ -67,6 +67,7 @@ func oldMap() map[string]interface{} {
m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain
m["BConfig.WebConfig.Session.SessionDisableHTTPOnly"] = BConfig.WebConfig.Session.SessionDisableHTTPOnly m["BConfig.WebConfig.Session.SessionDisableHTTPOnly"] = BConfig.WebConfig.Session.SessionDisableHTTPOnly
m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs
m["BConfig.Log.EnableStaticLogs"] = BConfig.Log.EnableStaticLogs
m["BConfig.Log.AccessLogsFormat"] = BConfig.Log.AccessLogsFormat m["BConfig.Log.AccessLogsFormat"] = BConfig.Log.AccessLogsFormat
m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum
m["BConfig.Log.Outputs"] = BConfig.Log.Outputs m["BConfig.Log.Outputs"] = BConfig.Log.Outputs

151
app.go
View File

@ -15,17 +15,22 @@
package beego package beego
import ( import (
"crypto/tls"
"crypto/x509"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"net/http/fcgi" "net/http/fcgi"
"os" "os"
"path" "path"
"strings"
"time" "time"
"github.com/astaxie/beego/grace" "github.com/astaxie/beego/grace"
"github.com/astaxie/beego/logs" "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
"golang.org/x/crypto/acme/autocert"
) )
var ( var (
@ -51,8 +56,11 @@ func NewApp() *App {
return app return app
} }
// MiddleWare function for http.Handler
type MiddleWare func(http.Handler) http.Handler
// Run beego application. // Run beego application.
func (app *App) Run() { func (app *App) Run(mws ...MiddleWare) {
addr := BConfig.Listen.HTTPAddr addr := BConfig.Listen.HTTPAddr
if BConfig.Listen.HTTPPort != 0 { if BConfig.Listen.HTTPPort != 0 {
@ -94,6 +102,12 @@ func (app *App) Run() {
} }
app.Server.Handler = app.Handlers app.Server.Handler = app.Handlers
for i := len(mws) - 1; i >= 0; i-- {
if mws[i] == nil {
continue
}
app.Server.Handler = mws[i](app.Server.Handler)
}
app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second
app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second
app.Server.ErrorLog = logs.GetLogger("HTTP") app.Server.ErrorLog = logs.GetLogger("HTTP")
@ -102,9 +116,9 @@ func (app *App) Run() {
if BConfig.Listen.Graceful { if BConfig.Listen.Graceful {
httpsAddr := BConfig.Listen.HTTPSAddr httpsAddr := BConfig.Listen.HTTPSAddr
app.Server.Addr = httpsAddr app.Server.Addr = httpsAddr
if BConfig.Listen.EnableHTTPS { if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS {
go func() { go func() {
time.Sleep(20 * time.Microsecond) time.Sleep(1000 * time.Microsecond)
if BConfig.Listen.HTTPSPort != 0 { if BConfig.Listen.HTTPSPort != 0 {
httpsAddr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) httpsAddr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort)
app.Server.Addr = httpsAddr app.Server.Addr = httpsAddr
@ -112,10 +126,27 @@ func (app *App) Run() {
server := grace.NewServer(httpsAddr, app.Handlers) server := grace.NewServer(httpsAddr, app.Handlers)
server.Server.ReadTimeout = app.Server.ReadTimeout server.Server.ReadTimeout = app.Server.ReadTimeout
server.Server.WriteTimeout = app.Server.WriteTimeout server.Server.WriteTimeout = app.Server.WriteTimeout
if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { if BConfig.Listen.EnableMutualHTTPS {
logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) if err := server.ListenAndServeMutualTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile, BConfig.Listen.TrustCaFile); err != nil {
time.Sleep(100 * time.Microsecond) logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
endRunning <- true time.Sleep(100 * time.Microsecond)
endRunning <- true
}
} else {
if BConfig.Listen.AutoTLS {
m := autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...),
Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir),
}
app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate}
BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", ""
}
if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
} }
}() }()
} }
@ -139,22 +170,44 @@ func (app *App) Run() {
} }
// run normal mode // run normal mode
if BConfig.Listen.EnableHTTPS { if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS {
go func() { go func() {
time.Sleep(20 * time.Microsecond) time.Sleep(1000 * time.Microsecond)
if BConfig.Listen.HTTPSPort != 0 { if BConfig.Listen.HTTPSPort != 0 {
app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort)
} else if BConfig.Listen.EnableHTTP { } else if BConfig.Listen.EnableHTTP {
BeeLogger.Info("Start https server error, confict with http.Please reset https port") BeeLogger.Info("Start https server error, conflict with http. Please reset https port")
return return
} }
logs.Info("https server Running on https://%s", app.Server.Addr) logs.Info("https server Running on https://%s", app.Server.Addr)
if BConfig.Listen.AutoTLS {
m := autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...),
Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir),
}
app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate}
BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", ""
} else if BConfig.Listen.EnableMutualHTTPS {
pool := x509.NewCertPool()
data, err := ioutil.ReadFile(BConfig.Listen.TrustCaFile)
if err != nil {
BeeLogger.Info("MutualHTTPS should provide TrustCaFile")
return
}
pool.AppendCertsFromPEM(data)
app.Server.TLSConfig = &tls.Config{
ClientCAs: pool,
ClientAuth: tls.RequireAndVerifyClientCert,
}
}
if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
logs.Critical("ListenAndServeTLS: ", err) logs.Critical("ListenAndServeTLS: ", err)
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
endRunning <- true endRunning <- true
} }
}() }()
} }
if BConfig.Listen.EnableHTTP { if BConfig.Listen.EnableHTTP {
go func() { go func() {
@ -207,6 +260,84 @@ func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *A
return BeeApp return BeeApp
} }
// UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful
// in web applications that inherit most routes from a base webapp via the underscore
// import, and aim to overwrite only certain paths.
// The method parameter can be empty or "*" for all HTTP methods, or a particular
// method type (e.g. "GET" or "POST") for selective removal.
//
// Usage (replace "GET" with "*" for all methods):
// beego.UnregisterFixedRoute("/yourpreviouspath", "GET")
// beego.Router("/yourpreviouspath", yourControllerAddress, "get:GetNewPage")
func UnregisterFixedRoute(fixedRoute string, method string) *App {
subPaths := splitPath(fixedRoute)
if method == "" || method == "*" {
for m := range HTTPMETHOD {
if _, ok := BeeApp.Handlers.routers[m]; !ok {
continue
}
if BeeApp.Handlers.routers[m].prefix == strings.Trim(fixedRoute, "/ ") {
findAndRemoveSingleTree(BeeApp.Handlers.routers[m])
continue
}
findAndRemoveTree(subPaths, BeeApp.Handlers.routers[m], m)
}
return BeeApp
}
// Single HTTP method
um := strings.ToUpper(method)
if _, ok := BeeApp.Handlers.routers[um]; ok {
if BeeApp.Handlers.routers[um].prefix == strings.Trim(fixedRoute, "/ ") {
findAndRemoveSingleTree(BeeApp.Handlers.routers[um])
return BeeApp
}
findAndRemoveTree(subPaths, BeeApp.Handlers.routers[um], um)
}
return BeeApp
}
func findAndRemoveTree(paths []string, entryPointTree *Tree, method string) {
for i := range entryPointTree.fixrouters {
if entryPointTree.fixrouters[i].prefix == paths[0] {
if len(paths) == 1 {
if len(entryPointTree.fixrouters[i].fixrouters) > 0 {
// If the route had children subtrees, remove just the functional leaf,
// to allow children to function as before
if len(entryPointTree.fixrouters[i].leaves) > 0 {
entryPointTree.fixrouters[i].leaves[0] = nil
entryPointTree.fixrouters[i].leaves = entryPointTree.fixrouters[i].leaves[1:]
}
} else {
// Remove the *Tree from the fixrouters slice
entryPointTree.fixrouters[i] = nil
if i == len(entryPointTree.fixrouters)-1 {
entryPointTree.fixrouters = entryPointTree.fixrouters[:i]
} else {
entryPointTree.fixrouters = append(entryPointTree.fixrouters[:i], entryPointTree.fixrouters[i+1:len(entryPointTree.fixrouters)]...)
}
}
return
}
findAndRemoveTree(paths[1:], entryPointTree.fixrouters[i], method)
}
}
}
func findAndRemoveSingleTree(entryPointTree *Tree) {
if entryPointTree == nil {
return
}
if len(entryPointTree.fixrouters) > 0 {
// If the route had children subtrees, remove just the functional leaf,
// to allow children to function as before
if len(entryPointTree.leaves) > 0 {
entryPointTree.leaves[0] = nil
entryPointTree.leaves = entryPointTree.leaves[1:]
}
}
}
// Include will generate router file in the router/xxx.go from the controller's comments // Include will generate router file in the router/xxx.go from the controller's comments
// usage: // usage:
// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) // beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{})

View File

@ -23,7 +23,7 @@ import (
const ( const (
// VERSION represent beego web framework version. // VERSION represent beego web framework version.
VERSION = "1.9.0" VERSION = "1.11.1"
// DEV is for develop // DEV is for develop
DEV = "dev" DEV = "dev"
@ -31,7 +31,10 @@ const (
PROD = "prod" PROD = "prod"
) )
//hook function to run // M is Map shortcut
type M map[string]interface{}
// Hook function to run
type hookfunc func() error type hookfunc func() error
var ( var (
@ -62,11 +65,29 @@ func Run(params ...string) {
if len(strs) > 1 && strs[1] != "" { if len(strs) > 1 && strs[1] != "" {
BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1])
} }
BConfig.Listen.Domains = params
} }
BeeApp.Run() BeeApp.Run()
} }
// RunWithMiddleWares Run beego application with middlewares.
func RunWithMiddleWares(addr string, mws ...MiddleWare) {
initBeforeHTTPRun()
strs := strings.Split(addr, ":")
if len(strs) > 0 && strs[0] != "" {
BConfig.Listen.HTTPAddr = strs[0]
BConfig.Listen.Domains = []string{strs[0]}
}
if len(strs) > 1 && strs[1] != "" {
BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1])
}
BeeApp.Run(mws...)
}
func initBeforeHTTPRun() { func initBeforeHTTPRun() {
//init hooks //init hooks
AddAPPStartHook( AddAPPStartHook(

2
cache/README.md vendored
View File

@ -52,7 +52,7 @@ Configure like this:
## Redis adapter ## Redis adapter
Redis adapter use the [redigo](http://github.com/garyburd/redigo) client. Redis adapter use the [redigo](http://github.com/gomodule/redigo) client.
Configure like this: Configure like this:

2
cache/cache.go vendored
View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Package cache provide a Cache interface and some implemetn engine // Package cache provide a Cache interface and some implement engine
// Usage: // Usage:
// //
// import( // import(

2
cache/file.go vendored
View File

@ -65,7 +65,7 @@ func NewFileCache() Cache {
// the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0} // the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0}
func (fc *FileCache) StartAndGC(config string) error { func (fc *FileCache) StartAndGC(config string) error {
var cfg map[string]string cfg := make(map[string]string)
json.Unmarshal([]byte(config), &cfg) json.Unmarshal([]byte(config), &cfg)
if _, ok := cfg["CachePath"]; !ok { if _, ok := cfg["CachePath"]; !ok {
cfg["CachePath"] = FileCachePath cfg["CachePath"] = FileCachePath

View File

@ -146,7 +146,7 @@ func (rc *Cache) IsExist(key string) bool {
} }
} }
_, err := rc.conn.Get(key) _, err := rc.conn.Get(key)
return !(err != nil) return err == nil
} }
// ClearAll clear all cached in memcache. // ClearAll clear all cached in memcache.

44
cache/memory.go vendored
View File

@ -116,19 +116,19 @@ func (bc *MemoryCache) Incr(key string) error {
if !ok { if !ok {
return errors.New("key not exist") return errors.New("key not exist")
} }
switch itm.val.(type) { switch val := itm.val.(type) {
case int: case int:
itm.val = itm.val.(int) + 1 itm.val = val + 1
case int32: case int32:
itm.val = itm.val.(int32) + 1 itm.val = val + 1
case int64: case int64:
itm.val = itm.val.(int64) + 1 itm.val = val + 1
case uint: case uint:
itm.val = itm.val.(uint) + 1 itm.val = val + 1
case uint32: case uint32:
itm.val = itm.val.(uint32) + 1 itm.val = val + 1
case uint64: case uint64:
itm.val = itm.val.(uint64) + 1 itm.val = val + 1
default: default:
return errors.New("item val is not (u)int (u)int32 (u)int64") return errors.New("item val is not (u)int (u)int32 (u)int64")
} }
@ -143,28 +143,28 @@ func (bc *MemoryCache) Decr(key string) error {
if !ok { if !ok {
return errors.New("key not exist") return errors.New("key not exist")
} }
switch itm.val.(type) { switch val := itm.val.(type) {
case int: case int:
itm.val = itm.val.(int) - 1 itm.val = val - 1
case int64: case int64:
itm.val = itm.val.(int64) - 1 itm.val = val - 1
case int32: case int32:
itm.val = itm.val.(int32) - 1 itm.val = val - 1
case uint: case uint:
if itm.val.(uint) > 0 { if val > 0 {
itm.val = itm.val.(uint) - 1 itm.val = val - 1
} else { } else {
return errors.New("item val is less than 0") return errors.New("item val is less than 0")
} }
case uint32: case uint32:
if itm.val.(uint32) > 0 { if val > 0 {
itm.val = itm.val.(uint32) - 1 itm.val = val - 1
} else { } else {
return errors.New("item val is less than 0") return errors.New("item val is less than 0")
} }
case uint64: case uint64:
if itm.val.(uint64) > 0 { if val > 0 {
itm.val = itm.val.(uint64) - 1 itm.val = val - 1
} else { } else {
return errors.New("item val is less than 0") return errors.New("item val is less than 0")
} }
@ -203,13 +203,17 @@ func (bc *MemoryCache) StartAndGC(config string) error {
dur := time.Duration(cf["interval"]) * time.Second dur := time.Duration(cf["interval"]) * time.Second
bc.Every = cf["interval"] bc.Every = cf["interval"]
bc.dur = dur bc.dur = dur
go bc.vaccuum() go bc.vacuum()
return nil return nil
} }
// check expiration. // check expiration.
func (bc *MemoryCache) vaccuum() { func (bc *MemoryCache) vacuum() {
if bc.Every < 1 { bc.RLock()
every := bc.Every
bc.RUnlock()
if every < 1 {
return return
} }
for { for {

91
cache/redis/redis.go vendored
View File

@ -14,9 +14,9 @@
// Package redis for cache provider // Package redis for cache provider
// //
// depend on github.com/garyburd/redigo/redis // depend on github.com/gomodule/redigo/redis
// //
// go install github.com/garyburd/redigo/redis // go install github.com/gomodule/redigo/redis
// //
// Usage: // Usage:
// import( // import(
@ -32,12 +32,14 @@ package redis
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"strconv" "strconv"
"time" "time"
"github.com/garyburd/redigo/redis" "github.com/gomodule/redigo/redis"
"github.com/astaxie/beego/cache" "github.com/astaxie/beego/cache"
"strings"
) )
var ( var (
@ -52,6 +54,7 @@ type Cache struct {
dbNum int dbNum int
key string key string
password string password string
maxIdle int
} }
// NewRedisCache create new redis cache with default collection name. // NewRedisCache create new redis cache with default collection name.
@ -59,14 +62,23 @@ func NewRedisCache() cache.Cache {
return &Cache{key: DefaultKey} return &Cache{key: DefaultKey}
} }
// actually do the redis cmds // actually do the redis cmds, args[0] must be the key name.
func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) { func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) {
if len(args) < 1 {
return nil, errors.New("missing required arguments")
}
args[0] = rc.associate(args[0])
c := rc.p.Get() c := rc.p.Get()
defer c.Close() defer c.Close()
return c.Do(commandName, args...) return c.Do(commandName, args...)
} }
// associate with config key.
func (rc *Cache) associate(originKey interface{}) string {
return fmt.Sprintf("%s:%s", rc.key, originKey)
}
// Get cache from redis. // Get cache from redis.
func (rc *Cache) Get(key string) interface{} { func (rc *Cache) Get(key string) interface{} {
if v, err := rc.do("GET", key); err == nil { if v, err := rc.do("GET", key); err == nil {
@ -77,57 +89,28 @@ func (rc *Cache) Get(key string) interface{} {
// GetMulti get cache from redis. // GetMulti get cache from redis.
func (rc *Cache) GetMulti(keys []string) []interface{} { func (rc *Cache) GetMulti(keys []string) []interface{} {
size := len(keys)
var rv []interface{}
c := rc.p.Get() c := rc.p.Get()
defer c.Close() defer c.Close()
var err error var args []interface{}
for _, key := range keys { for _, key := range keys {
err = c.Send("GET", key) args = append(args, rc.associate(key))
if err != nil {
goto ERROR
}
} }
if err = c.Flush(); err != nil { values, err := redis.Values(c.Do("MGET", args...))
goto ERROR if err != nil {
return nil
} }
for i := 0; i < size; i++ { return values
if v, err := c.Receive(); err == nil {
rv = append(rv, v.([]byte))
} else {
rv = append(rv, err)
}
}
return rv
ERROR:
rv = rv[0:0]
for i := 0; i < size; i++ {
rv = append(rv, nil)
}
return rv
} }
// Put put cache to redis. // Put put cache to redis.
func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error {
var err error _, err := rc.do("SETEX", key, int64(timeout/time.Second), val)
if _, err = rc.do("SETEX", key, int64(timeout/time.Second), val); err != nil {
return err
}
if _, err = rc.do("HSET", rc.key, key, true); err != nil {
return err
}
return err return err
} }
// Delete delete cache in redis. // Delete delete cache in redis.
func (rc *Cache) Delete(key string) error { func (rc *Cache) Delete(key string) error {
var err error _, err := rc.do("DEL", key)
if _, err = rc.do("DEL", key); err != nil {
return err
}
_, err = rc.do("HDEL", rc.key, key)
return err return err
} }
@ -137,11 +120,6 @@ func (rc *Cache) IsExist(key string) bool {
if err != nil { if err != nil {
return false return false
} }
if !v {
if _, err = rc.do("HDEL", rc.key, key); err != nil {
return false
}
}
return v return v
} }
@ -159,16 +137,17 @@ func (rc *Cache) Decr(key string) error {
// ClearAll clean all cache in redis. delete this redis collection. // ClearAll clean all cache in redis. delete this redis collection.
func (rc *Cache) ClearAll() error { func (rc *Cache) ClearAll() error {
cachedKeys, err := redis.Strings(rc.do("HKEYS", rc.key)) c := rc.p.Get()
defer c.Close()
cachedKeys, err := redis.Strings(c.Do("KEYS", rc.key+":*"))
if err != nil { if err != nil {
return err return err
} }
for _, str := range cachedKeys { for _, str := range cachedKeys {
if _, err = rc.do("DEL", str); err != nil { if _, err = c.Do("DEL", str); err != nil {
return err return err
} }
} }
_, err = rc.do("DEL", rc.key)
return err return err
} }
@ -186,16 +165,28 @@ func (rc *Cache) StartAndGC(config string) error {
if _, ok := cf["conn"]; !ok { if _, ok := cf["conn"]; !ok {
return errors.New("config has no conn key") return errors.New("config has no conn key")
} }
// Format redis://<password>@<host>:<port>
cf["conn"] = strings.Replace(cf["conn"], "redis://", "", 1)
if i := strings.Index(cf["conn"], "@"); i > -1 {
cf["password"] = cf["conn"][0:i]
cf["conn"] = cf["conn"][i+1:]
}
if _, ok := cf["dbNum"]; !ok { if _, ok := cf["dbNum"]; !ok {
cf["dbNum"] = "0" cf["dbNum"] = "0"
} }
if _, ok := cf["password"]; !ok { if _, ok := cf["password"]; !ok {
cf["password"] = "" cf["password"] = ""
} }
if _, ok := cf["maxIdle"]; !ok {
cf["maxIdle"] = "3"
}
rc.key = cf["key"] rc.key = cf["key"]
rc.conninfo = cf["conn"] rc.conninfo = cf["conn"]
rc.dbNum, _ = strconv.Atoi(cf["dbNum"]) rc.dbNum, _ = strconv.Atoi(cf["dbNum"])
rc.password = cf["password"] rc.password = cf["password"]
rc.maxIdle, _ = strconv.Atoi(cf["maxIdle"])
rc.connectInit() rc.connectInit()
@ -229,7 +220,7 @@ func (rc *Cache) connectInit() {
} }
// initialize a new pool // initialize a new pool
rc.p = &redis.Pool{ rc.p = &redis.Pool{
MaxIdle: 3, MaxIdle: rc.maxIdle,
IdleTimeout: 180 * time.Second, IdleTimeout: 180 * time.Second,
Dial: dialFunc, Dial: dialFunc,
} }

View File

@ -19,7 +19,7 @@ import (
"time" "time"
"github.com/astaxie/beego/cache" "github.com/astaxie/beego/cache"
"github.com/garyburd/redigo/redis" "github.com/gomodule/redigo/redis"
) )
func TestRedisCache(t *testing.T) { func TestRedisCache(t *testing.T) {

View File

@ -49,22 +49,27 @@ type Config struct {
// Listen holds for http and https related config // Listen holds for http and https related config
type Listen struct { type Listen struct {
Graceful bool // Graceful means use graceful module to start the server Graceful bool // Graceful means use graceful module to start the server
ServerTimeOut int64 ServerTimeOut int64
ListenTCP4 bool ListenTCP4 bool
EnableHTTP bool EnableHTTP bool
HTTPAddr string HTTPAddr string
HTTPPort int HTTPPort int
EnableHTTPS bool AutoTLS bool
HTTPSAddr string Domains []string
HTTPSPort int TLSCacheDir string
HTTPSCertFile string EnableHTTPS bool
HTTPSKeyFile string EnableMutualHTTPS bool
EnableAdmin bool HTTPSAddr string
AdminAddr string HTTPSPort int
AdminPort int HTTPSCertFile string
EnableFcgi bool HTTPSKeyFile string
EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O TrustCaFile string
EnableAdmin bool
AdminAddr string
AdminPort int
EnableFcgi bool
EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O
} }
// WebConfig holds web related config // WebConfig holds web related config
@ -96,17 +101,18 @@ type SessionConfig struct {
SessionAutoSetCookie bool SessionAutoSetCookie bool
SessionDomain string SessionDomain string
SessionDisableHTTPOnly bool // used to allow for cross domain cookies/javascript cookies. SessionDisableHTTPOnly bool // used to allow for cross domain cookies/javascript cookies.
SessionEnableSidInHTTPHeader bool // enable store/get the sessionId into/from http headers SessionEnableSidInHTTPHeader bool // enable store/get the sessionId into/from http headers
SessionNameInHTTPHeader string SessionNameInHTTPHeader string
SessionEnableSidInURLQuery bool // enable get the sessionId from Url Query params SessionEnableSidInURLQuery bool // enable get the sessionId from Url Query params
} }
// LogConfig holds Log related config // LogConfig holds Log related config
type LogConfig struct { type LogConfig struct {
AccessLogs bool AccessLogs bool
AccessLogsFormat string //access log format: JSON_FORMAT, APACHE_FORMAT or empty string EnableStaticLogs bool //log static files requests default: false
FileLineNum bool AccessLogsFormat string //access log format: JSON_FORMAT, APACHE_FORMAT or empty string
Outputs map[string]string // Store Adaptor : config FileLineNum bool
Outputs map[string]string // Store Adaptor : config
} }
var ( var (
@ -135,9 +141,13 @@ func init() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
appConfigPath = filepath.Join(workPath, "conf", "app.conf") var filename = "app.conf"
if os.Getenv("BEEGO_RUNMODE") != "" {
filename = os.Getenv("BEEGO_RUNMODE") + ".app.conf"
}
appConfigPath = filepath.Join(workPath, "conf", filename)
if !utils.FileExists(appConfigPath) { if !utils.FileExists(appConfigPath) {
appConfigPath = filepath.Join(AppPath, "conf", "app.conf") appConfigPath = filepath.Join(AppPath, "conf", filename)
if !utils.FileExists(appConfigPath) { if !utils.FileExists(appConfigPath) {
AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()} AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()}
return return
@ -176,13 +186,18 @@ func recoverPanic(ctx *context.Context) {
if BConfig.RunMode == DEV && BConfig.EnableErrorsRender { if BConfig.RunMode == DEV && BConfig.EnableErrorsRender {
showErr(err, ctx, stack) showErr(err, ctx, stack)
} }
if ctx.Output.Status != 0 {
ctx.ResponseWriter.WriteHeader(ctx.Output.Status)
} else {
ctx.ResponseWriter.WriteHeader(500)
}
} }
} }
func newBConfig() *Config { func newBConfig() *Config {
return &Config{ return &Config{
AppName: "beego", AppName: "beego",
RunMode: DEV, RunMode: PROD,
RouterCaseSensitive: true, RouterCaseSensitive: true,
ServerName: "beegoServer:" + VERSION, ServerName: "beegoServer:" + VERSION,
RecoverPanic: true, RecoverPanic: true,
@ -197,6 +212,9 @@ func newBConfig() *Config {
ServerTimeOut: 0, ServerTimeOut: 0,
ListenTCP4: false, ListenTCP4: false,
EnableHTTP: true, EnableHTTP: true,
AutoTLS: false,
Domains: []string{},
TLSCacheDir: ".",
HTTPAddr: "", HTTPAddr: "",
HTTPPort: 8080, HTTPPort: 8080,
EnableHTTPS: false, EnableHTTPS: false,
@ -234,16 +252,17 @@ func newBConfig() *Config {
SessionCookieLifeTime: 0, //set cookie default is the browser life SessionCookieLifeTime: 0, //set cookie default is the browser life
SessionAutoSetCookie: true, SessionAutoSetCookie: true,
SessionDomain: "", SessionDomain: "",
SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers
SessionNameInHTTPHeader: "Beegosessionid", SessionNameInHTTPHeader: "Beegosessionid",
SessionEnableSidInURLQuery: false, // enable get the sessionId from Url Query params SessionEnableSidInURLQuery: false, // enable get the sessionId from Url Query params
}, },
}, },
Log: LogConfig{ Log: LogConfig{
AccessLogs: false, AccessLogs: false,
AccessLogsFormat: "", EnableStaticLogs: false,
FileLineNum: true, AccessLogsFormat: "APACHE_FORMAT",
Outputs: map[string]string{"console": ""}, FileLineNum: true,
Outputs: map[string]string{"console": ""},
}, },
} }
} }

View File

@ -150,12 +150,12 @@ func ExpandValueEnv(value string) (realValue string) {
} }
key := "" key := ""
defalutV := "" defaultV := ""
// value start with "${" // value start with "${"
for i := 2; i < vLen; i++ { for i := 2; i < vLen; i++ {
if value[i] == '|' && (i+1 < vLen && value[i+1] == '|') { if value[i] == '|' && (i+1 < vLen && value[i+1] == '|') {
key = value[2:i] key = value[2:i]
defalutV = value[i+2 : vLen-1] // other string is default value. defaultV = value[i+2 : vLen-1] // other string is default value.
break break
} else if value[i] == '}' { } else if value[i] == '}' {
key = value[2:i] key = value[2:i]
@ -165,7 +165,7 @@ func ExpandValueEnv(value string) (realValue string) {
realValue = os.Getenv(key) realValue = os.Getenv(key)
if realValue == "" { if realValue == "" {
realValue = defalutV realValue = defaultV
} }
return return

View File

@ -126,7 +126,7 @@ func (c *fakeConfigContainer) SaveConfigFile(filename string) error {
var _ Configer = new(fakeConfigContainer) var _ Configer = new(fakeConfigContainer)
// NewFakeConfig return a fake Congiger // NewFakeConfig return a fake Configer
func NewFakeConfig() Configer { func NewFakeConfig() Configer {
return &fakeConfigContainer{ return &fakeConfigContainer{
data: make(map[string]string), data: make(map[string]string),

View File

@ -78,15 +78,37 @@ func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, e
} }
} }
section := defaultSection section := defaultSection
tmpBuf := bytes.NewBuffer(nil)
for { for {
line, _, err := buf.ReadLine() tmpBuf.Reset()
if err == io.EOF {
shouldBreak := false
for {
tmp, isPrefix, err := buf.ReadLine()
if err == io.EOF {
shouldBreak = true
break
}
//It might be a good idea to throw a error on all unknonw errors?
if _, ok := err.(*os.PathError); ok {
return nil, err
}
tmpBuf.Write(tmp)
if isPrefix {
continue
}
if !isPrefix {
break
}
}
if shouldBreak {
break break
} }
//It might be a good idea to throw a error on all unknonw errors?
if _, ok := err.(*os.PathError); ok { line := tmpBuf.Bytes()
return nil, err
}
line = bytes.TrimSpace(line) line = bytes.TrimSpace(line)
if bytes.Equal(line, bEmpty) { if bytes.Equal(line, bEmpty) {
continue continue
@ -215,7 +237,7 @@ func (c *IniConfigContainer) Bool(key string) (bool, error) {
} }
// DefaultBool returns the boolean value for a given key. // DefaultBool returns the boolean value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool { func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool {
v, err := c.Bool(key) v, err := c.Bool(key)
if err != nil { if err != nil {
@ -230,7 +252,7 @@ func (c *IniConfigContainer) Int(key string) (int, error) {
} }
// DefaultInt returns the integer value for a given key. // DefaultInt returns the integer value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int { func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int {
v, err := c.Int(key) v, err := c.Int(key)
if err != nil { if err != nil {
@ -245,7 +267,7 @@ func (c *IniConfigContainer) Int64(key string) (int64, error) {
} }
// DefaultInt64 returns the int64 value for a given key. // DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 { func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
v, err := c.Int64(key) v, err := c.Int64(key)
if err != nil { if err != nil {
@ -260,7 +282,7 @@ func (c *IniConfigContainer) Float(key string) (float64, error) {
} }
// DefaultFloat returns the float64 value for a given key. // DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 { func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
v, err := c.Float(key) v, err := c.Float(key)
if err != nil { if err != nil {
@ -275,7 +297,7 @@ func (c *IniConfigContainer) String(key string) string {
} }
// DefaultString returns the string value for a given key. // DefaultString returns the string value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *IniConfigContainer) DefaultString(key string, defaultval string) string { func (c *IniConfigContainer) DefaultString(key string, defaultval string) string {
v := c.String(key) v := c.String(key)
if v == "" { if v == "" {
@ -295,7 +317,7 @@ func (c *IniConfigContainer) Strings(key string) []string {
} }
// DefaultStrings returns the []string value for a given key. // DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string { func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string {
v := c.Strings(key) v := c.Strings(key)
if v == nil { if v == nil {
@ -314,7 +336,7 @@ func (c *IniConfigContainer) GetSection(section string) (map[string]string, erro
// SaveConfigFile save the config into file. // SaveConfigFile save the config into file.
// //
// BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Funcation. // BUG(env): The environment variable config item will be saved with real value in SaveConfigFile Function.
func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
// Write configuration file by filename. // Write configuration file by filename.
f, err := os.Create(filename) f, err := os.Create(filename)

View File

@ -101,7 +101,7 @@ func (c *JSONConfigContainer) Int(key string) (int, error) {
} }
// DefaultInt returns the integer value for a given key. // DefaultInt returns the integer value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *JSONConfigContainer) DefaultInt(key string, defaultval int) int { func (c *JSONConfigContainer) DefaultInt(key string, defaultval int) int {
if v, err := c.Int(key); err == nil { if v, err := c.Int(key); err == nil {
return v return v
@ -122,7 +122,7 @@ func (c *JSONConfigContainer) Int64(key string) (int64, error) {
} }
// DefaultInt64 returns the int64 value for a given key. // DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *JSONConfigContainer) DefaultInt64(key string, defaultval int64) int64 { func (c *JSONConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
if v, err := c.Int64(key); err == nil { if v, err := c.Int64(key); err == nil {
return v return v
@ -143,7 +143,7 @@ func (c *JSONConfigContainer) Float(key string) (float64, error) {
} }
// DefaultFloat returns the float64 value for a given key. // DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float64 { func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
if v, err := c.Float(key); err == nil { if v, err := c.Float(key); err == nil {
return v return v
@ -163,7 +163,7 @@ func (c *JSONConfigContainer) String(key string) string {
} }
// DefaultString returns the string value for a given key. // DefaultString returns the string value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string { func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string {
// TODO FIXME should not use "" to replace non existence // TODO FIXME should not use "" to replace non existence
if v := c.String(key); v != "" { if v := c.String(key); v != "" {
@ -182,7 +182,7 @@ func (c *JSONConfigContainer) Strings(key string) []string {
} }
// DefaultStrings returns the []string value for a given key. // DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string { func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string {
if v := c.Strings(key); v != nil { if v := c.Strings(key); v != nil {
return v return v

View File

@ -216,7 +216,7 @@ func TestJson(t *testing.T) {
t.Error("unknown keys should return an error when expecting a Bool") t.Error("unknown keys should return an error when expecting a Bool")
} }
if !jsonconf.DefaultBool("unknow", true) { if !jsonconf.DefaultBool("unknown", true) {
t.Error("unknown keys with default value wrong") t.Error("unknown keys with default value wrong")
} }
} }

View File

@ -102,7 +102,7 @@ func (c *ConfigContainer) Int(key string) (int, error) {
} }
// DefaultInt returns the integer value for a given key. // DefaultInt returns the integer value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { func (c *ConfigContainer) DefaultInt(key string, defaultval int) int {
v, err := c.Int(key) v, err := c.Int(key)
if err != nil { if err != nil {
@ -117,7 +117,7 @@ func (c *ConfigContainer) Int64(key string) (int64, error) {
} }
// DefaultInt64 returns the int64 value for a given key. // DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
v, err := c.Int64(key) v, err := c.Int64(key)
if err != nil { if err != nil {
@ -133,7 +133,7 @@ func (c *ConfigContainer) Float(key string) (float64, error) {
} }
// DefaultFloat returns the float64 value for a given key. // DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
v, err := c.Float(key) v, err := c.Float(key)
if err != nil { if err != nil {
@ -151,7 +151,7 @@ func (c *ConfigContainer) String(key string) string {
} }
// DefaultString returns the string value for a given key. // DefaultString returns the string value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *ConfigContainer) DefaultString(key string, defaultval string) string { func (c *ConfigContainer) DefaultString(key string, defaultval string) string {
v := c.String(key) v := c.String(key)
if v == "" { if v == "" {
@ -170,7 +170,7 @@ func (c *ConfigContainer) Strings(key string) []string {
} }
// DefaultStrings returns the []string value for a given key. // DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string {
v := c.Strings(key) v := c.Strings(key)
if v == nil { if v == nil {

View File

@ -119,7 +119,7 @@ func parseYML(buf []byte) (cnf map[string]interface{}, err error) {
// ConfigContainer A Config represents the yaml configuration. // ConfigContainer A Config represents the yaml configuration.
type ConfigContainer struct { type ConfigContainer struct {
data map[string]interface{} data map[string]interface{}
sync.Mutex sync.RWMutex
} }
// Bool returns the boolean value for a given key. // Bool returns the boolean value for a given key.
@ -154,7 +154,7 @@ func (c *ConfigContainer) Int(key string) (int, error) {
} }
// DefaultInt returns the integer value for a given key. // DefaultInt returns the integer value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { func (c *ConfigContainer) DefaultInt(key string, defaultval int) int {
v, err := c.Int(key) v, err := c.Int(key)
if err != nil { if err != nil {
@ -174,7 +174,7 @@ func (c *ConfigContainer) Int64(key string) (int64, error) {
} }
// DefaultInt64 returns the int64 value for a given key. // DefaultInt64 returns the int64 value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 {
v, err := c.Int64(key) v, err := c.Int64(key)
if err != nil { if err != nil {
@ -198,7 +198,7 @@ func (c *ConfigContainer) Float(key string) (float64, error) {
} }
// DefaultFloat returns the float64 value for a given key. // DefaultFloat returns the float64 value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 {
v, err := c.Float(key) v, err := c.Float(key)
if err != nil { if err != nil {
@ -218,7 +218,7 @@ func (c *ConfigContainer) String(key string) string {
} }
// DefaultString returns the string value for a given key. // DefaultString returns the string value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *ConfigContainer) DefaultString(key string, defaultval string) string { func (c *ConfigContainer) DefaultString(key string, defaultval string) string {
v := c.String(key) v := c.String(key)
if v == "" { if v == "" {
@ -237,7 +237,7 @@ func (c *ConfigContainer) Strings(key string) []string {
} }
// DefaultStrings returns the []string value for a given key. // DefaultStrings returns the []string value for a given key.
// if err != nil return defaltval // if err != nil return defaultval
func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string {
v := c.Strings(key) v := c.Strings(key)
if v == nil { if v == nil {
@ -285,9 +285,28 @@ func (c *ConfigContainer) getData(key string) (interface{}, error) {
if len(key) == 0 { if len(key) == 0 {
return nil, errors.New("key is empty") return nil, errors.New("key is empty")
} }
c.RLock()
defer c.RUnlock()
if v, ok := c.data[key]; ok { keys := strings.Split(key, ".")
return v, nil tmpData := c.data
for idx, k := range keys {
if v, ok := tmpData[k]; ok {
switch v.(type) {
case map[string]interface{}:
{
tmpData = v.(map[string]interface{})
if idx == len(keys) - 1 {
return tmpData, nil
}
}
default:
{
return v, nil
}
}
}
} }
return nil, fmt.Errorf("not exist key %q", key) return nil, fmt.Errorf("not exist key %q", key)
} }

View File

@ -48,15 +48,15 @@ func TestAssignConfig_02(t *testing.T) {
_BConfig := &Config{} _BConfig := &Config{}
bs, _ := json.Marshal(newBConfig()) bs, _ := json.Marshal(newBConfig())
jsonMap := map[string]interface{}{} jsonMap := M{}
json.Unmarshal(bs, &jsonMap) json.Unmarshal(bs, &jsonMap)
configMap := map[string]interface{}{} configMap := M{}
for k, v := range jsonMap { for k, v := range jsonMap {
if reflect.TypeOf(v).Kind() == reflect.Map { if reflect.TypeOf(v).Kind() == reflect.Map {
for k1, v1 := range v.(map[string]interface{}) { for k1, v1 := range v.(M) {
if reflect.TypeOf(v1).Kind() == reflect.Map { if reflect.TypeOf(v1).Kind() == reflect.Map {
for k2, v2 := range v1.(map[string]interface{}) { for k2, v2 := range v1.(M) {
configMap[k2] = v2 configMap[k2] = v2
} }
} else { } else {
@ -75,7 +75,7 @@ func TestAssignConfig_02(t *testing.T) {
jcf := &config.JSONConfig{} jcf := &config.JSONConfig{}
bs, _ = json.Marshal(configMap) bs, _ = json.Marshal(configMap)
ac, _ := jcf.ParseData([]byte(bs)) ac, _ := jcf.ParseData(bs)
for _, i := range []interface{}{_BConfig, &_BConfig.Listen, &_BConfig.WebConfig, &_BConfig.Log, &_BConfig.WebConfig.Session} { for _, i := range []interface{}{_BConfig, &_BConfig.Listen, &_BConfig.WebConfig, &_BConfig.Log, &_BConfig.WebConfig.Session} {
assignSingleConfig(i, ac) assignSingleConfig(i, ac)

View File

@ -38,6 +38,14 @@ import (
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
//commonly used mime-types
const (
ApplicationJSON = "application/json"
ApplicationXML = "application/xml"
ApplicationYAML = "application/x-yaml"
TextXML = "text/xml"
)
// NewContext return the Context with Input and Output // NewContext return the Context with Input and Output
func NewContext() *Context { func NewContext() *Context {
return &Context{ return &Context{
@ -193,6 +201,7 @@ type Response struct {
http.ResponseWriter http.ResponseWriter
Started bool Started bool
Status int Status int
Elapsed time.Duration
} }
func (r *Response) reset(rw http.ResponseWriter) { func (r *Response) reset(rw http.ResponseWriter) {
@ -244,3 +253,11 @@ func (r *Response) CloseNotify() <-chan bool {
} }
return nil return nil
} }
// Pusher http.Pusher
func (r *Response) Pusher() (pusher http.Pusher) {
if pusher, ok := r.ResponseWriter.(http.Pusher); ok {
return pusher
}
return nil
}

View File

@ -37,6 +37,7 @@ var (
acceptsHTMLRegex = regexp.MustCompile(`(text/html|application/xhtml\+xml)(?:,|$)`) acceptsHTMLRegex = regexp.MustCompile(`(text/html|application/xhtml\+xml)(?:,|$)`)
acceptsXMLRegex = regexp.MustCompile(`(application/xml|text/xml)(?:,|$)`) acceptsXMLRegex = regexp.MustCompile(`(application/xml|text/xml)(?:,|$)`)
acceptsJSONRegex = regexp.MustCompile(`(application/json)(?:,|$)`) acceptsJSONRegex = regexp.MustCompile(`(application/json)(?:,|$)`)
acceptsYAMLRegex = regexp.MustCompile(`(application/x-yaml)(?:,|$)`)
maxParam = 50 maxParam = 50
) )
@ -203,6 +204,10 @@ func (input *BeegoInput) AcceptsXML() bool {
func (input *BeegoInput) AcceptsJSON() bool { func (input *BeegoInput) AcceptsJSON() bool {
return acceptsJSONRegex.MatchString(input.Header("Accept")) return acceptsJSONRegex.MatchString(input.Header("Accept"))
} }
// AcceptsYAML Checks if request accepts json response
func (input *BeegoInput) AcceptsYAML() bool {
return acceptsYAMLRegex.MatchString(input.Header("Accept"))
}
// IP returns request client ip. // IP returns request client ip.
// if in proxy, return first proxy id. // if in proxy, return first proxy id.

View File

@ -30,6 +30,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"time" "time"
yaml "gopkg.in/yaml.v2"
) )
// BeegoOutput does work for sending response header. // BeegoOutput does work for sending response header.
@ -182,8 +184,8 @@ func errorRenderer(err error) Renderer {
} }
// JSON writes json to response body. // JSON writes json to response body.
// if coding is true, it converts utf-8 to \u0000 type. // if encoding is true, it converts utf-8 to \u0000 type.
func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, coding bool) error { func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) error {
output.Header("Content-Type", "application/json; charset=utf-8") output.Header("Content-Type", "application/json; charset=utf-8")
var content []byte var content []byte
var err error var err error
@ -196,12 +198,25 @@ func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, coding bool) e
http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError)
return err return err
} }
if coding { if encoding {
content = []byte(stringsToJSON(string(content))) content = []byte(stringsToJSON(string(content)))
} }
return output.Body(content) return output.Body(content)
} }
// YAML writes yaml to response body.
func (output *BeegoOutput) YAML(data interface{}) error {
output.Header("Content-Type", "application/x-yaml; charset=utf-8")
var content []byte
var err error
content, err = yaml.Marshal(data)
if err != nil {
http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError)
return err
}
return output.Body(content)
}
// JSONP writes jsonp to response body. // JSONP writes jsonp to response body.
func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error { func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error {
output.Header("Content-Type", "application/javascript; charset=utf-8") output.Header("Content-Type", "application/javascript; charset=utf-8")
@ -245,6 +260,19 @@ func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error {
return output.Body(content) return output.Body(content)
} }
// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header
func (output *BeegoOutput) ServeFormatted(data interface{}, hasIndent bool, hasEncode ...bool) {
accept := output.Context.Input.Header("Accept")
switch accept {
case ApplicationYAML:
output.YAML(data)
case ApplicationXML, TextXML:
output.XML(data, hasIndent)
default:
output.JSON(data, hasIndent, len(hasEncode) > 0 && hasEncode[0])
}
}
// Download forces response for download file. // Download forces response for download file.
// it prepares the download response header automatically. // it prepares the download response header automatically.
func (output *BeegoOutput) Download(file string, filename ...string) { func (output *BeegoOutput) Download(file string, filename ...string) {
@ -260,7 +288,20 @@ func (output *BeegoOutput) Download(file string, filename ...string) {
} else { } else {
fName = filepath.Base(file) fName = filepath.Base(file)
} }
output.Header("Content-Disposition", "attachment; filename="+url.QueryEscape(fName)) //https://tools.ietf.org/html/rfc6266#section-4.3
fn := url.PathEscape(fName)
if fName == fn {
fn = "filename=" + fn
} else {
/**
The parameters "filename" and "filename*" differ only in that
"filename*" uses the encoding defined in [RFC5987], allowing the use
of characters not present in the ISO-8859-1 character set
([ISO-8859-1]).
*/
fn = "filename=" + fName + "; filename*=utf-8''" + fn
}
output.Header("Content-Disposition", "attachment; "+fn)
output.Header("Content-Description", "File Transfer") output.Header("Content-Description", "File Transfer")
output.Header("Content-Type", "application/octet-stream") output.Header("Content-Type", "application/octet-stream")
output.Header("Content-Transfer-Encoding", "binary") output.Header("Content-Transfer-Encoding", "binary")
@ -325,13 +366,13 @@ func (output *BeegoOutput) IsForbidden() bool {
} }
// IsNotFound returns boolean of this request is not found. // IsNotFound returns boolean of this request is not found.
// HTTP 404 means forbidden. // HTTP 404 means not found.
func (output *BeegoOutput) IsNotFound() bool { func (output *BeegoOutput) IsNotFound() bool {
return output.Status == 404 return output.Status == 404
} }
// IsClientError returns boolean of this request client sends error data. // IsClientError returns boolean of this request client sends error data.
// HTTP 4xx means forbidden. // HTTP 4xx means client error.
func (output *BeegoOutput) IsClientError() bool { func (output *BeegoOutput) IsClientError() bool {
return output.Status >= 400 && output.Status < 500 return output.Status >= 400 && output.Status < 500
} }
@ -350,6 +391,11 @@ func stringsToJSON(str string) string {
jsons.WriteRune(r) jsons.WriteRune(r)
} else { } else {
jsons.WriteString("\\u") jsons.WriteString("\\u")
if rint < 0x100 {
jsons.WriteString("00")
} else if rint < 0x1000 {
jsons.WriteString("0")
}
jsons.WriteString(strconv.FormatInt(int64(rint), 16)) jsons.WriteString(strconv.FormatInt(int64(rint), 16))
} }
} }

View File

@ -7,7 +7,7 @@ import (
// MethodParamOption defines a func which apply options on a MethodParam // MethodParamOption defines a func which apply options on a MethodParam
type MethodParamOption func(*MethodParam) type MethodParamOption func(*MethodParam)
// IsRequired indicates that this param is required and can not be ommited from the http request // IsRequired indicates that this param is required and can not be omitted from the http request
var IsRequired MethodParamOption = func(p *MethodParam) { var IsRequired MethodParamOption = func(p *MethodParam) {
p.required = true p.required = true
} }

View File

@ -75,7 +75,7 @@ func checkParser(def testDefinition, t *testing.T, methodParam ...MethodParam) {
} }
convResult, err := safeConvert(reflect.ValueOf(result), toType) convResult, err := safeConvert(reflect.ValueOf(result), toType)
if err != nil { if err != nil {
t.Errorf("Convertion error for %v. from value: %v, toType: %v, error: %v", def.strValue, result, toType, err) t.Errorf("Conversion error for %v. from value: %v, toType: %v, error: %v", def.strValue, result, toType, err)
return return
} }
if !reflect.DeepEqual(convResult.Interface(), def.expectedValue) { if !reflect.DeepEqual(convResult.Interface(), def.expectedValue) {

View File

@ -32,24 +32,44 @@ import (
"github.com/astaxie/beego/session" "github.com/astaxie/beego/session"
) )
//commonly used mime-types
const (
applicationJSON = "application/json"
applicationXML = "application/xml"
textXML = "text/xml"
)
var ( var (
// ErrAbort custom error when user stop request handler manually. // ErrAbort custom error when user stop request handler manually.
ErrAbort = errors.New("User stop run") ErrAbort = errors.New("user stop run")
// GlobalControllerRouter store comments with controller. pkgpath+controller:comments // GlobalControllerRouter store comments with controller. pkgpath+controller:comments
GlobalControllerRouter = make(map[string][]ControllerComments) GlobalControllerRouter = make(map[string][]ControllerComments)
) )
// ControllerFilter store the filter for controller
type ControllerFilter struct {
Pattern string
Pos int
Filter FilterFunc
ReturnOnOutput bool
ResetParams bool
}
// ControllerFilterComments store the comment for controller level filter
type ControllerFilterComments struct {
Pattern string
Pos int
Filter string // NOQA
ReturnOnOutput bool
ResetParams bool
}
// ControllerImportComments store the import comment for controller needed
type ControllerImportComments struct {
ImportPath string
ImportAlias string
}
// ControllerComments store the comment for the controller method // ControllerComments store the comment for the controller method
type ControllerComments struct { type ControllerComments struct {
Method string Method string
Router string Router string
Filters []*ControllerFilter
ImportComments []*ControllerImportComments
FilterComments []*ControllerFilterComments
AllowHTTPMethods []string AllowHTTPMethods []string
Params []map[string]string Params []map[string]string
MethodParams []*param.MethodParam MethodParams []*param.MethodParam
@ -73,7 +93,6 @@ type Controller struct {
controllerName string controllerName string
actionName string actionName string
methodMapping map[string]func() //method:routertree methodMapping map[string]func() //method:routertree
gotofunc string
AppController interface{} AppController interface{}
// template data // template data
@ -136,37 +155,37 @@ func (c *Controller) Finish() {}
// Get adds a request function to handle GET request. // Get adds a request function to handle GET request.
func (c *Controller) Get() { func (c *Controller) Get() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
} }
// Post adds a request function to handle POST request. // Post adds a request function to handle POST request.
func (c *Controller) Post() { func (c *Controller) Post() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
} }
// Delete adds a request function to handle DELETE request. // Delete adds a request function to handle DELETE request.
func (c *Controller) Delete() { func (c *Controller) Delete() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
} }
// Put adds a request function to handle PUT request. // Put adds a request function to handle PUT request.
func (c *Controller) Put() { func (c *Controller) Put() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
} }
// Head adds a request function to handle HEAD request. // Head adds a request function to handle HEAD request.
func (c *Controller) Head() { func (c *Controller) Head() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
} }
// Patch adds a request function to handle PATCH request. // Patch adds a request function to handle PATCH request.
func (c *Controller) Patch() { func (c *Controller) Patch() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
} }
// Options adds a request function to handle OPTIONS request. // Options adds a request function to handle OPTIONS request.
func (c *Controller) Options() { func (c *Controller) Options() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
} }
// HandlerFunc call function with the name // HandlerFunc call function with the name
@ -272,9 +291,23 @@ func (c *Controller) viewPath() string {
// Redirect sends the redirection response to url with status code. // Redirect sends the redirection response to url with status code.
func (c *Controller) Redirect(url string, code int) { func (c *Controller) Redirect(url string, code int) {
logAccess(c.Ctx, nil, code)
c.Ctx.Redirect(code, url) c.Ctx.Redirect(code, url)
} }
// SetData set the data depending on the accepted
func (c *Controller) SetData(data interface{}) {
accept := c.Ctx.Input.Header("Accept")
switch accept {
case context.ApplicationYAML:
c.Data["yaml"] = data
case context.ApplicationXML, context.TextXML:
c.Data["xml"] = data
default:
c.Data["json"] = data
}
}
// Abort stops controller handler and show the error data if code is defined in ErrorMap or code string. // Abort stops controller handler and show the error data if code is defined in ErrorMap or code string.
func (c *Controller) Abort(code string) { func (c *Controller) Abort(code string) {
status, err := strconv.Atoi(code) status, err := strconv.Atoi(code)
@ -317,47 +350,35 @@ func (c *Controller) URLFor(endpoint string, values ...interface{}) string {
// ServeJSON sends a json response with encoding charset. // ServeJSON sends a json response with encoding charset.
func (c *Controller) ServeJSON(encoding ...bool) { func (c *Controller) ServeJSON(encoding ...bool) {
var ( var (
hasIndent = true hasIndent = BConfig.RunMode != PROD
hasEncoding = false hasEncoding = len(encoding) > 0 && encoding[0]
) )
if BConfig.RunMode == PROD {
hasIndent = false
}
if len(encoding) > 0 && encoding[0] {
hasEncoding = true
}
c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding) c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding)
} }
// ServeJSONP sends a jsonp response. // ServeJSONP sends a jsonp response.
func (c *Controller) ServeJSONP() { func (c *Controller) ServeJSONP() {
hasIndent := true hasIndent := BConfig.RunMode != PROD
if BConfig.RunMode == PROD {
hasIndent = false
}
c.Ctx.Output.JSONP(c.Data["jsonp"], hasIndent) c.Ctx.Output.JSONP(c.Data["jsonp"], hasIndent)
} }
// ServeXML sends xml response. // ServeXML sends xml response.
func (c *Controller) ServeXML() { func (c *Controller) ServeXML() {
hasIndent := true hasIndent := BConfig.RunMode != PROD
if BConfig.RunMode == PROD {
hasIndent = false
}
c.Ctx.Output.XML(c.Data["xml"], hasIndent) c.Ctx.Output.XML(c.Data["xml"], hasIndent)
} }
// ServeFormatted serve Xml OR Json, depending on the value of the Accept header // ServeYAML sends yaml response.
func (c *Controller) ServeFormatted() { func (c *Controller) ServeYAML() {
accept := c.Ctx.Input.Header("Accept") c.Ctx.Output.YAML(c.Data["yaml"])
switch accept { }
case applicationJSON:
c.ServeJSON() // ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header
case applicationXML, textXML: func (c *Controller) ServeFormatted(encoding ...bool) {
c.ServeXML() hasIndent := BConfig.RunMode != PROD
default: hasEncoding := len(encoding) > 0 && encoding[0]
c.ServeJSON() c.Ctx.Output.ServeFormatted(c.Data, hasIndent, hasEncoding)
}
} }
// Input returns the input data map from POST or PUT request body and query string. // Input returns the input data map from POST or PUT request body and query string.

View File

@ -28,7 +28,7 @@ import (
) )
const ( const (
errorTypeHandler = iota errorTypeHandler = iota
errorTypeController errorTypeController
) )
@ -93,11 +93,6 @@ func showErr(err interface{}, ctx *context.Context, stack string) {
"BeegoVersion": VERSION, "BeegoVersion": VERSION,
"GoVersion": runtime.Version(), "GoVersion": runtime.Version(),
} }
if ctx.Output.Status != 0 {
ctx.ResponseWriter.WriteHeader(ctx.Output.Status)
} else {
ctx.ResponseWriter.WriteHeader(500)
}
t.Execute(ctx.ResponseWriter, data) t.Execute(ctx.ResponseWriter, data)
} }
@ -366,7 +361,7 @@ func gatewayTimeout(rw http.ResponseWriter, r *http.Request) {
func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errContent string) { func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errContent string) {
t, _ := template.New("beegoerrortemp").Parse(errtpl) t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := map[string]interface{}{ data := M{
"Title": http.StatusText(errCode), "Title": http.StatusText(errCode),
"BeegoVersion": VERSION, "BeegoVersion": VERSION,
"Content": template.HTML(errContent), "Content": template.HTML(errContent),
@ -439,6 +434,9 @@ func exception(errCode string, ctx *context.Context) {
} }
func executeError(err *errorInfo, ctx *context.Context, code int) { func executeError(err *errorInfo, ctx *context.Context, code int) {
//make sure to log the error in the access log
logAccess(ctx, nil, code)
if err.errorType == errorTypeHandler { if err.errorType == errorTypeHandler {
ctx.ResponseWriter.WriteHeader(code) ctx.ResponseWriter.WriteHeader(code)
err.handler(ctx.ResponseWriter, ctx.Request) err.handler(ctx.ResponseWriter, ctx.Request)

74
fs.go Normal file
View File

@ -0,0 +1,74 @@
package beego
import (
"net/http"
"os"
"path/filepath"
)
type FileSystem struct {
}
func (d FileSystem) Open(name string) (http.File, error) {
return os.Open(name)
}
// Walk walks the file tree rooted at root in filesystem, calling walkFn for each file or
// directory in the tree, including root. All errors that arise visiting files
// and directories are filtered by walkFn.
func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error {
f, err := fs.Open(root)
if err != nil {
return err
}
info, err := f.Stat()
if err != nil {
err = walkFn(root, nil, err)
} else {
err = walk(fs, root, info, walkFn)
}
if err == filepath.SkipDir {
return nil
}
return err
}
// walk recursively descends path, calling walkFn.
func walk(fs http.FileSystem, path string, info os.FileInfo, walkFn filepath.WalkFunc) error {
var err error
if !info.IsDir() {
return walkFn(path, info, nil)
}
dir, err := fs.Open(path)
if err != nil {
if err1 := walkFn(path, info, err); err1 != nil {
return err1
}
return err
}
defer dir.Close()
dirs, err := dir.Readdir(-1)
err1 := walkFn(path, info, err)
// If err != nil, walk can't walk into this directory.
// err1 != nil means walkFn want walk to skip this directory or stop walking.
// Therefore, if one of err and err1 isn't nil, walk will return.
if err != nil || err1 != nil {
// The caller's behavior is controlled by the return value, which is decided
// by walkFn. walkFn may ignore err and return nil.
// If walkFn returns SkipDir, it will be handled by the caller.
// So walk should return whatever walkFn returns.
return err1
}
for _, fileInfo := range dirs {
filename := filepath.Join(path, fileInfo.Name())
if err = walk(fs, filename, fileInfo, walkFn); err != nil {
if !fileInfo.IsDir() || err != filepath.SkipDir {
return err
}
}
}
return nil
}

39
go.mod Normal file
View File

@ -0,0 +1,39 @@
module github.com/astaxie/beego
require (
github.com/Knetic/govaluate v3.0.0+incompatible // indirect
github.com/OwnLocal/goes v1.0.0
github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd
github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542
github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737
github.com/casbin/casbin v1.7.0
github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58
github.com/couchbase/go-couchbase v0.0.0-20181122212707-3e9b6e1258bb
github.com/couchbase/gomemcached v0.0.0-20181122193126-5125a94a666c // indirect
github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a // indirect
github.com/cupcake/rdb v0.0.0-20161107195141-43ba34106c76 // indirect
github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712 // indirect
github.com/elazarl/go-bindata-assetfs v1.0.0
github.com/go-redis/redis v6.14.2+incompatible
github.com/go-sql-driver/mysql v1.4.1
github.com/gogo/protobuf v1.1.1
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
github.com/gomodule/redigo v2.0.0+incompatible
github.com/lib/pq v1.0.0
github.com/mattn/go-sqlite3 v1.10.0
github.com/pelletier/go-toml v1.2.0 // indirect
github.com/pkg/errors v0.8.0 // indirect
github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726 // indirect
github.com/siddontang/ledisdb v0.0.0-20181029004158-becf5f38d373
github.com/siddontang/rdb v0.0.0-20150307021120-fc89ed2e418d // indirect
github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec
github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c // indirect
github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b // indirect
golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a // indirect
gopkg.in/yaml.v2 v2.2.1
)
replace golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85 => github.com/golang/crypto v0.0.0-20181127143415-eb0de9b17e85
replace gopkg.in/yaml.v2 v2.2.1 => github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d

68
go.sum Normal file
View File

@ -0,0 +1,68 @@
github.com/Knetic/govaluate v3.0.0+incompatible h1:7o6+MAPhYTCF0+fdvoz1xDedhRb4f6s9Tn1Tt7/WTEg=
github.com/Knetic/govaluate v3.0.0+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0=
github.com/OwnLocal/goes v1.0.0/go.mod h1:8rIFjBGTue3lCU0wplczcUgt9Gxgrkkrw7etMIcn8TM=
github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd h1:jZtX5jh5IOMu0fpOTC3ayh6QGSPJ/KWOv1lgPvbRw1M=
github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd/go.mod h1:1b+Y/CofkYwXMUU0OhQqGvsY2Bvgr4j6jfT699wyZKQ=
github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542 h1:nYXb+3jF6Oq/j8R/y90XrKpreCxIalBWfeyeKymgOPk=
github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542/go.mod h1:kSeGC/p1AbBiEp5kat81+DSQrZenVBZXklMLaELspWU=
github.com/belogik/goes v0.0.0-20151229125003-e54d722c3aff h1:/kO0p2RTGLB8R5gub7ps0GmYpB2O8LXEoPq8tzFDCUI=
github.com/belogik/goes v0.0.0-20151229125003-e54d722c3aff/go.mod h1:PhH1ZhyCzHKt4uAasyx+ljRCgoezetRNf59CUtwUkqY=
github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737 h1:rRISKWyXfVxvoa702s91Zl5oREZTrR3yv+tXrrX7G/g=
github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60=
github.com/casbin/casbin v1.7.0 h1:PuzlE8w0JBg/DhIqnkF1Dewf3z+qmUZMVN07PonvVUQ=
github.com/casbin/casbin v1.7.0/go.mod h1:c67qKN6Oum3UF5Q1+BByfFxkwKvhwW57ITjqwtzR1KE=
github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 h1:F1EaeKL/ta07PY/k9Os/UFtwERei2/XzGemhpGnBKNg=
github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58/go.mod h1:EOBUe0h4xcZ5GoxqC5SDxFQ8gwyZPKQoEzownBlhI80=
github.com/couchbase/go-couchbase v0.0.0-20181122212707-3e9b6e1258bb h1:w3RapLhkA5+km9Z8vUkC6VCaskduJXvXwJg5neKnfDU=
github.com/couchbase/go-couchbase v0.0.0-20181122212707-3e9b6e1258bb/go.mod h1:TWI8EKQMs5u5jLKW/tsb9VwauIrMIxQG1r5fMsswK5U=
github.com/couchbase/gomemcached v0.0.0-20181122193126-5125a94a666c h1:K4FIibkr4//ziZKOKmt4RL0YImuTjLLBtwElf+F2lSQ=
github.com/couchbase/gomemcached v0.0.0-20181122193126-5125a94a666c/go.mod h1:srVSlQLB8iXBVXHgnqemxUXqN6FCvClgCMPCsjBDR7c=
github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a h1:Y5XsLCEhtEI8qbD9RP3Qlv5FXdTDHxZM9UPUnMRgBp8=
github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a/go.mod h1:BQwMFlJzDjFDG3DJUdU0KORxn88UlsOULuxLExMh3Hs=
github.com/cupcake/rdb v0.0.0-20161107195141-43ba34106c76 h1:Lgdd/Qp96Qj8jqLpq2cI1I1X7BJnu06efS+XkhRoLUQ=
github.com/cupcake/rdb v0.0.0-20161107195141-43ba34106c76/go.mod h1:vYwsqCOLxGiisLwp9rITslkFNpZD5rz43tf41QFkTWY=
github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712 h1:aaQcKT9WumO6JEJcRyTqFVq4XUZiUcKR2/GI31TOcz8=
github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M=
github.com/elazarl/go-bindata-assetfs v1.0.0 h1:G/bYguwHIzWq9ZoyUQqrjTmJbbYn3j3CKKpKinvZLFk=
github.com/elazarl/go-bindata-assetfs v1.0.0/go.mod h1:v+YaWX3bdea5J/mo8dSETolEo7R71Vk1u8bnjau5yw4=
github.com/go-redis/redis v6.14.2+incompatible h1:UE9pLhzmWf+xHNmZsoccjXosPicuiNaInPgym8nzfg0=
github.com/go-redis/redis v6.14.2+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA=
github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA=
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d h1:xy93KVe+KrIIwWDEAfQBdIfsiHJkepbYsDr+VY3g9/o=
github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
github.com/gogo/protobuf v1.1.1 h1:72R+M5VuhED/KujmZVcIquuo8mBgX4oVda//DQb3PXo=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/golang/crypto v0.0.0-20181127143415-eb0de9b17e85 h1:B7ZbAFz7NOmvpUE5RGtu3u0WIizy5GdvbNpEf4RPnWs=
github.com/golang/crypto v0.0.0-20181127143415-eb0de9b17e85/go.mod h1:uZvAcrsnNaCxlh1HorK5dUQHGmEKPh2H/Rl1kehswPo=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0=
github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4=
github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o=
github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc=
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726 h1:xT+JlYxNGqyT+XcU8iUrN18JYed2TvG9yN5ULG2jATM=
github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726/go.mod h1:3yhqj7WBBfRhbBlzyOC3gUxftwsU0u8gqevxwIHQpMw=
github.com/siddontang/ledisdb v0.0.0-20181029004158-becf5f38d373 h1:p6IxqQMjab30l4lb9mmkIkkcE1yv6o0SKbPhW5pxqHI=
github.com/siddontang/ledisdb v0.0.0-20181029004158-becf5f38d373/go.mod h1:mF1DpOSOUiJRMR+FDqaqu3EBqrybQtrDDszLUZ6oxPg=
github.com/siddontang/rdb v0.0.0-20150307021120-fc89ed2e418d h1:NVwnfyR3rENtlz62bcrkXME3INVUa4lcdGt+opvxExs=
github.com/siddontang/rdb v0.0.0-20150307021120-fc89ed2e418d/go.mod h1:AMEsy7v5z92TR1JKMkLLoaOQk++LVnOKL3ScbJ8GNGA=
github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec h1:q6XVwXmKvCRHRqesF3cSv6lNqqHi0QWOvgDlSohg8UA=
github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec/go.mod h1:QBvMkMya+gXctz3kmljlUCu/yB3GZ6oee+dUozsezQE=
github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c h1:3eGShk3EQf5gJCYW+WzA0TEJQd37HLOmlYF7N0YJwv0=
github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c/go.mod h1:Z4AUp2Km+PwemOoO/VB5AOx9XSsIItzFjoJlOSiYmn0=
github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b h1:0Ve0/CCjiAiyKddUMUn3RwIGlq2iTW4GuVzyoKBYO/8=
github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b/go.mod h1:Q12BUT7DqIlHRmgv3RskH+UCM/4eqVMgI0EMmlSpAXc=
golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85 h1:et7+NAX3lLIk5qUCTA9QelBjGE/NkhzYw/mhnr0s7nI=
golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a h1:gOpx8G595UYyvj8UK4+OFyY4rx037g3fmfhe5SasG3U=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View File

@ -28,12 +28,11 @@ func (c *graceConn) Close() (err error) {
}() }()
c.m.Lock() c.m.Lock()
defer c.m.Unlock()
if c.closed { if c.closed {
c.m.Unlock()
return return
} }
c.server.wg.Done() c.server.wg.Done()
c.closed = true c.closed = true
c.m.Unlock()
return c.Conn.Close() return c.Conn.Close()
} }

View File

@ -2,7 +2,9 @@ package grace
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509"
"fmt" "fmt"
"io/ioutil"
"log" "log"
"net" "net"
"net/http" "net/http"
@ -32,6 +34,11 @@ type Server struct {
// creating a new service goroutine for each. // creating a new service goroutine for each.
// The service goroutines read requests and then call srv.Handler to reply to them. // The service goroutines read requests and then call srv.Handler to reply to them.
func (srv *Server) Serve() (err error) { func (srv *Server) Serve() (err error) {
defer func() {
if r := recover(); r != nil {
log.Println("wait group counter is negative", r)
}
}()
srv.state = StateRunning srv.state = StateRunning
err = srv.Server.Serve(srv.GraceListener) err = srv.Server.Serve(srv.GraceListener)
log.Println(syscall.Getpid(), "Waiting for connections to finish...") log.Println(syscall.Getpid(), "Waiting for connections to finish...")
@ -65,7 +72,7 @@ func (srv *Server) ListenAndServe() (err error) {
log.Println(err) log.Println(err)
return err return err
} }
err = process.Kill() err = process.Signal(syscall.SIGTERM)
if err != nil { if err != nil {
return err return err
} }
@ -114,6 +121,62 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
srv.tlsInnerListener = newGraceListener(l, srv) srv.tlsInnerListener = newGraceListener(l, srv)
srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig) srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig)
if srv.isChild {
process, err := os.FindProcess(os.Getppid())
if err != nil {
log.Println(err)
return err
}
err = process.Signal(syscall.SIGTERM)
if err != nil {
return err
}
}
log.Println(os.Getpid(), srv.Addr)
return srv.Serve()
}
// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls
// Serve to handle requests on incoming mutual TLS connections.
func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) {
addr := srv.Addr
if addr == "" {
addr = ":https"
}
if srv.TLSConfig == nil {
srv.TLSConfig = &tls.Config{}
}
if srv.TLSConfig.NextProtos == nil {
srv.TLSConfig.NextProtos = []string{"http/1.1"}
}
srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return
}
srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
pool := x509.NewCertPool()
data, err := ioutil.ReadFile(trustFile)
if err != nil {
log.Println(err)
return err
}
pool.AppendCertsFromPEM(data)
srv.TLSConfig.ClientCAs = pool
log.Println("Mutual HTTPS")
go srv.handleSignals()
l, err := srv.getListener(addr)
if err != nil {
log.Println(err)
return err
}
srv.tlsInnerListener = newGraceListener(l, srv)
srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig)
if srv.isChild { if srv.isChild {
process, err := os.FindProcess(os.Getppid()) process, err := os.FindProcess(os.Getppid())
if err != nil { if err != nil {

View File

@ -11,7 +11,7 @@ import (
"github.com/astaxie/beego/session" "github.com/astaxie/beego/session"
) )
// // register MIME type with content type
func registerMime() error { func registerMime() error {
for k, v := range mimemaps { for k, v := range mimemaps {
mime.AddExtensionType(k, v) mime.AddExtensionType(k, v)

View File

@ -50,6 +50,7 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"gopkg.in/yaml.v2"
) )
var defaultSetting = BeegoHTTPSettings{ var defaultSetting = BeegoHTTPSettings{
@ -318,6 +319,34 @@ func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest {
return b return b
} }
// XMLBody adds request raw body encoding by XML.
func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) {
if b.req.Body == nil && obj != nil {
byts, err := xml.Marshal(obj)
if err != nil {
return b, err
}
b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
b.req.ContentLength = int64(len(byts))
b.req.Header.Set("Content-Type", "application/xml")
}
return b, nil
}
// YAMLBody adds request raw body encoding by YAML.
func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) {
if b.req.Body == nil && obj != nil {
byts, err := yaml.Marshal(obj)
if err != nil {
return b, err
}
b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
b.req.ContentLength = int64(len(byts))
b.req.Header.Set("Content-Type", "application/x+yaml")
}
return b, nil
}
// JSONBody adds request raw body encoding by JSON. // JSONBody adds request raw body encoding by JSON.
func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) { func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) {
if b.req.Body == nil && obj != nil { if b.req.Body == nil && obj != nil {
@ -417,12 +446,12 @@ func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) {
} }
b.buildURL(paramBody) b.buildURL(paramBody)
url, err := url.Parse(b.url) urlParsed, err := url.Parse(b.url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
b.req.URL = url b.req.URL = urlParsed
trans := b.setting.Transport trans := b.setting.Transport
@ -432,7 +461,7 @@ func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) {
TLSClientConfig: b.setting.TLSClientConfig, TLSClientConfig: b.setting.TLSClientConfig,
Proxy: b.setting.Proxy, Proxy: b.setting.Proxy,
Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout), Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout),
MaxIdleConnsPerHost: -1, MaxIdleConnsPerHost: 100,
} }
} else { } else {
// if b.transport is *http.Transport then set the settings. // if b.transport is *http.Transport then set the settings.
@ -567,6 +596,16 @@ func (b *BeegoHTTPRequest) ToXML(v interface{}) error {
return xml.Unmarshal(data, v) return xml.Unmarshal(data, v)
} }
// ToYAML returns the map that marshals from the body bytes as yaml in response .
// it calls Response inner.
func (b *BeegoHTTPRequest) ToYAML(v interface{}) error {
data, err := b.Bytes()
if err != nil {
return err
}
return yaml.Unmarshal(data, v)
}
// Response executes request client gets response mannually. // Response executes request client gets response mannually.
func (b *BeegoHTTPRequest) Response() (*http.Response, error) { func (b *BeegoHTTPRequest) Response() (*http.Response, error) {
return b.getResponse() return b.getResponse()

View File

@ -16,6 +16,8 @@ package httplib
import ( import (
"io/ioutil" "io/ioutil"
"net"
"net/http"
"os" "os"
"strings" "strings"
"testing" "testing"
@ -161,7 +163,16 @@ func TestWithSetting(t *testing.T) {
var setting BeegoHTTPSettings var setting BeegoHTTPSettings
setting.EnableCookie = true setting.EnableCookie = true
setting.UserAgent = v setting.UserAgent = v
setting.Transport = nil setting.Transport = &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 50,
IdleConnTimeout: 90 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
setting.ReadWriteTimeout = 5 * time.Second setting.ReadWriteTimeout = 5 * time.Second
SetDefaultSetting(setting) SetDefaultSetting(setting)

View File

@ -16,48 +16,57 @@ As of now this logs support console, file,smtp and conn.
First you must import it First you must import it
import ( ```golang
"github.com/astaxie/beego/logs" import (
) "github.com/astaxie/beego/logs"
)
```
Then init a Log (example with console adapter) Then init a Log (example with console adapter)
log := NewLogger(10000) ```golang
log.SetLogger("console", "") log := logs.NewLogger(10000)
log.SetLogger("console", "")
```
> the first params stand for how many channel > the first params stand for how many channel
Use it like this: Use it like this:
log.Trace("trace")
log.Info("info")
log.Warn("warning")
log.Debug("debug")
log.Critical("critical")
```golang
log.Trace("trace")
log.Info("info")
log.Warn("warning")
log.Debug("debug")
log.Critical("critical")
```
## File adapter ## File adapter
Configure file adapter like this: Configure file adapter like this:
log := NewLogger(10000) ```golang
log.SetLogger("file", `{"filename":"test.log"}`) log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test.log"}`)
```
## Conn adapter ## Conn adapter
Configure like this: Configure like this:
log := NewLogger(1000) ```golang
log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`) log := NewLogger(1000)
log.Info("info") log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`)
log.Info("info")
```
## Smtp adapter ## Smtp adapter
Configure like this: Configure like this:
log := NewLogger(10000) ```golang
log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`) log := NewLogger(10000)
log.Critical("sendmail critical") log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`)
time.Sleep(time.Second * 30) log.Critical("sendmail critical")
time.Sleep(time.Second * 30)
```

View File

@ -16,17 +16,19 @@ package logs
import ( import (
"bytes" "bytes"
"strings"
"encoding/json" "encoding/json"
"time"
"fmt" "fmt"
"time"
) )
const ( const (
apacheFormatPattern = "%s - - [%s] \"%s %d %d\" %f %s %s\n" apacheFormatPattern = "%s - - [%s] \"%s %d %d\" %f %s %s"
apacheFormat = "APACHE_FORMAT" apacheFormat = "APACHE_FORMAT"
jsonFormat = "JSON_FORMAT" jsonFormat = "JSON_FORMAT"
) )
// AccessLogRecord struct for holding access log data.
type AccessLogRecord struct { type AccessLogRecord struct {
RemoteAddr string `json:"remote_addr"` RemoteAddr string `json:"remote_addr"`
RequestTime time.Time `json:"request_time"` RequestTime time.Time `json:"request_time"`
@ -37,8 +39,8 @@ type AccessLogRecord struct {
Status int `json:"status"` Status int `json:"status"`
BodyBytesSent int64 `json:"body_bytes_sent"` BodyBytesSent int64 `json:"body_bytes_sent"`
ElapsedTime time.Duration `json:"elapsed_time"` ElapsedTime time.Duration `json:"elapsed_time"`
HttpReferrer string `json:"http_referrer"` HTTPReferrer string `json:"http_referrer"`
HttpUserAgent string `json:"http_user_agent"` HTTPUserAgent string `json:"http_user_agent"`
RemoteUser string `json:"remote_user"` RemoteUser string `json:"remote_user"`
} }
@ -52,23 +54,21 @@ func (r *AccessLogRecord) json() ([]byte, error) {
} }
func disableEscapeHTML(i interface{}) { func disableEscapeHTML(i interface{}) {
e, ok := i.(interface { if e, ok := i.(interface {
SetEscapeHTML(bool) SetEscapeHTML(bool)
}); }); ok {
if ok {
e.SetEscapeHTML(false) e.SetEscapeHTML(false)
} }
} }
// AccessLog - Format and print access log.
func AccessLog(r *AccessLogRecord, format string) { func AccessLog(r *AccessLogRecord, format string) {
var msg string var msg string
switch format { switch format {
case apacheFormat: case apacheFormat:
timeFormatted := r.RequestTime.Format("02/Jan/2006 03:04:05") timeFormatted := r.RequestTime.Format("02/Jan/2006 03:04:05")
msg = fmt.Sprintf(apacheFormatPattern, r.RemoteAddr, timeFormatted, r.Request, r.Status, r.BodyBytesSent, msg = fmt.Sprintf(apacheFormatPattern, r.RemoteAddr, timeFormatted, r.Request, r.Status, r.BodyBytesSent,
r.ElapsedTime.Seconds(), r.HttpReferrer, r.HttpUserAgent) r.ElapsedTime.Seconds(), r.HTTPReferrer, r.HTTPUserAgent)
case jsonFormat: case jsonFormat:
fallthrough fallthrough
default: default:
@ -79,6 +79,5 @@ func AccessLog(r *AccessLogRecord, format string) {
msg = string(jsonData) msg = string(jsonData)
} }
} }
beeLogger.writeMsg(levelLoggerImpl, strings.TrimSpace(msg))
beeLogger.Debug(msg)
} }

View File

@ -8,8 +8,8 @@ import (
"net/url" "net/url"
"time" "time"
"github.com/OwnLocal/goes"
"github.com/astaxie/beego/logs" "github.com/astaxie/beego/logs"
"github.com/belogik/goes"
) )
// NewES return a LoggerInterface // NewES return a LoggerInterface
@ -21,7 +21,7 @@ func NewES() logs.Logger {
} }
type esLogger struct { type esLogger struct {
*goes.Connection *goes.Client
DSN string `json:"dsn"` DSN string `json:"dsn"`
Level int `json:"level"` Level int `json:"level"`
} }
@ -41,8 +41,8 @@ func (el *esLogger) Init(jsonconfig string) error {
} else if host, port, err := net.SplitHostPort(u.Host); err != nil { } else if host, port, err := net.SplitHostPort(u.Host); err != nil {
return err return err
} else { } else {
conn := goes.NewConnection(host, port) conn := goes.NewClient(host, port)
el.Connection = conn el.Client = conn
} }
return nil return nil
} }
@ -78,3 +78,4 @@ func (el *esLogger) Flush() {
func init() { func init() {
logs.Register(logs.AdapterEs, NewES) logs.Register(logs.AdapterEs, NewES)
} }

View File

@ -21,6 +21,7 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"path"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
@ -40,6 +41,9 @@ type fileLogWriter struct {
MaxLines int `json:"maxlines"` MaxLines int `json:"maxlines"`
maxLinesCurLines int maxLinesCurLines int
MaxFiles int `json:"maxfiles"`
MaxFilesCurFiles int
// Rotate at size // Rotate at size
MaxSize int `json:"maxsize"` MaxSize int `json:"maxsize"`
maxSizeCurSize int maxSizeCurSize int
@ -50,6 +54,12 @@ type fileLogWriter struct {
dailyOpenDate int dailyOpenDate int
dailyOpenTime time.Time dailyOpenTime time.Time
// Rotate hourly
Hourly bool `json:"hourly"`
MaxHours int64 `json:"maxhours"`
hourlyOpenDate int
hourlyOpenTime time.Time
Rotate bool `json:"rotate"` Rotate bool `json:"rotate"`
Level int `json:"level"` Level int `json:"level"`
@ -66,25 +76,30 @@ func newFileWriter() Logger {
w := &fileLogWriter{ w := &fileLogWriter{
Daily: true, Daily: true,
MaxDays: 7, MaxDays: 7,
Hourly: false,
MaxHours: 168,
Rotate: true, Rotate: true,
RotatePerm: "0440", RotatePerm: "0440",
Level: LevelTrace, Level: LevelTrace,
Perm: "0660", Perm: "0660",
MaxLines: 10000000,
MaxFiles: 999,
MaxSize: 1 << 28,
} }
return w return w
} }
// Init file logger with json config. // Init file logger with json config.
// jsonConfig like: // jsonConfig like:
// { // {
// "filename":"logs/beego.log", // "filename":"logs/beego.log",
// "maxLines":10000, // "maxLines":10000,
// "maxsize":1024, // "maxsize":1024,
// "daily":true, // "daily":true,
// "maxDays":15, // "maxDays":15,
// "rotate":true, // "rotate":true,
// "perm":"0600" // "perm":"0600"
// } // }
func (w *fileLogWriter) Init(jsonConfig string) error { func (w *fileLogWriter) Init(jsonConfig string) error {
err := json.Unmarshal([]byte(jsonConfig), w) err := json.Unmarshal([]byte(jsonConfig), w)
if err != nil { if err != nil {
@ -115,10 +130,16 @@ func (w *fileLogWriter) startLogger() error {
return w.initFd() return w.initFd()
} }
func (w *fileLogWriter) needRotate(size int, day int) bool { func (w *fileLogWriter) needRotateDaily(size int, day int) bool {
return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) ||
(w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) ||
(w.Daily && day != w.dailyOpenDate) (w.Daily && day != w.dailyOpenDate)
}
func (w *fileLogWriter) needRotateHourly(size int, hour int) bool {
return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) ||
(w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) ||
(w.Hourly && hour != w.hourlyOpenDate)
} }
@ -127,14 +148,23 @@ func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error {
if level > w.Level { if level > w.Level {
return nil return nil
} }
h, d := formatTimeHeader(when) hd, d, h := formatTimeHeader(when)
msg = string(h) + msg + "\n" msg = string(hd) + msg + "\n"
if w.Rotate { if w.Rotate {
w.RLock() w.RLock()
if w.needRotate(len(msg), d) { if w.needRotateHourly(len(msg), h) {
w.RUnlock() w.RUnlock()
w.Lock() w.Lock()
if w.needRotate(len(msg), d) { if w.needRotateHourly(len(msg), h) {
if err := w.doRotate(when); err != nil {
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
}
}
w.Unlock()
} else if w.needRotateDaily(len(msg), d) {
w.RUnlock()
w.Lock()
if w.needRotateDaily(len(msg), d) {
if err := w.doRotate(when); err != nil { if err := w.doRotate(when); err != nil {
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
} }
@ -161,6 +191,10 @@ func (w *fileLogWriter) createLogFile() (*os.File, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
filepath := path.Dir(w.Filename)
os.MkdirAll(filepath, os.FileMode(perm))
fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, os.FileMode(perm)) fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, os.FileMode(perm))
if err == nil { if err == nil {
// Make sure file perm is user set perm cause of `os.OpenFile` will obey umask // Make sure file perm is user set perm cause of `os.OpenFile` will obey umask
@ -178,11 +212,15 @@ func (w *fileLogWriter) initFd() error {
w.maxSizeCurSize = int(fInfo.Size()) w.maxSizeCurSize = int(fInfo.Size())
w.dailyOpenTime = time.Now() w.dailyOpenTime = time.Now()
w.dailyOpenDate = w.dailyOpenTime.Day() w.dailyOpenDate = w.dailyOpenTime.Day()
w.hourlyOpenTime = time.Now()
w.hourlyOpenDate = w.hourlyOpenTime.Hour()
w.maxLinesCurLines = 0 w.maxLinesCurLines = 0
if w.Daily { if w.Hourly {
go w.hourlyRotate(w.hourlyOpenTime)
} else if w.Daily {
go w.dailyRotate(w.dailyOpenTime) go w.dailyRotate(w.dailyOpenTime)
} }
if fInfo.Size() > 0 { if fInfo.Size() > 0 && w.MaxLines > 0 {
count, err := w.lines() count, err := w.lines()
if err != nil { if err != nil {
return err return err
@ -198,7 +236,22 @@ func (w *fileLogWriter) dailyRotate(openTime time.Time) {
tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100)) tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100))
<-tm.C <-tm.C
w.Lock() w.Lock()
if w.needRotate(0, time.Now().Day()) { if w.needRotateDaily(0, time.Now().Day()) {
if err := w.doRotate(time.Now()); err != nil {
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
}
}
w.Unlock()
}
func (w *fileLogWriter) hourlyRotate(openTime time.Time) {
y, m, d := openTime.Add(1 * time.Hour).Date()
h, _, _ := openTime.Add(1 * time.Hour).Clock()
nextHour := time.Date(y, m, d, h, 0, 0, 0, openTime.Location())
tm := time.NewTimer(time.Duration(nextHour.UnixNano() - openTime.UnixNano() + 100))
<-tm.C
w.Lock()
if w.needRotateHourly(0, time.Now().Hour()) {
if err := w.doRotate(time.Now()); err != nil { if err := w.doRotate(time.Now()); err != nil {
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
} }
@ -238,8 +291,10 @@ func (w *fileLogWriter) lines() (int, error) {
func (w *fileLogWriter) doRotate(logTime time.Time) error { func (w *fileLogWriter) doRotate(logTime time.Time) error {
// file exists // file exists
// Find the next available number // Find the next available number
num := 1 num := w.MaxFilesCurFiles + 1
fName := "" fName := ""
format := ""
var openTime time.Time
rotatePerm, err := strconv.ParseInt(w.RotatePerm, 8, 64) rotatePerm, err := strconv.ParseInt(w.RotatePerm, 8, 64)
if err != nil { if err != nil {
return err return err
@ -251,19 +306,26 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error {
goto RESTART_LOGGER goto RESTART_LOGGER
} }
if w.Hourly {
format = "2006010215"
openTime = w.hourlyOpenTime
} else if w.Daily {
format = "2006-01-02"
openTime = w.dailyOpenTime
}
// only when one of them be setted, then the file would be splited
if w.MaxLines > 0 || w.MaxSize > 0 { if w.MaxLines > 0 || w.MaxSize > 0 {
for ; err == nil && num <= 999; num++ { for ; err == nil && num <= w.MaxFiles; num++ {
fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", logTime.Format("2006-01-02"), num, w.suffix) fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", logTime.Format(format), num, w.suffix)
_, err = os.Lstat(fName) _, err = os.Lstat(fName)
} }
} else { } else {
fName = fmt.Sprintf("%s.%s%s", w.fileNameOnly, w.dailyOpenTime.Format("2006-01-02"), w.suffix) fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", openTime.Format(format), num, w.suffix)
_, err = os.Lstat(fName) _, err = os.Lstat(fName)
for ; err == nil && num <= 999; num++ { w.MaxFilesCurFiles = num
fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", w.dailyOpenTime.Format("2006-01-02"), num, w.suffix)
_, err = os.Lstat(fName)
}
} }
// return error if the last file checked still existed // return error if the last file checked still existed
if err == nil { if err == nil {
return fmt.Errorf("Rotate: Cannot find free log number to rename %s", w.Filename) return fmt.Errorf("Rotate: Cannot find free log number to rename %s", w.Filename)
@ -307,13 +369,21 @@ func (w *fileLogWriter) deleteOldLog() {
if info == nil { if info == nil {
return return
} }
if w.Hourly {
if !info.IsDir() && info.ModTime().Add(24*time.Hour*time.Duration(w.MaxDays)).Before(time.Now()) { if !info.IsDir() && info.ModTime().Add(1 * time.Hour * time.Duration(w.MaxHours)).Before(time.Now()) {
if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) &&
strings.HasSuffix(filepath.Base(path), w.suffix) { strings.HasSuffix(filepath.Base(path), w.suffix) {
os.Remove(path) os.Remove(path)
} }
} }
} else if w.Daily {
if !info.IsDir() && info.ModTime().Add(24 * time.Hour * time.Duration(w.MaxDays)).Before(time.Now()) {
if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) &&
strings.HasSuffix(filepath.Base(path), w.suffix) {
os.Remove(path)
}
}
}
return return
}) })
} }

View File

@ -112,7 +112,7 @@ func TestFile2(t *testing.T) {
os.Remove("test2.log") os.Remove("test2.log")
} }
func TestFileRotate_01(t *testing.T) { func TestFileDailyRotate_01(t *testing.T) {
log := NewLogger(10000) log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`) log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`)
log.Debug("debug") log.Debug("debug")
@ -133,28 +133,28 @@ func TestFileRotate_01(t *testing.T) {
os.Remove("test3.log") os.Remove("test3.log")
} }
func TestFileRotate_02(t *testing.T) { func TestFileDailyRotate_02(t *testing.T) {
fn1 := "rotate_day.log" fn1 := "rotate_day.log"
fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log" fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log"
testFileRotate(t, fn1, fn2) testFileRotate(t, fn1, fn2, true, false)
} }
func TestFileRotate_03(t *testing.T) { func TestFileDailyRotate_03(t *testing.T) {
fn1 := "rotate_day.log" fn1 := "rotate_day.log"
fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log" fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log"
os.Create(fn) os.Create(fn)
fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log" fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log"
testFileRotate(t, fn1, fn2) testFileRotate(t, fn1, fn2, true, false)
os.Remove(fn) os.Remove(fn)
} }
func TestFileRotate_04(t *testing.T) { func TestFileDailyRotate_04(t *testing.T) {
fn1 := "rotate_day.log" fn1 := "rotate_day.log"
fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log" fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log"
testFileDailyRotate(t, fn1, fn2) testFileDailyRotate(t, fn1, fn2)
} }
func TestFileRotate_05(t *testing.T) { func TestFileDailyRotate_05(t *testing.T) {
fn1 := "rotate_day.log" fn1 := "rotate_day.log"
fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log" fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log"
os.Create(fn) os.Create(fn)
@ -162,7 +162,7 @@ func TestFileRotate_05(t *testing.T) {
testFileDailyRotate(t, fn1, fn2) testFileDailyRotate(t, fn1, fn2)
os.Remove(fn) os.Remove(fn)
} }
func TestFileRotate_06(t *testing.T) { //test file mode func TestFileDailyRotate_06(t *testing.T) { //test file mode
log := NewLogger(10000) log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`) log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`)
log.Debug("debug") log.Debug("debug")
@ -183,23 +183,110 @@ func TestFileRotate_06(t *testing.T) { //test file mode
os.Remove(rotateName) os.Remove(rotateName)
os.Remove("test3.log") os.Remove("test3.log")
} }
func testFileRotate(t *testing.T, fn1, fn2 string) {
func TestFileHourlyRotate_01(t *testing.T) {
log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test3.log","hourly":true,"maxlines":4}`)
log.Debug("debug")
log.Info("info")
log.Notice("notice")
log.Warning("warning")
log.Error("error")
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006010215"), 1) + ".log"
b, err := exists(rotateName)
if !b || err != nil {
os.Remove("test3.log")
t.Fatal("rotate not generated")
}
os.Remove(rotateName)
os.Remove("test3.log")
}
func TestFileHourlyRotate_02(t *testing.T) {
fn1 := "rotate_hour.log"
fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log"
testFileRotate(t, fn1, fn2, false, true)
}
func TestFileHourlyRotate_03(t *testing.T) {
fn1 := "rotate_hour.log"
fn := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".log"
os.Create(fn)
fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log"
testFileRotate(t, fn1, fn2, false, true)
os.Remove(fn)
}
func TestFileHourlyRotate_04(t *testing.T) {
fn1 := "rotate_hour.log"
fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log"
testFileHourlyRotate(t, fn1, fn2)
}
func TestFileHourlyRotate_05(t *testing.T) {
fn1 := "rotate_hour.log"
fn := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".log"
os.Create(fn)
fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log"
testFileHourlyRotate(t, fn1, fn2)
os.Remove(fn)
}
func TestFileHourlyRotate_06(t *testing.T) { //test file mode
log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test3.log", "hourly":true, "maxlines":4}`)
log.Debug("debug")
log.Info("info")
log.Notice("notice")
log.Warning("warning")
log.Error("error")
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006010215"), 1) + ".log"
s, _ := os.Lstat(rotateName)
if s.Mode() != 0440 {
os.Remove(rotateName)
os.Remove("test3.log")
t.Fatal("rotate file mode error")
}
os.Remove(rotateName)
os.Remove("test3.log")
}
func testFileRotate(t *testing.T, fn1, fn2 string, daily, hourly bool) {
fw := &fileLogWriter{ fw := &fileLogWriter{
Daily: true, Daily: daily,
MaxDays: 7, MaxDays: 7,
Hourly: hourly,
MaxHours: 168,
Rotate: true, Rotate: true,
Level: LevelTrace, Level: LevelTrace,
Perm: "0660", Perm: "0660",
RotatePerm: "0440", RotatePerm: "0440",
} }
fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1))
fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) if daily {
fw.dailyOpenDate = fw.dailyOpenTime.Day() fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1))
fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug) fw.dailyOpenTime = time.Now().Add(-24 * time.Hour)
fw.dailyOpenDate = fw.dailyOpenTime.Day()
}
if hourly {
fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1))
fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour)
fw.hourlyOpenDate = fw.hourlyOpenTime.Day()
}
fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug)
for _, file := range []string{fn1, fn2} { for _, file := range []string{fn1, fn2} {
_, err := os.Stat(file) _, err := os.Stat(file)
if err != nil { if err != nil {
t.Log(err)
t.FailNow() t.FailNow()
} }
os.Remove(file) os.Remove(file)
@ -239,6 +326,37 @@ func testFileDailyRotate(t *testing.T, fn1, fn2 string) {
fw.Destroy() fw.Destroy()
} }
func testFileHourlyRotate(t *testing.T, fn1, fn2 string) {
fw := &fileLogWriter{
Hourly: true,
MaxHours: 168,
Rotate: true,
Level: LevelTrace,
Perm: "0660",
RotatePerm: "0440",
}
fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1))
fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour)
fw.hourlyOpenDate = fw.hourlyOpenTime.Hour()
hour, _ := time.ParseInLocation("2006010215", time.Now().Format("2006010215"), fw.hourlyOpenTime.Location())
hour = hour.Add(-1 * time.Second)
fw.hourlyRotate(hour)
for _, file := range []string{fn1, fn2} {
_, err := os.Stat(file)
if err != nil {
t.FailNow()
}
content, err := ioutil.ReadFile(file)
if err != nil {
t.FailNow()
}
if len(content) > 0 {
t.FailNow()
}
os.Remove(file)
}
fw.Destroy()
}
func exists(path string) (bool, error) { func exists(path string) (bool, error) {
_, err := os.Stat(path) _, err := os.Stat(path)
if err == nil { if err == nil {

View File

@ -309,6 +309,11 @@ func (bl *BeeLogger) SetLevel(l int) {
bl.level = l bl.level = l
} }
// GetLevel Get Current log message level.
func (bl *BeeLogger) GetLevel() int {
return bl.level
}
// SetLogFuncCallDepth set log funcCallDepth // SetLogFuncCallDepth set log funcCallDepth
func (bl *BeeLogger) SetLogFuncCallDepth(d int) { func (bl *BeeLogger) SetLogFuncCallDepth(d int) {
bl.loggerFuncCallDepth = d bl.loggerFuncCallDepth = d

View File

@ -33,7 +33,7 @@ func newLogWriter(wr io.Writer) *logWriter {
func (lg *logWriter) println(when time.Time, msg string) { func (lg *logWriter) println(when time.Time, msg string) {
lg.Lock() lg.Lock()
h, _ := formatTimeHeader(when) h, _, _:= formatTimeHeader(when)
lg.writer.Write(append(append(h, msg...), '\n')) lg.writer.Write(append(append(h, msg...), '\n'))
lg.Unlock() lg.Unlock()
} }
@ -90,10 +90,10 @@ const (
ns1 = `0123456789` ns1 = `0123456789`
) )
func formatTimeHeader(when time.Time) ([]byte, int) { func formatTimeHeader(when time.Time) ([]byte, int, int) {
y, mo, d := when.Date() y, mo, d := when.Date()
h, mi, s := when.Clock() h, mi, s := when.Clock()
ns := when.Nanosecond()/1000000 ns := when.Nanosecond() / 1000000
//len("2006/01/02 15:04:05.123 ")==24 //len("2006/01/02 15:04:05.123 ")==24
var buf [24]byte var buf [24]byte
@ -123,7 +123,7 @@ func formatTimeHeader(when time.Time) ([]byte, int) {
buf[23] = ' ' buf[23] = ' '
return buf[0:], d return buf[0:], d, h
} }
var ( var (

View File

@ -30,8 +30,8 @@ func TestFormatHeader_0(t *testing.T) {
if tm.Year() >= 2100 { if tm.Year() >= 2100 {
break break
} }
h, _ := formatTimeHeader(tm) h, _, _ := formatTimeHeader(tm)
if tm.Format("2006/01/02 15:04:05.999 ") != string(h) { if tm.Format("2006/01/02 15:04:05.000 ") != string(h) {
t.Log(tm) t.Log(tm)
t.FailNow() t.FailNow()
} }
@ -48,8 +48,8 @@ func TestFormatHeader_1(t *testing.T) {
if tm.Year() >= year+1 { if tm.Year() >= year+1 {
break break
} }
h, _ := formatTimeHeader(tm) h, _, _ := formatTimeHeader(tm)
if tm.Format("2006/01/02 15:04:05.999 ") != string(h) { if tm.Format("2006/01/02 15:04:05.000 ") != string(h) {
t.Log(tm) t.Log(tm)
t.FailNow() t.FailNow()
} }

View File

@ -67,7 +67,10 @@ func (f *multiFileLogWriter) Init(config string) error {
jsonMap["level"] = i jsonMap["level"] = i
bs, _ := json.Marshal(jsonMap) bs, _ := json.Marshal(jsonMap)
writer = newFileWriter().(*fileLogWriter) writer = newFileWriter().(*fileLogWriter)
writer.Init(string(bs)) err := writer.Init(string(bs))
if err != nil {
return err
}
f.writers[i] = writer f.writers[i] = writer
} }
} }

View File

@ -322,7 +322,7 @@ func (m *Migration) GetSQL() (sql string) {
sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name) sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name)
} }
if len(m.Columns) > index { if len(m.Columns) > index+1 {
sql += "," sql += ","
} }
} }
@ -355,7 +355,7 @@ func (m *Migration) GetSQL() (sql string) {
} else { } else {
sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name) sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name)
} }
if len(m.Columns) > index { if len(m.Columns) > index+1 {
sql += "," sql += ","
} }
} }
@ -366,14 +366,14 @@ func (m *Migration) GetSQL() (sql string) {
for index, unique := range m.Uniques { for index, unique := range m.Uniques {
sql += fmt.Sprintf("\n DROP KEY `%s`", unique.Definition) sql += fmt.Sprintf("\n DROP KEY `%s`", unique.Definition)
if len(m.Uniques) > index { if len(m.Uniques) > index+1 {
sql += "," sql += ","
} }
} }
for index, column := range m.Renames { for index, column := range m.Renames {
sql += fmt.Sprintf("\n CHANGE COLUMN `%s` `%s` %s %s %s %s", column.NewName, column.OldName, column.OldDataType, column.OldUnsign, column.OldNull, column.OldDefault) sql += fmt.Sprintf("\n CHANGE COLUMN `%s` `%s` %s %s %s %s", column.NewName, column.OldName, column.OldDataType, column.OldUnsign, column.OldNull, column.OldDefault)
if len(m.Renames) > index { if len(m.Renames) > index+1 {
sql += "," sql += ","
} }
} }

View File

@ -172,7 +172,7 @@ func Register(name string, m Migrationer) error {
return nil return nil
} }
// Upgrade upgrate the migration from lasttime // Upgrade upgrade the migration from lasttime
func Upgrade(lasttime int64) error { func Upgrade(lasttime int64) error {
sm := sortMap(migrationMap) sm := sortMap(migrationMap)
i := 0 i := 0

View File

@ -207,11 +207,11 @@ func (n *Namespace) Include(cList ...ControllerInterface) *Namespace {
func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { func (n *Namespace) Namespace(ns ...*Namespace) *Namespace {
for _, ni := range ns { for _, ni := range ns {
for k, v := range ni.handlers.routers { for k, v := range ni.handlers.routers {
if t, ok := n.handlers.routers[k]; ok { if _, ok := n.handlers.routers[k]; ok {
addPrefix(v, ni.prefix) addPrefix(v, ni.prefix)
n.handlers.routers[k].AddTree(ni.prefix, v) n.handlers.routers[k].AddTree(ni.prefix, v)
} else { } else {
t = NewTree() t := NewTree()
t.AddTree(ni.prefix, v) t.AddTree(ni.prefix, v)
addPrefix(t, ni.prefix) addPrefix(t, ni.prefix)
n.handlers.routers[k] = t n.handlers.routers[k] = t
@ -236,11 +236,11 @@ func (n *Namespace) Namespace(ns ...*Namespace) *Namespace {
func AddNamespace(nl ...*Namespace) { func AddNamespace(nl ...*Namespace) {
for _, n := range nl { for _, n := range nl {
for k, v := range n.handlers.routers { for k, v := range n.handlers.routers {
if t, ok := BeeApp.Handlers.routers[k]; ok { if _, ok := BeeApp.Handlers.routers[k]; ok {
addPrefix(v, n.prefix) addPrefix(v, n.prefix)
BeeApp.Handlers.routers[k].AddTree(n.prefix, v) BeeApp.Handlers.routers[k].AddTree(n.prefix, v)
} else { } else {
t = NewTree() t := NewTree()
t.AddTree(n.prefix, v) t.AddTree(n.prefix, v)
addPrefix(t, n.prefix) addPrefix(t, n.prefix)
BeeApp.Handlers.routers[k] = t BeeApp.Handlers.routers[k] = t

View File

@ -51,12 +51,14 @@ checkColumn:
switch fieldType { switch fieldType {
case TypeBooleanField: case TypeBooleanField:
col = T["bool"] col = T["bool"]
case TypeCharField: case TypeVarCharField:
if al.Driver == DRPostgres && fi.toText { if al.Driver == DRPostgres && fi.toText {
col = T["string-text"] col = T["string-text"]
} else { } else {
col = fmt.Sprintf(T["string"], fieldSize) col = fmt.Sprintf(T["string"], fieldSize)
} }
case TypeCharField:
col = fmt.Sprintf(T["string-char"], fieldSize)
case TypeTextField: case TypeTextField:
col = T["string-text"] col = T["string-text"]
case TypeTimeField: case TypeTimeField:
@ -96,13 +98,13 @@ checkColumn:
} }
case TypeJSONField: case TypeJSONField:
if al.Driver != DRPostgres { if al.Driver != DRPostgres {
fieldType = TypeCharField fieldType = TypeVarCharField
goto checkColumn goto checkColumn
} }
col = T["json"] col = T["json"]
case TypeJsonbField: case TypeJsonbField:
if al.Driver != DRPostgres { if al.Driver != DRPostgres {
fieldType = TypeCharField fieldType = TypeVarCharField
goto checkColumn goto checkColumn
} }
col = T["jsonb"] col = T["jsonb"]
@ -195,6 +197,10 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
if strings.Contains(column, "%COL%") { if strings.Contains(column, "%COL%") {
column = strings.Replace(column, "%COL%", fi.column, -1) column = strings.Replace(column, "%COL%", fi.column, -1)
} }
if fi.description != "" {
column += " " + fmt.Sprintf("COMMENT '%s'",fi.description)
}
columns = append(columns, column) columns = append(columns, column)
} }

View File

@ -142,7 +142,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} else { } else {
value = field.Bool() value = field.Bool()
} }
case TypeCharField, TypeTextField, TypeJSONField, TypeJsonbField: case TypeVarCharField, TypeCharField, TypeTextField, TypeJSONField, TypeJsonbField:
if ns, ok := field.Interface().(sql.NullString); ok { if ns, ok := field.Interface().(sql.NullString); ok {
value = nil value = nil
if ns.Valid { if ns.Valid {
@ -536,6 +536,8 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
updates := make([]string, len(names)) updates := make([]string, len(names))
var conflitValue interface{} var conflitValue interface{}
for i, v := range names { for i, v := range names {
// identifier in database may not be case-sensitive, so quote it
v = fmt.Sprintf("%s%s%s", Q, v, Q)
marks[i] = "?" marks[i] = "?"
valueStr := argsMap[strings.ToLower(v)] valueStr := argsMap[strings.ToLower(v)]
if v == args0 { if v == args0 {
@ -760,7 +762,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
} }
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, values...) var err error
var res sql.Result
if qs != nil && qs.forContext {
res, err = q.ExecContext(qs.ctx, query, values...)
} else {
res, err = q.Exec(query, values...)
}
if err == nil { if err == nil {
return res.RowsAffected() return res.RowsAffected()
} }
@ -849,11 +857,16 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
for i := range marks { for i := range marks {
marks[i] = "?" marks[i] = "?"
} }
sql := fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) sqlIn := fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql) query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, args...) var res sql.Result
if qs != nil && qs.forContext {
res, err = q.ExecContext(qs.ctx, query, args...)
} else {
res, err = q.Exec(query, args...)
}
if err == nil { if err == nil {
num, err := res.RowsAffected() num, err := res.RowsAffected()
if err != nil { if err != nil {
@ -926,7 +939,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
maps[fi.column] = true maps[fi.column] = true
} }
} else { } else {
panic(fmt.Errorf("wrong field/column name `%s`", col)) return 0, fmt.Errorf("wrong field/column name `%s`", col)
} }
} }
if hasRel { if hasRel {
@ -969,14 +982,25 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
} }
query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit)
if qs.forupdate {
query += " FOR UPDATE"
}
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
var rs *sql.Rows var rs *sql.Rows
r, err := q.Query(query, args...) var err error
if err != nil { if qs != nil && qs.forContext {
return 0, err rs, err = q.QueryContext(qs.ctx, query, args...)
if err != nil {
return 0, err
}
} else {
rs, err = q.Query(query, args...)
if err != nil {
return 0, err
}
} }
rs = r
refs := make([]interface{}, colsNum) refs := make([]interface{}, colsNum)
for i := range refs { for i := range refs {
@ -1105,8 +1129,12 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
row := q.QueryRow(query, args...) var row *sql.Row
if qs != nil && qs.forContext {
row = q.QueryRowContext(qs.ctx, query, args...)
} else {
row = q.QueryRow(query, args...)
}
err = row.Scan(&cnt) err = row.Scan(&cnt)
return return
} }
@ -1240,7 +1268,7 @@ setValue:
} }
value = b value = b
} }
case fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField:
if str == nil { if str == nil {
value = ToStr(val) value = ToStr(val)
} else { } else {
@ -1386,7 +1414,7 @@ setValue:
field.SetBool(value.(bool)) field.SetBool(value.(bool))
} }
} }
case fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField:
if isNative { if isNative {
if ns, ok := field.Interface().(sql.NullString); ok { if ns, ok := field.Interface().(sql.NullString); ok {
if value == nil { if value == nil {

View File

@ -119,7 +119,7 @@ type alias struct {
func detectTZ(al *alias) { func detectTZ(al *alias) {
// orm timezone system match database // orm timezone system match database
// default use Local // default use Local
al.TZ = time.Local al.TZ = DefaultTimeLoc
if al.DriverName == "sphinx" { if al.DriverName == "sphinx" {
return return
@ -136,7 +136,9 @@ func detectTZ(al *alias) {
} }
t, err := time.Parse("-07:00:00", tz) t, err := time.Parse("-07:00:00", tz)
if err == nil { if err == nil {
al.TZ = t.Location() if t.Location().String() != "" {
al.TZ = t.Location()
}
} else { } else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
} }

View File

@ -46,6 +46,7 @@ var mysqlTypes = map[string]string{
"pk": "NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY",
"bool": "bool", "bool": "bool",
"string": "varchar(%d)", "string": "varchar(%d)",
"string-char": "char(%d)",
"string-text": "longtext", "string-text": "longtext",
"time.Time-date": "date", "time.Time-date": "date",
"time.Time": "datetime", "time.Time": "datetime",

View File

@ -34,6 +34,7 @@ var oracleTypes = map[string]string{
"pk": "NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY",
"bool": "bool", "bool": "bool",
"string": "VARCHAR2(%d)", "string": "VARCHAR2(%d)",
"string-char": "CHAR(%d)",
"string-text": "VARCHAR2(%d)", "string-text": "VARCHAR2(%d)",
"time.Time-date": "DATE", "time.Time-date": "DATE",
"time.Time": "TIMESTAMP", "time.Time": "TIMESTAMP",

View File

@ -43,6 +43,7 @@ var postgresTypes = map[string]string{
"pk": "NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY",
"bool": "bool", "bool": "bool",
"string": "varchar(%d)", "string": "varchar(%d)",
"string-char": "char(%d)",
"string-text": "text", "string-text": "text",
"time.Time-date": "date", "time.Time-date": "date",
"time.Time": "timestamp with time zone", "time.Time": "timestamp with time zone",

View File

@ -43,6 +43,7 @@ var sqliteTypes = map[string]string{
"pk": "NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY",
"bool": "bool", "bool": "bool",
"string": "varchar(%d)", "string": "varchar(%d)",
"string-char": "character(%d)",
"string-text": "text", "string-text": "text",
"time.Time-date": "date", "time.Time-date": "date",
"time.Time": "datetime", "time.Time": "datetime",

View File

@ -372,7 +372,13 @@ func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (whe
operator = "exact" operator = "exact"
} }
operSQL, args := t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz) var operSQL string
var args []interface{}
if p.isRaw {
operSQL = p.sql
} else {
operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz)
}
leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
t.base.GenerateOperatorLeftCol(fi, operator, &leftCol) t.base.GenerateOperatorLeftCol(fi, operator, &leftCol)

View File

@ -52,7 +52,7 @@ func (mc *_modelCache) all() map[string]*modelInfo {
return m return m
} }
// get orderd model info // get ordered model info
func (mc *_modelCache) allOrdered() []*modelInfo { func (mc *_modelCache) allOrdered() []*modelInfo {
m := make([]*modelInfo, 0, len(mc.orders)) m := make([]*modelInfo, 0, len(mc.orders))
for _, table := range mc.orders { for _, table := range mc.orders {

View File

@ -89,7 +89,7 @@ func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) {
modelCache.set(table, mi) modelCache.set(table, mi)
} }
// boostrap models // bootstrap models
func bootStrap() { func bootStrap() {
if modelCache.done { if modelCache.done {
return return
@ -332,7 +332,7 @@ func RegisterModelWithSuffix(suffix string, models ...interface{}) {
} }
} }
// BootStrap bootrap models. // BootStrap bootstrap models.
// make all model parsed and can not add more models // make all model parsed and can not add more models
func BootStrap() { func BootStrap() {
if modelCache.done { if modelCache.done {

View File

@ -23,6 +23,7 @@ import (
// Define the Type enum // Define the Type enum
const ( const (
TypeBooleanField = 1 << iota TypeBooleanField = 1 << iota
TypeVarCharField
TypeCharField TypeCharField
TypeTextField TypeTextField
TypeTimeField TypeTimeField
@ -49,9 +50,9 @@ const (
// Define some logic enum // Define some logic enum
const ( const (
IsIntegerField = ^-TypePositiveBigIntegerField >> 5 << 6 IsIntegerField = ^-TypePositiveBigIntegerField >> 6 << 7
IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 9 << 10 IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 10 << 11
IsRelField = ^-RelReverseMany >> 17 << 18 IsRelField = ^-RelReverseMany >> 18 << 19
IsFieldType = ^-RelReverseMany<<1 + 1 IsFieldType = ^-RelReverseMany<<1 + 1
) )
@ -85,7 +86,7 @@ func (e *BooleanField) SetRaw(value interface{}) error {
e.Set(d) e.Set(d)
case string: case string:
v, err := StrTo(d).Bool() v, err := StrTo(d).Bool()
if err != nil { if err == nil {
e.Set(v) e.Set(v)
} }
return err return err
@ -126,7 +127,7 @@ func (e *CharField) String() string {
// FieldType return the enum type // FieldType return the enum type
func (e *CharField) FieldType() int { func (e *CharField) FieldType() int {
return TypeCharField return TypeVarCharField
} }
// SetRaw set the interface to string // SetRaw set the interface to string
@ -190,7 +191,7 @@ func (e *TimeField) SetRaw(value interface{}) error {
e.Set(d) e.Set(d)
case string: case string:
v, err := timeParse(d, formatTime) v, err := timeParse(d, formatTime)
if err != nil { if err == nil {
e.Set(v) e.Set(v)
} }
return err return err
@ -232,7 +233,7 @@ func (e *DateField) Set(d time.Time) {
*e = DateField(d) *e = DateField(d)
} }
// String convert datatime to string // String convert datetime to string
func (e *DateField) String() string { func (e *DateField) String() string {
return e.Value().String() return e.Value().String()
} }
@ -249,7 +250,7 @@ func (e *DateField) SetRaw(value interface{}) error {
e.Set(d) e.Set(d)
case string: case string:
v, err := timeParse(d, formatDate) v, err := timeParse(d, formatDate)
if err != nil { if err == nil {
e.Set(v) e.Set(v)
} }
return err return err
@ -272,12 +273,12 @@ var _ Fielder = new(DateField)
// Takes the same extra arguments as DateField. // Takes the same extra arguments as DateField.
type DateTimeField time.Time type DateTimeField time.Time
// Value return the datatime value // Value return the datetime value
func (e DateTimeField) Value() time.Time { func (e DateTimeField) Value() time.Time {
return time.Time(e) return time.Time(e)
} }
// Set set the time.Time to datatime // Set set the time.Time to datetime
func (e *DateTimeField) Set(d time.Time) { func (e *DateTimeField) Set(d time.Time) {
*e = DateTimeField(d) *e = DateTimeField(d)
} }
@ -299,7 +300,7 @@ func (e *DateTimeField) SetRaw(value interface{}) error {
e.Set(d) e.Set(d)
case string: case string:
v, err := timeParse(d, formatDateTime) v, err := timeParse(d, formatDateTime)
if err != nil { if err == nil {
e.Set(v) e.Set(v)
} }
return err return err
@ -309,12 +310,12 @@ func (e *DateTimeField) SetRaw(value interface{}) error {
return nil return nil
} }
// RawValue return the datatime value // RawValue return the datetime value
func (e *DateTimeField) RawValue() interface{} { func (e *DateTimeField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify datatime implement fielder // verify datetime implement fielder
var _ Fielder = new(DateTimeField) var _ Fielder = new(DateTimeField)
// FloatField A floating-point number represented in go by a float32 value. // FloatField A floating-point number represented in go by a float32 value.
@ -349,9 +350,10 @@ func (e *FloatField) SetRaw(value interface{}) error {
e.Set(d) e.Set(d)
case string: case string:
v, err := StrTo(d).Float64() v, err := StrTo(d).Float64()
if err != nil { if err == nil {
e.Set(v) e.Set(v)
} }
return err
default: default:
return fmt.Errorf("<FloatField.SetRaw> unknown value `%s`", value) return fmt.Errorf("<FloatField.SetRaw> unknown value `%s`", value)
} }
@ -396,9 +398,10 @@ func (e *SmallIntegerField) SetRaw(value interface{}) error {
e.Set(d) e.Set(d)
case string: case string:
v, err := StrTo(d).Int16() v, err := StrTo(d).Int16()
if err != nil { if err == nil {
e.Set(v) e.Set(v)
} }
return err
default: default:
return fmt.Errorf("<SmallIntegerField.SetRaw> unknown value `%s`", value) return fmt.Errorf("<SmallIntegerField.SetRaw> unknown value `%s`", value)
} }
@ -443,9 +446,10 @@ func (e *IntegerField) SetRaw(value interface{}) error {
e.Set(d) e.Set(d)
case string: case string:
v, err := StrTo(d).Int32() v, err := StrTo(d).Int32()
if err != nil { if err == nil {
e.Set(v) e.Set(v)
} }
return err
default: default:
return fmt.Errorf("<IntegerField.SetRaw> unknown value `%s`", value) return fmt.Errorf("<IntegerField.SetRaw> unknown value `%s`", value)
} }
@ -490,9 +494,10 @@ func (e *BigIntegerField) SetRaw(value interface{}) error {
e.Set(d) e.Set(d)
case string: case string:
v, err := StrTo(d).Int64() v, err := StrTo(d).Int64()
if err != nil { if err == nil {
e.Set(v) e.Set(v)
} }
return err
default: default:
return fmt.Errorf("<BigIntegerField.SetRaw> unknown value `%s`", value) return fmt.Errorf("<BigIntegerField.SetRaw> unknown value `%s`", value)
} }
@ -537,9 +542,10 @@ func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
e.Set(d) e.Set(d)
case string: case string:
v, err := StrTo(d).Uint16() v, err := StrTo(d).Uint16()
if err != nil { if err == nil {
e.Set(v) e.Set(v)
} }
return err
default: default:
return fmt.Errorf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value) return fmt.Errorf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value)
} }
@ -584,9 +590,10 @@ func (e *PositiveIntegerField) SetRaw(value interface{}) error {
e.Set(d) e.Set(d)
case string: case string:
v, err := StrTo(d).Uint32() v, err := StrTo(d).Uint32()
if err != nil { if err == nil {
e.Set(v) e.Set(v)
} }
return err
default: default:
return fmt.Errorf("<PositiveIntegerField.SetRaw> unknown value `%s`", value) return fmt.Errorf("<PositiveIntegerField.SetRaw> unknown value `%s`", value)
} }
@ -631,9 +638,10 @@ func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
e.Set(d) e.Set(d)
case string: case string:
v, err := StrTo(d).Uint64() v, err := StrTo(d).Uint64()
if err != nil { if err == nil {
e.Set(v) e.Set(v)
} }
return err
default: default:
return fmt.Errorf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value) return fmt.Errorf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value)
} }

View File

@ -136,6 +136,7 @@ type fieldInfo struct {
decimals int decimals int
isFielder bool // implement Fielder interface isFielder bool // implement Fielder interface
onDelete string onDelete string
description string
} }
// new field info // new field info
@ -244,8 +245,10 @@ checkType:
if err != nil { if err != nil {
goto end goto end
} }
if fieldType == TypeCharField { if fieldType == TypeVarCharField {
switch tags["type"] { switch tags["type"] {
case "char":
fieldType = TypeCharField
case "text": case "text":
fieldType = TypeTextField fieldType = TypeTextField
case "json": case "json":
@ -298,6 +301,7 @@ checkType:
fi.sf = sf fi.sf = sf
fi.fullName = mi.fullName + mName + "." + sf.Name fi.fullName = mi.fullName + mName + "." + sf.Name
fi.description = tags["description"]
fi.null = attrs["null"] fi.null = attrs["null"]
fi.index = attrs["index"] fi.index = attrs["index"]
fi.auto = attrs["auto"] fi.auto = attrs["auto"]
@ -357,7 +361,7 @@ checkType:
switch fieldType { switch fieldType {
case TypeBooleanField: case TypeBooleanField:
case TypeCharField, TypeJSONField, TypeJsonbField: case TypeVarCharField, TypeCharField, TypeJSONField, TypeJsonbField:
if size != "" { if size != "" {
v, e := StrTo(size).Int32() v, e := StrTo(size).Int32()
if e != nil { if e != nil {

View File

@ -75,7 +75,8 @@ func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int)
break break
} }
//record current field index //record current field index
fi.fieldIndex = append(index, i) fi.fieldIndex = append(fi.fieldIndex, index...)
fi.fieldIndex = append(fi.fieldIndex, i)
fi.mi = mi fi.mi = mi
fi.inModel = true fi.inModel = true
if !mi.fields.Add(fi) { if !mi.fields.Add(fi) {

View File

@ -49,7 +49,7 @@ func (e *SliceStringField) String() string {
} }
func (e *SliceStringField) FieldType() int { func (e *SliceStringField) FieldType() int {
return TypeCharField return TypeVarCharField
} }
func (e *SliceStringField) SetRaw(value interface{}) error { func (e *SliceStringField) SetRaw(value interface{}) error {
@ -433,53 +433,57 @@ var (
dDbBaser dbBaser dDbBaser dbBaser
) )
var (
helpinfo = `need driver and source!
Default DB Drivers.
driver: url
mysql: https://github.com/go-sql-driver/mysql
sqlite3: https://github.com/mattn/go-sqlite3
postgres: https://github.com/lib/pq
tidb: https://github.com/pingcap/tidb
usage:
go get -u github.com/astaxie/beego/orm
go get -u github.com/go-sql-driver/mysql
go get -u github.com/mattn/go-sqlite3
go get -u github.com/lib/pq
go get -u github.com/pingcap/tidb
#### MySQL
mysql -u root -e 'create database orm_test;'
export ORM_DRIVER=mysql
export ORM_SOURCE="root:@/orm_test?charset=utf8"
go test -v github.com/astaxie/beego/orm
#### Sqlite3
export ORM_DRIVER=sqlite3
export ORM_SOURCE='file:memory_test?mode=memory'
go test -v github.com/astaxie/beego/orm
#### PostgreSQL
psql -c 'create database orm_test;' -U postgres
export ORM_DRIVER=postgres
export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
go test -v github.com/astaxie/beego/orm
#### TiDB
export ORM_DRIVER=tidb
export ORM_SOURCE='memory://test/test'
go test -v github.com/astaxie/beego/orm
`
)
func init() { func init() {
Debug, _ = StrTo(DBARGS.Debug).Bool() Debug, _ = StrTo(DBARGS.Debug).Bool()
if DBARGS.Driver == "" || DBARGS.Source == "" { if DBARGS.Driver == "" || DBARGS.Source == "" {
fmt.Println(`need driver and source! fmt.Println(helpinfo)
Default DB Drivers.
driver: url
mysql: https://github.com/go-sql-driver/mysql
sqlite3: https://github.com/mattn/go-sqlite3
postgres: https://github.com/lib/pq
tidb: https://github.com/pingcap/tidb
usage:
go get -u github.com/astaxie/beego/orm
go get -u github.com/go-sql-driver/mysql
go get -u github.com/mattn/go-sqlite3
go get -u github.com/lib/pq
go get -u github.com/pingcap/tidb
#### MySQL
mysql -u root -e 'create database orm_test;'
export ORM_DRIVER=mysql
export ORM_SOURCE="root:@/orm_test?charset=utf8"
go test -v github.com/astaxie/beego/orm
#### Sqlite3
export ORM_DRIVER=sqlite3
export ORM_SOURCE='file:memory_test?mode=memory'
go test -v github.com/astaxie/beego/orm
#### PostgreSQL
psql -c 'create database orm_test;' -U postgres
export ORM_DRIVER=postgres
export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
go test -v github.com/astaxie/beego/orm
#### TiDB
export ORM_DRIVER=tidb
export ORM_SOURCE='memory://test/test'
go test -v github.com/astaxie/beego/orm
`)
os.Exit(2) os.Exit(2)
} }

View File

@ -44,6 +44,7 @@ var supportTag = map[string]int{
"decimals": 2, "decimals": 2,
"on_delete": 2, "on_delete": 2,
"type": 2, "type": 2,
"description": 2,
} }
// get reflect.Type name with package path. // get reflect.Type name with package path.
@ -109,7 +110,7 @@ func getTableUnique(val reflect.Value) [][]string {
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
column := col column := col
if col == "" { if col == "" {
column = snakeString(sf.Name) column = nameStrategyMap[nameStrategy](sf.Name)
} }
switch ft { switch ft {
case RelForeignKey, RelOneToOne: case RelForeignKey, RelOneToOne:
@ -149,7 +150,7 @@ func getFieldType(val reflect.Value) (ft int, err error) {
case reflect.TypeOf(new(bool)): case reflect.TypeOf(new(bool)):
ft = TypeBooleanField ft = TypeBooleanField
case reflect.TypeOf(new(string)): case reflect.TypeOf(new(string)):
ft = TypeCharField ft = TypeVarCharField
case reflect.TypeOf(new(time.Time)): case reflect.TypeOf(new(time.Time)):
ft = TypeDateTimeField ft = TypeDateTimeField
default: default:
@ -176,7 +177,7 @@ func getFieldType(val reflect.Value) (ft int, err error) {
case reflect.Bool: case reflect.Bool:
ft = TypeBooleanField ft = TypeBooleanField
case reflect.String: case reflect.String:
ft = TypeCharField ft = TypeVarCharField
default: default:
if elm.Interface() == nil { if elm.Interface() == nil {
panic(fmt.Errorf("%s is nil pointer, may be miss setting tag", val)) panic(fmt.Errorf("%s is nil pointer, may be miss setting tag", val))
@ -189,7 +190,7 @@ func getFieldType(val reflect.Value) (ft int, err error) {
case sql.NullBool: case sql.NullBool:
ft = TypeBooleanField ft = TypeBooleanField
case sql.NullString: case sql.NullString:
ft = TypeCharField ft = TypeVarCharField
case time.Time: case time.Time:
ft = TypeDateTimeField ft = TypeDateTimeField
} }

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// +build go1.8
// Package orm provide ORM for MySQL/PostgreSQL/sqlite // Package orm provide ORM for MySQL/PostgreSQL/sqlite
// Simple Usage // Simple Usage
// //
@ -52,6 +54,7 @@
package orm package orm
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
@ -422,7 +425,7 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
var name string var name string
if table, ok := ptrStructOrTableName.(string); ok { if table, ok := ptrStructOrTableName.(string); ok {
name = snakeString(table) name = nameStrategyMap[defaultNameStrategy](table)
if mi, ok := modelCache.get(name); ok { if mi, ok := modelCache.get(name); ok {
qs = newQuerySet(o, mi) qs = newQuerySet(o, mi)
} }
@ -458,11 +461,15 @@ func (o *orm) Using(name string) error {
// begin transaction // begin transaction
func (o *orm) Begin() error { func (o *orm) Begin() error {
return o.BeginTx(context.Background(), nil)
}
func (o *orm) BeginTx(ctx context.Context, opts *sql.TxOptions) error {
if o.isTx { if o.isTx {
return ErrTxHasBegan return ErrTxHasBegan
} }
var tx *sql.Tx var tx *sql.Tx
tx, err := o.db.(txer).Begin() tx, err := o.db.(txer).BeginTx(ctx, opts)
if err != nil { if err != nil {
return err return err
} }
@ -515,6 +522,16 @@ func (o *orm) Driver() Driver {
return driver(o.alias.Name) return driver(o.alias.Name)
} }
// return sql.DBStats for current database
func (o *orm) DBStats() *sql.DBStats {
if o.alias != nil && o.alias.DB != nil {
stats := o.alias.DB.Stats()
return &stats
}
return nil
}
// NewOrm create new orm // NewOrm create new orm
func NewOrm() Ormer { func NewOrm() Ormer {
BootStrap() // execute only once BootStrap() // execute only once
@ -541,6 +558,9 @@ func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
al.Name = aliasName al.Name = aliasName
al.DriverName = driverName al.DriverName = driverName
al.DB = db
detectTZ(al)
o := new(orm) o := new(orm)
o.alias = al o.alias = al

View File

@ -31,6 +31,8 @@ type condValue struct {
isOr bool isOr bool
isNot bool isNot bool
isCond bool isCond bool
isRaw bool
sql string
} }
// Condition struct. // Condition struct.
@ -45,6 +47,15 @@ func NewCondition() *Condition {
return c return c
} }
// Raw add raw sql to condition
func (c Condition) Raw(expr string, sql string) *Condition {
if len(sql) == 0 {
panic(fmt.Errorf("<Condition.Raw> sql cannot empty"))
}
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), sql: sql, isRaw: true})
return &c
}
// And add expression to condition // And add expression to condition
func (c Condition) And(expr string, args ...interface{}) *Condition { func (c Condition) And(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {

View File

@ -15,6 +15,7 @@
package orm package orm
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"io" "io"
@ -28,6 +29,9 @@ type Log struct {
*log.Logger *log.Logger
} }
//costomer log func
var LogFunc func(query map[string]interface{})
// NewLog set io.Writer to create a Logger. // NewLog set io.Writer to create a Logger.
func NewLog(out io.Writer) *Log { func NewLog(out io.Writer) *Log {
d := new(Log) d := new(Log)
@ -36,12 +40,15 @@ func NewLog(out io.Writer) *Log {
} }
func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) { func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) {
var logMap = make(map[string]interface{})
sub := time.Now().Sub(t) / 1e5 sub := time.Now().Sub(t) / 1e5
elsp := float64(int(sub)) / 10.0 elsp := float64(int(sub)) / 10.0
logMap["cost_time"] = elsp
flag := " OK" flag := " OK"
if err != nil { if err != nil {
flag = "FAIL" flag = "FAIL"
} }
logMap["flag"] = flag
con := fmt.Sprintf(" -[Queries/%s] - [%s / %11s / %7.1fms] - [%s]", alias.Name, flag, operaton, elsp, query) con := fmt.Sprintf(" -[Queries/%s] - [%s / %11s / %7.1fms] - [%s]", alias.Name, flag, operaton, elsp, query)
cons := make([]string, 0, len(args)) cons := make([]string, 0, len(args))
for _, arg := range args { for _, arg := range args {
@ -53,6 +60,10 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error
if err != nil { if err != nil {
con += " - " + err.Error() con += " - " + err.Error()
} }
logMap["sql"] = fmt.Sprintf("%s-`%s`", query, strings.Join(cons, "`, `"))
if LogFunc != nil{
LogFunc(logMap)
}
DebugLog.Println(con) DebugLog.Println(con)
} }
@ -122,6 +133,13 @@ func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) {
return stmt, err return stmt, err
} }
func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
a := time.Now()
stmt, err := d.db.PrepareContext(ctx, query)
debugLogQueies(d.alias, "db.Prepare", query, a, err)
return stmt, err
}
func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) { func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) {
a := time.Now() a := time.Now()
res, err := d.db.Exec(query, args...) res, err := d.db.Exec(query, args...)
@ -129,6 +147,13 @@ func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error)
return res, err return res, err
} }
func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
a := time.Now()
res, err := d.db.ExecContext(ctx, query, args...)
debugLogQueies(d.alias, "db.Exec", query, a, err, args...)
return res, err
}
func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) { func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) {
a := time.Now() a := time.Now()
res, err := d.db.Query(query, args...) res, err := d.db.Query(query, args...)
@ -136,6 +161,13 @@ func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error)
return res, err return res, err
} }
func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
a := time.Now()
res, err := d.db.QueryContext(ctx, query, args...)
debugLogQueies(d.alias, "db.Query", query, a, err, args...)
return res, err
}
func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row { func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row {
a := time.Now() a := time.Now()
res := d.db.QueryRow(query, args...) res := d.db.QueryRow(query, args...)
@ -143,6 +175,13 @@ func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row {
return res return res
} }
func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
a := time.Now()
res := d.db.QueryRowContext(ctx, query, args...)
debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...)
return res
}
func (d *dbQueryLog) Begin() (*sql.Tx, error) { func (d *dbQueryLog) Begin() (*sql.Tx, error) {
a := time.Now() a := time.Now()
tx, err := d.db.(txer).Begin() tx, err := d.db.(txer).Begin()
@ -150,6 +189,13 @@ func (d *dbQueryLog) Begin() (*sql.Tx, error) {
return tx, err return tx, err
} }
func (d *dbQueryLog) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
a := time.Now()
tx, err := d.db.(txer).BeginTx(ctx, opts)
debugLogQueies(d.alias, "db.BeginTx", "START TRANSACTION", a, err)
return tx, err
}
func (d *dbQueryLog) Commit() error { func (d *dbQueryLog) Commit() error {
a := time.Now() a := time.Now()
err := d.db.(txEnder).Commit() err := d.db.(txEnder).Commit()

View File

@ -15,6 +15,7 @@
package orm package orm
import ( import (
"context"
"fmt" "fmt"
) )
@ -55,16 +56,19 @@ func ColValue(opt operator, value interface{}) interface{} {
// real query struct // real query struct
type querySet struct { type querySet struct {
mi *modelInfo mi *modelInfo
cond *Condition cond *Condition
related []string related []string
relDepth int relDepth int
limit int64 limit int64
offset int64 offset int64
groups []string groups []string
orders []string orders []string
distinct bool distinct bool
orm *orm forupdate bool
orm *orm
ctx context.Context
forContext bool
} }
var _ QuerySeter = new(querySet) var _ QuerySeter = new(querySet)
@ -78,6 +82,15 @@ func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
return &o return &o
} }
// add raw sql to querySeter.
func (o querySet) FilterRaw(expr string, sql string) QuerySeter {
if o.cond == nil {
o.cond = NewCondition()
}
o.cond = o.cond.Raw(expr, sql)
return &o
}
// add NOT condition to querySeter. // add NOT condition to querySeter.
func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter { func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
if o.cond == nil { if o.cond == nil {
@ -127,6 +140,12 @@ func (o querySet) Distinct() QuerySeter {
return &o return &o
} }
// add FOR UPDATE to SELECT
func (o querySet) ForUpdate() QuerySeter {
o.forupdate = true
return &o
}
// set relation model to query together. // set relation model to query together.
// it will query relation models and assign to parent model. // it will query relation models and assign to parent model.
func (o querySet) RelatedSel(params ...interface{}) QuerySeter { func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
@ -259,6 +278,13 @@ func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string)
panic(ErrNotImplement) panic(ErrNotImplement)
} }
// set context to QuerySeter.
func (o querySet) WithContext(ctx context.Context) QuerySeter {
o.ctx = ctx
o.forContext = true
return &o
}
// create new QuerySeter. // create new QuerySeter.
func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {
o := new(querySet) o := new(querySet)

View File

@ -150,8 +150,10 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
case reflect.Struct: case reflect.Struct:
if value == nil { if value == nil {
ind.Set(reflect.Zero(ind.Type())) ind.Set(reflect.Zero(ind.Type()))
return
} else if _, ok := ind.Interface().(time.Time); ok { }
switch ind.Interface().(type) {
case time.Time:
var str string var str string
switch d := value.(type) { switch d := value.(type) {
case time.Time: case time.Time:
@ -178,7 +180,25 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
} }
} }
} }
case sql.NullString, sql.NullInt64, sql.NullFloat64, sql.NullBool:
indi := reflect.New(ind.Type()).Interface()
sc, ok := indi.(sql.Scanner)
if !ok {
return
}
err := sc.Scan(value)
if err == nil {
ind.Set(reflect.Indirect(reflect.ValueOf(sc)))
}
} }
case reflect.Ptr:
if value == nil {
ind.Set(reflect.Zero(ind.Type()))
break
}
ind.Set(reflect.New(ind.Type().Elem()))
o.setFieldValue(reflect.Indirect(ind), value)
} }
} }
@ -358,7 +378,7 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
var col string var col string
if col = tags["column"]; col == "" { if col = tags["column"]; col == "" {
col = snakeString(fe.Name) col = nameStrategyMap[nameStrategy](fe.Name)
} }
if v, ok := columnsMp[col]; ok { if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface() value := reflect.ValueOf(v).Elem().Interface()
@ -509,7 +529,7 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
var col string var col string
if col = tags["column"]; col == "" { if col = tags["column"]; col == "" {
col = snakeString(fe.Name) col = nameStrategyMap[nameStrategy](fe.Name)
} }
if v, ok := columnsMp[col]; ok { if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface() value := reflect.ValueOf(v).Elem().Interface()

View File

@ -12,10 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// +build go1.8
package orm package orm
import ( import (
"bytes" "bytes"
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -452,9 +455,18 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr)) throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr))
throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr)) throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr))
throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr)) throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr))
throwFail(t, AssertIs((*d.TimePtr).Format(testTime), timePtr.Format(testTime))) throwFail(t, AssertIs((*d.TimePtr).UTC().Format(testTime), timePtr.UTC().Format(testTime)))
throwFail(t, AssertIs((*d.DatePtr).Format(testDate), datePtr.Format(testDate))) throwFail(t, AssertIs((*d.DatePtr).UTC().Format(testDate), datePtr.UTC().Format(testDate)))
throwFail(t, AssertIs((*d.DateTimePtr).Format(testDateTime), dateTimePtr.Format(testDateTime))) throwFail(t, AssertIs((*d.DateTimePtr).UTC().Format(testDateTime), dateTimePtr.UTC().Format(testDateTime)))
// test support for pointer fields using RawSeter.QueryRows()
var dnList []*DataNull
Q := dDbBaser.TableQuote()
num, err = dORM.Raw(fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q), 3).QueryRows(&dnList)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1))
equal := reflect.DeepEqual(*dnList[0], d)
throwFailNow(t, AssertIs(equal, true))
} }
func TestDataCustomTypes(t *testing.T) { func TestDataCustomTypes(t *testing.T) {
@ -896,6 +908,18 @@ func TestOperators(t *testing.T) {
num, err = qs.Filter("id__between", []int{2, 3}).Count() num, err = qs.Filter("id__between", []int{2, 3}).Count()
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 2)) throwFail(t, AssertIs(num, 2))
num, err = qs.FilterRaw("user_name", "= 'slene'").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.FilterRaw("status", "IN (1, 2)").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 2))
num, err = qs.FilterRaw("profile_id", "IN (SELECT id FROM user_profile WHERE age=30)").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
} }
func TestSetCond(t *testing.T) { func TestSetCond(t *testing.T) {
@ -921,6 +945,11 @@ func TestSetCond(t *testing.T) {
num, err = qs.SetCond(cond4).Count() num, err = qs.SetCond(cond4).Count()
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(num, 3)) throwFail(t, AssertIs(num, 3))
cond5 := cond.Raw("user_name", "= 'slene'").OrNotCond(cond.And("user_name", "slene"))
num, err = qs.SetCond(cond5).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 3))
} }
func TestLimit(t *testing.T) { func TestLimit(t *testing.T) {
@ -1659,6 +1688,31 @@ func TestRawQueryRow(t *testing.T) {
throwFail(t, AssertIs(uid, 4)) throwFail(t, AssertIs(uid, 4))
throwFail(t, AssertIs(*status, 3)) throwFail(t, AssertIs(*status, 3))
throwFail(t, AssertIs(pid, nil)) throwFail(t, AssertIs(pid, nil))
// test for sql.Null* fields
nData := &DataNull{
NullString: sql.NullString{String: "test sql.null", Valid: true},
NullBool: sql.NullBool{Bool: true, Valid: true},
NullInt64: sql.NullInt64{Int64: 42, Valid: true},
NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true},
}
newId, err := dORM.Insert(nData)
throwFailNow(t, err)
var nd *DataNull
query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q)
err = dORM.Raw(query, newId).QueryRow(&nd)
throwFailNow(t, err)
throwFailNow(t, AssertNot(nd, nil))
throwFail(t, AssertIs(nd.NullBool.Valid, true))
throwFail(t, AssertIs(nd.NullBool.Bool, true))
throwFail(t, AssertIs(nd.NullString.Valid, true))
throwFail(t, AssertIs(nd.NullString.String, "test sql.null"))
throwFail(t, AssertIs(nd.NullInt64.Valid, true))
throwFail(t, AssertIs(nd.NullInt64.Int64, 42))
throwFail(t, AssertIs(nd.NullFloat64.Valid, true))
throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42))
} }
// user_profile table // user_profile table
@ -1751,6 +1805,32 @@ func TestQueryRows(t *testing.T) {
throwFailNow(t, AssertIs(l[1].UserName, "astaxie")) throwFailNow(t, AssertIs(l[1].UserName, "astaxie"))
throwFailNow(t, AssertIs(l[1].Age, 30)) throwFailNow(t, AssertIs(l[1].Age, 30))
// test for sql.Null* fields
nData := &DataNull{
NullString: sql.NullString{String: "test sql.null", Valid: true},
NullBool: sql.NullBool{Bool: true, Valid: true},
NullInt64: sql.NullInt64{Int64: 42, Valid: true},
NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true},
}
newId, err := dORM.Insert(nData)
throwFailNow(t, err)
var nDataList []*DataNull
query = fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q)
num, err = dORM.Raw(query, newId).QueryRows(&nDataList)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1))
nd := nDataList[0]
throwFailNow(t, AssertNot(nd, nil))
throwFail(t, AssertIs(nd.NullBool.Valid, true))
throwFail(t, AssertIs(nd.NullBool.Bool, true))
throwFail(t, AssertIs(nd.NullString.Valid, true))
throwFail(t, AssertIs(nd.NullString.String, "test sql.null"))
throwFail(t, AssertIs(nd.NullInt64.Valid, true))
throwFail(t, AssertIs(nd.NullInt64.Int64, 42))
throwFail(t, AssertIs(nd.NullFloat64.Valid, true))
throwFail(t, AssertIs(nd.NullFloat64.Float64, 42.42))
} }
func TestRawValues(t *testing.T) { func TestRawValues(t *testing.T) {
@ -1990,6 +2070,66 @@ func TestTransaction(t *testing.T) {
} }
func TestTransactionIsolationLevel(t *testing.T) {
// this test worked when database support transaction isolation level
if IsSqlite {
return
}
o1 := NewOrm()
o2 := NewOrm()
// start two transaction with isolation level repeatable read
err := o1.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead})
throwFail(t, err)
err = o2.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead})
throwFail(t, err)
// o1 insert tag
var tag Tag
tag.Name = "test-transaction"
id, err := o1.Insert(&tag)
throwFail(t, err)
throwFail(t, AssertIs(id > 0, true))
// o2 query tag table, no result
num, err := o2.QueryTable("tag").Filter("name", "test-transaction").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 0))
// o1 commit
o1.Commit()
// o2 query tag table, still no result
num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 0))
// o2 commit and query tag table, get the result
o2.Commit()
num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = o1.QueryTable("tag").Filter("name", "test-transaction").Delete()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
}
func TestBeginTxWithContextCanceled(t *testing.T) {
o := NewOrm()
ctx, cancel := context.WithCancel(context.Background())
o.BeginTx(ctx, nil)
id, err := o.Insert(&Tag{Name: "test-context"})
throwFail(t, err)
throwFail(t, AssertIs(id > 0, true))
// cancel the context before commit to make it error
cancel()
err = o.Commit()
throwFail(t, AssertIs(err, context.Canceled))
}
func TestReadOrCreate(t *testing.T) { func TestReadOrCreate(t *testing.T) {
u := &User{ u := &User{
UserName: "Kyle", UserName: "Kyle",
@ -2260,6 +2400,7 @@ func TestIgnoreCaseTag(t *testing.T) {
throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name")) throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name"))
throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name")) throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name"))
} }
func TestInsertOrUpdate(t *testing.T) { func TestInsertOrUpdate(t *testing.T) {
RegisterModel(new(User)) RegisterModel(new(User))
user := User{UserName: "unique_username133", Status: 1, Password: "o"} user := User{UserName: "unique_username133", Status: 1, Password: "o"}
@ -2297,6 +2438,11 @@ func TestInsertOrUpdate(t *testing.T) {
throwFailNow(t, AssertIs(user2.Status, test.Status)) throwFailNow(t, AssertIs(user2.Status, test.Status))
throwFailNow(t, AssertIs(user2.Password, strings.TrimSpace(test.Password))) throwFailNow(t, AssertIs(user2.Password, strings.TrimSpace(test.Password)))
} }
//postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values
if IsPostgres {
return
}
//test3 + //test3 +
_, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status+1") _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status+1")
if err != nil { if err != nil {

View File

@ -15,6 +15,7 @@
package orm package orm
import ( import (
"context"
"database/sql" "database/sql"
"reflect" "reflect"
"time" "time"
@ -106,6 +107,17 @@ type Ormer interface {
// ... // ...
// err = o.Rollback() // err = o.Rollback()
Begin() error Begin() error
// begin transaction with provided context and option
// the provided context is used until the transaction is committed or rolled back.
// if the context is canceled, the transaction will be rolled back.
// the provided TxOptions is optional and may be nil if defaults should be used.
// if a non-default isolation level is used that the driver doesn't support, an error will be returned.
// for example:
// o := NewOrm()
// err := o.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead})
// ...
// err = o.Rollback()
BeginTx(ctx context.Context, opts *sql.TxOptions) error
// commit transaction // commit transaction
Commit() error Commit() error
// rollback transaction // rollback transaction
@ -116,6 +128,7 @@ type Ormer interface {
// // update user testing's name to slene // // update user testing's name to slene
Raw(query string, args ...interface{}) RawSeter Raw(query string, args ...interface{}) RawSeter
Driver() Driver Driver() Driver
DBStats() *sql.DBStats
} }
// Inserter insert prepared statement // Inserter insert prepared statement
@ -135,6 +148,11 @@ type QuerySeter interface {
// // time compare // // time compare
// qs.Filter("created", time.Now()) // qs.Filter("created", time.Now())
Filter(string, ...interface{}) QuerySeter Filter(string, ...interface{}) QuerySeter
// add raw sql to querySeter.
// for example:
// qs.FilterRaw("user_id IN (SELECT id FROM profile WHERE age>=18)")
// //sql-> WHERE user_id IN (SELECT id FROM profile WHERE age>=18)
FilterRaw(string, string) QuerySeter
// add NOT condition to querySeter. // add NOT condition to querySeter.
// have the same usage as Filter // have the same usage as Filter
Exclude(string, ...interface{}) QuerySeter Exclude(string, ...interface{}) QuerySeter
@ -190,6 +208,10 @@ type QuerySeter interface {
// Distinct(). // Distinct().
// All(&permissions) // All(&permissions)
Distinct() QuerySeter Distinct() QuerySeter
// set FOR UPDATE to query.
// for example:
// o.QueryTable("user").Filter("uid", uid).ForUpdate().All(&users)
ForUpdate() QuerySeter
// return QuerySeter execution result number // return QuerySeter execution result number
// for example: // for example:
// num, err = qs.Filter("profile__age__gt", 28).Count() // num, err = qs.Filter("profile__age__gt", 28).Count()
@ -374,16 +396,23 @@ type RawSeter interface {
type stmtQuerier interface { type stmtQuerier interface {
Close() error Close() error
Exec(args ...interface{}) (sql.Result, error) Exec(args ...interface{}) (sql.Result, error)
//ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
Query(args ...interface{}) (*sql.Rows, error) Query(args ...interface{}) (*sql.Rows, error)
//QueryContext(args ...interface{}) (*sql.Rows, error)
QueryRow(args ...interface{}) *sql.Row QueryRow(args ...interface{}) *sql.Row
//QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row
} }
// db querier // db querier
type dbQuerier interface { type dbQuerier interface {
Prepare(query string) (*sql.Stmt, error) Prepare(query string) (*sql.Stmt, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
Exec(query string, args ...interface{}) (sql.Result, error) Exec(query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error) Query(query string, args ...interface{}) (*sql.Rows, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row QueryRow(query string, args ...interface{}) *sql.Row
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
} }
// type DB interface { // type DB interface {
@ -397,6 +426,7 @@ type dbQuerier interface {
// transaction beginner // transaction beginner
type txer interface { type txer interface {
Begin() (*sql.Tx, error) Begin() (*sql.Tx, error)
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
} }
// transaction ending // transaction ending

View File

@ -23,6 +23,18 @@ import (
"time" "time"
) )
type fn func(string) string
var (
nameStrategyMap = map[string]fn{
defaultNameStrategy: snakeString,
SnakeAcronymNameStrategy: snakeStringWithAcronym,
}
defaultNameStrategy = "snakeString"
SnakeAcronymNameStrategy = "snakeStringWithAcronym"
nameStrategy = defaultNameStrategy
)
// StrTo is the target string // StrTo is the target string
type StrTo string type StrTo string
@ -198,7 +210,28 @@ func ToInt64(value interface{}) (d int64) {
return return
} }
// snake string, XxYy to xx_yy , XxYY to xx_yy func snakeStringWithAcronym(s string) string {
data := make([]byte, 0, len(s)*2)
num := len(s)
for i := 0; i < num; i++ {
d := s[i]
before := false
after := false
if i > 0 {
before = s[i-1] >= 'a' && s[i-1] <= 'z'
}
if i+1 < num {
after = s[i+1] >= 'a' && s[i+1] <= 'z'
}
if i > 0 && d >= 'A' && d <= 'Z' && (before || after) {
data = append(data, '_')
}
data = append(data, d)
}
return strings.ToLower(string(data[:]))
}
// snake string, XxYy to xx_yy , XxYY to xx_y_y
func snakeString(s string) string { func snakeString(s string) string {
data := make([]byte, 0, len(s)*2) data := make([]byte, 0, len(s)*2)
j := false j := false
@ -216,6 +249,14 @@ func snakeString(s string) string {
return strings.ToLower(string(data[:])) return strings.ToLower(string(data[:]))
} }
// SetNameStrategy set different name strategy
func SetNameStrategy(s string) {
if SnakeAcronymNameStrategy != s {
nameStrategy = defaultNameStrategy
}
nameStrategy = s
}
// camel string, xx_yy to XxYy // camel string, xx_yy to XxYy
func camelString(s string) string { func camelString(s string) string {
data := make([]byte, 0, len(s)) data := make([]byte, 0, len(s))

View File

@ -34,3 +34,37 @@ func TestCamelString(t *testing.T) {
} }
} }
} }
func TestSnakeString(t *testing.T) {
camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"}
snake := []string{"pic_url", "hello_world", "hello_world", "hel_l_o_word", "pic_url1", "xy_x_x"}
answer := make(map[string]string)
for i, v := range camel {
answer[v] = snake[i]
}
for _, v := range camel {
res := snakeString(v)
if res != answer[v] {
t.Error("Unit Test Fail:", v, res, answer[v])
}
}
}
func TestSnakeStringWithAcronym(t *testing.T) {
camel := []string{"ID", "PicURL", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"}
snake := []string{"id", "pic_url", "hello_world", "hello_world", "hel_lo_word", "pic_url1", "xy_xx"}
answer := make(map[string]string)
for i, v := range camel {
answer[v] = snake[i]
}
for _, v := range camel {
res := snakeStringWithAcronym(v)
if res != answer[v] {
t.Error("Unit Test Fail:", v, res, answer[v])
}
}
}

277
parser.go
View File

@ -39,7 +39,7 @@ var globalRouterTemplate = `package routers
import ( import (
"github.com/astaxie/beego" "github.com/astaxie/beego"
"github.com/astaxie/beego/context/param" "github.com/astaxie/beego/context/param"{{.globalimport}}
) )
func init() { func init() {
@ -52,6 +52,22 @@ var (
commentFilename string commentFilename string
pkgLastupdate map[string]int64 pkgLastupdate map[string]int64
genInfoList map[string][]ControllerComments genInfoList map[string][]ControllerComments
routerHooks = map[string]int{
"beego.BeforeStatic": BeforeStatic,
"beego.BeforeRouter": BeforeRouter,
"beego.BeforeExec": BeforeExec,
"beego.AfterExec": AfterExec,
"beego.FinishRouter": FinishRouter,
}
routerHooksMapping = map[int]string{
BeforeStatic: "beego.BeforeStatic",
BeforeRouter: "beego.BeforeRouter",
BeforeExec: "beego.BeforeExec",
AfterExec: "beego.AfterExec",
FinishRouter: "beego.FinishRouter",
}
) )
const commentPrefix = "commentsRouter_" const commentPrefix = "commentsRouter_"
@ -102,6 +118,20 @@ type parsedComment struct {
routerPath string routerPath string
methods []string methods []string
params map[string]parsedParam params map[string]parsedParam
filters []parsedFilter
imports []parsedImport
}
type parsedImport struct {
importPath string
importAlias string
}
type parsedFilter struct {
pattern string
pos int
filter string
params []bool
} }
type parsedParam struct { type parsedParam struct {
@ -114,24 +144,69 @@ type parsedParam struct {
func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error { func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error {
if f.Doc != nil { if f.Doc != nil {
parsedComment, err := parseComment(f.Doc.List) parsedComments, err := parseComment(f.Doc.List)
if err != nil { if err != nil {
return err return err
} }
if parsedComment.routerPath != "" { for _, parsedComment := range parsedComments {
key := pkgpath + ":" + controllerName if parsedComment.routerPath != "" {
cc := ControllerComments{} key := pkgpath + ":" + controllerName
cc.Method = f.Name.String() cc := ControllerComments{}
cc.Router = parsedComment.routerPath cc.Method = f.Name.String()
cc.AllowHTTPMethods = parsedComment.methods cc.Router = parsedComment.routerPath
cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment) cc.AllowHTTPMethods = parsedComment.methods
genInfoList[key] = append(genInfoList[key], cc) cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment)
cc.FilterComments = buildFilters(parsedComment.filters)
cc.ImportComments = buildImports(parsedComment.imports)
genInfoList[key] = append(genInfoList[key], cc)
}
} }
} }
return nil return nil
} }
func buildImports(pis []parsedImport) []*ControllerImportComments {
var importComments []*ControllerImportComments
for _, pi := range pis {
importComments = append(importComments, &ControllerImportComments{
ImportPath: pi.importPath,
ImportAlias: pi.importAlias,
})
}
return importComments
}
func buildFilters(pfs []parsedFilter) []*ControllerFilterComments {
var filterComments []*ControllerFilterComments
for _, pf := range pfs {
var (
returnOnOutput bool
resetParams bool
)
if len(pf.params) >= 1 {
returnOnOutput = pf.params[0]
}
if len(pf.params) >= 2 {
resetParams = pf.params[1]
}
filterComments = append(filterComments, &ControllerFilterComments{
Filter: pf.filter,
Pattern: pf.pattern,
Pos: pf.pos,
ReturnOnOutput: returnOnOutput,
ResetParams: resetParams,
})
}
return filterComments
}
func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam { func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam {
result := make([]*param.MethodParam, 0, len(funcParams)) result := make([]*param.MethodParam, 0, len(funcParams))
for _, fparam := range funcParams { for _, fparam := range funcParams {
@ -177,26 +252,15 @@ func paramInPath(name, route string) bool {
var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`) var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`)
func parseComment(lines []*ast.Comment) (pc *parsedComment, err error) { func parseComment(lines []*ast.Comment) (pcs []*parsedComment, err error) {
pc = &parsedComment{} pcs = []*parsedComment{}
params := map[string]parsedParam{}
filters := []parsedFilter{}
imports := []parsedImport{}
for _, c := range lines { for _, c := range lines {
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
if strings.HasPrefix(t, "@router") { if strings.HasPrefix(t, "@Param") {
matches := routeRegex.FindStringSubmatch(t)
if len(matches) == 3 {
pc.routerPath = matches[1]
methods := matches[2]
if methods == "" {
pc.methods = []string{"get"}
//pc.hasGet = true
} else {
pc.methods = strings.Split(methods, ",")
//pc.hasGet = strings.Contains(methods, "get")
}
} else {
return nil, errors.New("Router information is missing")
}
} else if strings.HasPrefix(t, "@Param") {
pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param"))) pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param")))
if len(pv) < 4 { if len(pv) < 4 {
logs.Error("Invalid @Param format. Needs at least 4 parameters") logs.Error("Invalid @Param format. Needs at least 4 parameters")
@ -217,17 +281,99 @@ func parseComment(lines []*ast.Comment) (pc *parsedComment, err error) {
p.defValue = pv[3] p.defValue = pv[3]
p.required, _ = strconv.ParseBool(pv[4]) p.required, _ = strconv.ParseBool(pv[4])
} }
if pc.params == nil { params[funcParamName] = p
pc.params = map[string]parsedParam{} }
}
for _, c := range lines {
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
if strings.HasPrefix(t, "@Import") {
iv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Import")))
if len(iv) == 0 || len(iv) > 2 {
logs.Error("Invalid @Import format. Only accepts 1 or 2 parameters")
continue
}
p := parsedImport{}
p.importPath = iv[0]
if len(iv) == 2 {
p.importAlias = iv[1]
}
imports = append(imports, p)
}
}
filterLoop:
for _, c := range lines {
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
if strings.HasPrefix(t, "@Filter") {
fv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Filter")))
if len(fv) < 3 {
logs.Error("Invalid @Filter format. Needs at least 3 parameters")
continue filterLoop
}
p := parsedFilter{}
p.pattern = fv[0]
posName := fv[1]
if pos, exists := routerHooks[posName]; exists {
p.pos = pos
} else {
logs.Error("Invalid @Filter pos: ", posName)
continue filterLoop
}
p.filter = fv[2]
fvParams := fv[3:]
for _, fvParam := range fvParams {
switch fvParam {
case "true":
p.params = append(p.params, true)
case "false":
p.params = append(p.params, false)
default:
logs.Error("Invalid @Filter param: ", fvParam)
continue filterLoop
}
}
filters = append(filters, p)
}
}
for _, c := range lines {
var pc = &parsedComment{}
pc.params = params
pc.filters = filters
pc.imports = imports
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
if strings.HasPrefix(t, "@router") {
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
matches := routeRegex.FindStringSubmatch(t)
if len(matches) == 3 {
pc.routerPath = matches[1]
methods := matches[2]
if methods == "" {
pc.methods = []string{"get"}
//pc.hasGet = true
} else {
pc.methods = strings.Split(methods, ",")
//pc.hasGet = strings.Contains(methods, "get")
}
pcs = append(pcs, pc)
} else {
return nil, errors.New("Router information is missing")
} }
pc.params[funcParamName] = p
} }
} }
return return
} }
// direct copy from bee\g_docs.go // direct copy from bee\g_docs.go
// analisys params return []string // analysis params return []string
// @Param query form string true "The email for login" // @Param query form string true "The email for login"
// [query form string true "The email for login"] // [query form string true "The email for login"]
func getparams(str string) []string { func getparams(str string) []string {
@ -266,8 +412,9 @@ func genRouterCode(pkgRealpath string) {
os.Mkdir(getRouterDir(pkgRealpath), 0755) os.Mkdir(getRouterDir(pkgRealpath), 0755)
logs.Info("generate router from comments") logs.Info("generate router from comments")
var ( var (
globalinfo string globalinfo string
sortKey []string globalimport string
sortKey []string
) )
for k := range genInfoList { for k := range genInfoList {
sortKey = append(sortKey, k) sortKey = append(sortKey, k)
@ -285,6 +432,7 @@ func genRouterCode(pkgRealpath string) {
} }
allmethod = strings.TrimRight(allmethod, ",") + "}" allmethod = strings.TrimRight(allmethod, ",") + "}"
} }
params := "nil" params := "nil"
if len(c.Params) > 0 { if len(c.Params) > 0 {
params = "[]map[string]string{" params = "[]map[string]string{"
@ -295,6 +443,7 @@ func genRouterCode(pkgRealpath string) {
} }
params = strings.TrimRight(params, ",") + "}" params = strings.TrimRight(params, ",") + "}"
} }
methodParams := "param.Make(" methodParams := "param.Make("
if len(c.MethodParams) > 0 { if len(c.MethodParams) > 0 {
lines := make([]string, 0, len(c.MethodParams)) lines := make([]string, 0, len(c.MethodParams))
@ -306,24 +455,66 @@ func genRouterCode(pkgRealpath string) {
",\n " ",\n "
} }
methodParams += ")" methodParams += ")"
imports := ""
if len(c.ImportComments) > 0 {
for _, i := range c.ImportComments {
if i.ImportAlias != "" {
imports += fmt.Sprintf(`
%s "%s"`, i.ImportAlias, i.ImportPath)
} else {
imports += fmt.Sprintf(`
"%s"`, i.ImportPath)
}
}
}
filters := ""
if len(c.FilterComments) > 0 {
for _, f := range c.FilterComments {
filters += fmt.Sprintf(` &beego.ControllerFilter{
Pattern: "%s",
Pos: %s,
Filter: %s,
ReturnOnOutput: %v,
ResetParams: %v,
},`, f.Pattern, routerHooksMapping[f.Pos], f.Filter, f.ReturnOnOutput, f.ResetParams)
}
}
if filters == "" {
filters = "nil"
} else {
filters = fmt.Sprintf(`[]*beego.ControllerFilter{
%s
}`, filters)
}
globalimport = imports
globalinfo = globalinfo + ` globalinfo = globalinfo + `
beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"],
beego.ControllerComments{ beego.ControllerComments{
Method: "` + strings.TrimSpace(c.Method) + `", Method: "` + strings.TrimSpace(c.Method) + `",
` + "Router: `" + c.Router + "`" + `, ` + "Router: `" + c.Router + "`" + `,
AllowHTTPMethods: ` + allmethod + `, AllowHTTPMethods: ` + allmethod + `,
MethodParams: ` + methodParams + `, MethodParams: ` + methodParams + `,
Params: ` + params + `}) Filters: ` + filters + `,
Params: ` + params + `})
` `
} }
} }
if globalinfo != "" { if globalinfo != "" {
f, err := os.Create(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) f, err := os.Create(filepath.Join(getRouterDir(pkgRealpath), commentFilename))
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer f.Close() defer f.Close()
f.WriteString(strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1))
content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1)
content = strings.Replace(content, "{{.globalimport}}", globalimport, -1)
f.WriteString(content)
} }
} }

View File

@ -72,8 +72,8 @@ import (
// AppIDToAppSecret is used to get appsecret throw appid // AppIDToAppSecret is used to get appsecret throw appid
type AppIDToAppSecret func(string) string type AppIDToAppSecret func(string) string
// APIBaiscAuth use the basic appid/appkey as the AppIdToAppSecret // APIBasicAuth use the basic appid/appkey as the AppIdToAppSecret
func APIBaiscAuth(appid, appkey string) beego.FilterFunc { func APIBasicAuth(appid, appkey string) beego.FilterFunc {
ft := func(aid string) string { ft := func(aid string) string {
if aid == appid { if aid == appid {
return appkey return appkey
@ -83,6 +83,11 @@ func APIBaiscAuth(appid, appkey string) beego.FilterFunc {
return APISecretAuth(ft, 300) return APISecretAuth(ft, 300)
} }
// APIBaiscAuth calls APIBasicAuth for previous callers
func APIBaiscAuth(appid, appkey string) beego.FilterFunc {
return APIBaiscAuth(appid, appkey)
}
// APISecretAuth use AppIdToAppSecret verify and // APISecretAuth use AppIdToAppSecret verify and
func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc { func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc {
return func(ctx *context.Context) { return func(ctx *context.Context) {

220
router.go
View File

@ -15,6 +15,7 @@
package beego package beego
import ( import (
"errors"
"fmt" "fmt"
"net/http" "net/http"
"path" "path"
@ -43,35 +44,35 @@ const (
) )
const ( const (
routerTypeBeego = iota routerTypeBeego = iota
routerTypeRESTFul routerTypeRESTFul
routerTypeHandler routerTypeHandler
) )
var ( var (
// HTTPMETHOD list the supported http methods. // HTTPMETHOD list the supported http methods.
HTTPMETHOD = map[string]string{ HTTPMETHOD = map[string]bool{
"GET": "GET", "GET": true,
"POST": "POST", "POST": true,
"PUT": "PUT", "PUT": true,
"DELETE": "DELETE", "DELETE": true,
"PATCH": "PATCH", "PATCH": true,
"OPTIONS": "OPTIONS", "OPTIONS": true,
"HEAD": "HEAD", "HEAD": true,
"TRACE": "TRACE", "TRACE": true,
"CONNECT": "CONNECT", "CONNECT": true,
"MKCOL": "MKCOL", "MKCOL": true,
"COPY": "COPY", "COPY": true,
"MOVE": "MOVE", "MOVE": true,
"PROPFIND": "PROPFIND", "PROPFIND": true,
"PROPPATCH": "PROPPATCH", "PROPPATCH": true,
"LOCK": "LOCK", "LOCK": true,
"UNLOCK": "UNLOCK", "UNLOCK": true,
} }
// these beego.Controller's methods shouldn't reflect to AutoRouter // these beego.Controller's methods shouldn't reflect to AutoRouter
exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString", exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString",
"RenderBytes", "Redirect", "Abort", "StopRun", "UrlFor", "ServeJSON", "ServeJSONP", "RenderBytes", "Redirect", "Abort", "StopRun", "UrlFor", "ServeJSON", "ServeJSONP",
"ServeXML", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool", "ServeYAML", "ServeXML", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool",
"GetFloat", "GetFile", "SaveToFile", "StartSession", "SetSession", "GetSession", "GetFloat", "GetFile", "SaveToFile", "StartSession", "SetSession", "GetSession",
"DelSession", "SessionRegenerateID", "DestroySession", "IsAjax", "GetSecureCookie", "DelSession", "SessionRegenerateID", "DestroySession", "IsAjax", "GetSecureCookie",
"SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml", "SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml",
@ -133,14 +134,15 @@ type ControllerRegister struct {
// NewControllerRegister returns a new ControllerRegister. // NewControllerRegister returns a new ControllerRegister.
func NewControllerRegister() *ControllerRegister { func NewControllerRegister() *ControllerRegister {
cr := &ControllerRegister{ return &ControllerRegister{
routers: make(map[string]*Tree), routers: make(map[string]*Tree),
policies: make(map[string]*Tree), policies: make(map[string]*Tree),
pool: sync.Pool{
New: func() interface{} {
return beecontext.NewContext()
},
},
} }
cr.pool.New = func() interface{} {
return beecontext.NewContext()
}
return cr
} }
// Add controller handler and pattern rules to ControllerRegister. // Add controller handler and pattern rules to ControllerRegister.
@ -170,7 +172,7 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt
} }
comma := strings.Split(colon[0], ",") comma := strings.Split(colon[0], ",")
for _, m := range comma { for _, m := range comma {
if _, ok := HTTPMETHOD[strings.ToUpper(m)]; m == "*" || ok { if m == "*" || HTTPMETHOD[strings.ToUpper(m)] {
if val := reflectVal.MethodByName(colon[1]); val.IsValid() { if val := reflectVal.MethodByName(colon[1]); val.IsValid() {
methods[strings.ToUpper(m)] = colon[1] methods[strings.ToUpper(m)] = colon[1]
} else { } else {
@ -201,9 +203,12 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt
numOfFields := elemVal.NumField() numOfFields := elemVal.NumField()
for i := 0; i < numOfFields; i++ { for i := 0; i < numOfFields; i++ {
fieldVal := elemVal.Field(i)
fieldType := elemType.Field(i) fieldType := elemType.Field(i)
execElem.FieldByName(fieldType.Name).Set(fieldVal) elemField := execElem.FieldByName(fieldType.Name)
if elemField.CanSet() {
fieldVal := elemVal.Field(i)
elemField.Set(fieldVal)
}
} }
return execController return execController
@ -211,13 +216,13 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt
route.methodParams = methodParams route.methodParams = methodParams
if len(methods) == 0 { if len(methods) == 0 {
for _, m := range HTTPMETHOD { for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route) p.addToRouter(m, pattern, route)
} }
} else { } else {
for k := range methods { for k := range methods {
if k == "*" { if k == "*" {
for _, m := range HTTPMETHOD { for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route) p.addToRouter(m, pattern, route)
} }
} else { } else {
@ -274,6 +279,10 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) {
key := t.PkgPath() + ":" + t.Name() key := t.PkgPath() + ":" + t.Name()
if comm, ok := GlobalControllerRouter[key]; ok { if comm, ok := GlobalControllerRouter[key]; ok {
for _, a := range comm { for _, a := range comm {
for _, f := range a.Filters {
p.InsertFilter(f.Pattern, f.Pos, f.Filter, f.ReturnOnOutput, f.ResetParams)
}
p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)
} }
} }
@ -359,7 +368,7 @@ func (p *ControllerRegister) Any(pattern string, f FilterFunc) {
// }) // })
func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
method = strings.ToUpper(method) method = strings.ToUpper(method)
if _, ok := HTTPMETHOD[method]; method != "*" && !ok { if method != "*" && !HTTPMETHOD[method] {
panic("not support http method: " + method) panic("not support http method: " + method)
} }
route := &ControllerInfo{} route := &ControllerInfo{}
@ -368,7 +377,7 @@ func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
route.runFunction = f route.runFunction = f
methods := make(map[string]string) methods := make(map[string]string)
if method == "*" { if method == "*" {
for _, val := range HTTPMETHOD { for val := range HTTPMETHOD {
methods[val] = val methods[val] = val
} }
} else { } else {
@ -377,7 +386,7 @@ func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
route.methods = methods route.methods = methods
for k := range methods { for k := range methods {
if k == "*" { if k == "*" {
for _, m := range HTTPMETHOD { for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route) p.addToRouter(m, pattern, route)
} }
} else { } else {
@ -397,7 +406,7 @@ func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...
pattern = path.Join(pattern, "?:all(.*)") pattern = path.Join(pattern, "?:all(.*)")
} }
} }
for _, m := range HTTPMETHOD { for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route) p.addToRouter(m, pattern, route)
} }
} }
@ -432,7 +441,7 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface)
patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name)) patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name))
patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name) patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name)
route.pattern = pattern route.pattern = pattern
for _, m := range HTTPMETHOD { for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route) p.addToRouter(m, pattern, route)
p.addToRouter(m, patternInit, route) p.addToRouter(m, patternInit, route)
p.addToRouter(m, patternFix, route) p.addToRouter(m, patternFix, route)
@ -471,8 +480,7 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter
// add Filter into // add Filter into
func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) { func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) {
if pos < BeforeStatic || pos > FinishRouter { if pos < BeforeStatic || pos > FinishRouter {
err = fmt.Errorf("can not find your filter position") return errors.New("can not find your filter position")
return
} }
p.enableFilter = true p.enableFilter = true
p.filters[pos] = append(p.filters[pos], mr) p.filters[pos] = append(p.filters[pos], mr)
@ -502,10 +510,10 @@ func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) stri
} }
} }
} }
controllName := strings.Join(paths[:len(paths)-1], "/") controllerName := strings.Join(paths[:len(paths)-1], "/")
methodName := paths[len(paths)-1] methodName := paths[len(paths)-1]
for m, t := range p.routers { for m, t := range p.routers {
ok, url := p.geturl(t, "/", controllName, methodName, params, m) ok, url := p.getURL(t, "/", controllerName, methodName, params, m)
if ok { if ok {
return url return url
} }
@ -513,17 +521,17 @@ func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) stri
return "" return ""
} }
func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName string, params map[string]string, httpMethod string) (bool, string) { func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName string, params map[string]string, httpMethod string) (bool, string) {
for _, subtree := range t.fixrouters { for _, subtree := range t.fixrouters {
u := path.Join(url, subtree.prefix) u := path.Join(url, subtree.prefix)
ok, u := p.geturl(subtree, u, controllName, methodName, params, httpMethod) ok, u := p.getURL(subtree, u, controllerName, methodName, params, httpMethod)
if ok { if ok {
return ok, u return ok, u
} }
} }
if t.wildcard != nil { if t.wildcard != nil {
u := path.Join(url, urlPlaceholder) u := path.Join(url, urlPlaceholder)
ok, u := p.geturl(t.wildcard, u, controllName, methodName, params, httpMethod) ok, u := p.getURL(t.wildcard, u, controllerName, methodName, params, httpMethod)
if ok { if ok {
return ok, u return ok, u
} }
@ -531,9 +539,9 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin
for _, l := range t.leaves { for _, l := range t.leaves {
if c, ok := l.runObject.(*ControllerInfo); ok { if c, ok := l.runObject.(*ControllerInfo); ok {
if c.routerType == routerTypeBeego && if c.routerType == routerTypeBeego &&
strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllName) { strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) {
find := false find := false
if _, ok := HTTPMETHOD[strings.ToUpper(methodName)]; ok { if HTTPMETHOD[strings.ToUpper(methodName)] {
if len(c.methods) == 0 { if len(c.methods) == 0 {
find = true find = true
} else if m, ok := c.methods[strings.ToUpper(methodName)]; ok && m == strings.ToUpper(methodName) { } else if m, ok := c.methods[strings.ToUpper(methodName)]; ok && m == strings.ToUpper(methodName) {
@ -570,18 +578,18 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin
} }
} }
} }
canskip := false canSkip := false
for _, v := range l.wildcards { for _, v := range l.wildcards {
if v == ":" { if v == ":" {
canskip = true canSkip = true
continue continue
} }
if u, ok := params[v]; ok { if u, ok := params[v]; ok {
delete(params, v) delete(params, v)
url = strings.Replace(url, urlPlaceholder, u, 1) url = strings.Replace(url, urlPlaceholder, u, 1)
} else { } else {
if canskip { if canSkip {
canskip = false canSkip = false
continue continue
} }
return false, "" return false, ""
@ -590,27 +598,27 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin
return true, url + toURL(params) return true, url + toURL(params)
} }
var i int var i int
var startreg bool var startReg bool
regurl := "" regURL := ""
for _, v := range strings.Trim(l.regexps.String(), "^$") { for _, v := range strings.Trim(l.regexps.String(), "^$") {
if v == '(' { if v == '(' {
startreg = true startReg = true
continue continue
} else if v == ')' { } else if v == ')' {
startreg = false startReg = false
if v, ok := params[l.wildcards[i]]; ok { if v, ok := params[l.wildcards[i]]; ok {
delete(params, l.wildcards[i]) delete(params, l.wildcards[i])
regurl = regurl + v regURL = regURL + v
i++ i++
} else { } else {
break break
} }
} else if !startreg { } else if !startReg {
regurl = string(append([]rune(regurl), v)) regURL = string(append([]rune(regURL), v))
} }
} }
if l.regexps.MatchString(regurl) { if l.regexps.MatchString(regURL) {
ps := strings.Split(regurl, "/") ps := strings.Split(regURL, "/")
for _, p := range ps { for _, p := range ps {
url = strings.Replace(url, urlPlaceholder, p, 1) url = strings.Replace(url, urlPlaceholder, p, 1)
} }
@ -681,8 +689,8 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
// filter wrong http method // filter wrong http method
if _, ok := HTTPMETHOD[r.Method]; !ok { if !HTTPMETHOD[r.Method] {
http.Error(rw, "Method Not Allowed", 405) http.Error(rw, "Method Not Allowed", http.StatusMethodNotAllowed)
goto Admin goto Admin
} }
@ -791,7 +799,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if !isRunnable { if !isRunnable {
//Invoke the request handler //Invoke the request handler
var execController ControllerInterface var execController ControllerInterface
if routerInfo.initialize != nil { if routerInfo != nil && routerInfo.initialize != nil {
execController = routerInfo.initialize() execController = routerInfo.initialize()
} else { } else {
vc := reflect.New(runRouter) vc := reflect.New(runRouter)
@ -874,15 +882,18 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
Admin: Admin:
//admin module record QPS //admin module record QPS
statusCode := context.ResponseWriter.Status statusCode := context.ResponseWriter.Status
if statusCode == 0 { if statusCode == 0 {
statusCode = 200 statusCode = 200
} }
logAccess(context, &startTime, statusCode)
timeDur := time.Since(startTime)
context.ResponseWriter.Elapsed = timeDur
if BConfig.Listen.EnableAdmin { if BConfig.Listen.EnableAdmin {
timeDur := time.Since(startTime)
pattern := "" pattern := ""
if routerInfo != nil { if routerInfo != nil {
pattern = routerInfo.pattern pattern = routerInfo.pattern
@ -897,49 +908,29 @@ Admin:
} }
} }
if BConfig.RunMode == DEV || BConfig.Log.AccessLogs { if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs {
timeDur := time.Since(startTime)
var devInfo string var devInfo string
iswin := (runtime.GOOS == "windows") iswin := (runtime.GOOS == "windows")
statusColor := logs.ColorByStatus(iswin, statusCode) statusColor := logs.ColorByStatus(iswin, statusCode)
methodColor := logs.ColorByMethod(iswin, r.Method) methodColor := logs.ColorByMethod(iswin, r.Method)
resetColor := logs.ColorByMethod(iswin, "") resetColor := logs.ColorByMethod(iswin, "")
if BConfig.Log.AccessLogsFormat != "" { if findRouter {
record := &logs.AccessLogRecord{ if routerInfo != nil {
RemoteAddr: context.Input.IP(), devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s r:%s", context.Input.IP(), statusColor, statusCode,
RequestTime: startTime, resetColor, timeDur.String(), "match", methodColor, r.Method, resetColor, r.URL.Path,
RequestMethod: r.Method, routerInfo.pattern)
Request: fmt.Sprintf("%s %s %s", r.Method, r.RequestURI, r.Proto),
ServerProtocol: r.Proto,
Host: r.Host,
Status: statusCode,
ElapsedTime: timeDur,
HttpReferrer: r.Header.Get("Referer"),
HttpUserAgent: r.Header.Get("User-Agent"),
RemoteUser: r.Header.Get("Remote-User"),
BodyBytesSent: 0, //@todo this one is missing!
}
logs.AccessLog(record, BConfig.Log.AccessLogsFormat)
}else {
if findRouter {
if routerInfo != nil {
devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s r:%s", context.Input.IP(), statusColor, statusCode,
resetColor, timeDur.String(), "match", methodColor, r.Method, resetColor, r.URL.Path,
routerInfo.pattern)
} else {
devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", context.Input.IP(), statusColor, statusCode, resetColor,
timeDur.String(), "match", methodColor, r.Method, resetColor, r.URL.Path)
}
} else { } else {
devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", context.Input.IP(), statusColor, statusCode, resetColor, devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", context.Input.IP(), statusColor, statusCode, resetColor,
timeDur.String(), "nomatch", methodColor, r.Method, resetColor, r.URL.Path) timeDur.String(), "match", methodColor, r.Method, resetColor, r.URL.Path)
}
if iswin {
logs.W32Debug(devInfo)
} else {
logs.Debug(devInfo)
} }
} else {
devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", context.Input.IP(), statusColor, statusCode, resetColor,
timeDur.String(), "nomatch", methodColor, r.Method, resetColor, r.URL.Path)
}
if iswin {
logs.W32Debug(devInfo)
} else {
logs.Debug(devInfo)
} }
} }
// Call WriteHeader if status code has been set changed // Call WriteHeader if status code has been set changed
@ -957,7 +948,7 @@ func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, ex
context.RenderMethodResult(resultValue) context.RenderMethodResult(resultValue)
} }
} }
if !context.ResponseWriter.Started && context.Output.Status == 0 { if !context.ResponseWriter.Started && len(results) > 0 && context.Output.Status == 0 {
context.Output.SetStatus(200) context.Output.SetStatus(200)
} }
} }
@ -988,3 +979,38 @@ func toURL(params map[string]string) string {
} }
return strings.TrimRight(u, "&") return strings.TrimRight(u, "&")
} }
func logAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) {
//Skip logging if AccessLogs config is false
if !BConfig.Log.AccessLogs {
return
}
//Skip logging static requests unless EnableStaticLogs config is true
if !BConfig.Log.EnableStaticLogs && DefaultAccessLogFilter.Filter(ctx) {
return
}
var (
requestTime time.Time
elapsedTime time.Duration
r = ctx.Request
)
if startTime != nil {
requestTime = *startTime
elapsedTime = time.Since(*startTime)
}
record := &logs.AccessLogRecord{
RemoteAddr: ctx.Input.IP(),
RequestTime: requestTime,
RequestMethod: r.Method,
Request: fmt.Sprintf("%s %s %s", r.Method, r.RequestURI, r.Proto),
ServerProtocol: r.Proto,
Host: r.Host,
Status: statusCode,
ElapsedTime: elapsedTime,
HTTPReferrer: r.Header.Get("Referer"),
HTTPUserAgent: r.Header.Get("User-Agent"),
RemoteUser: r.Header.Get("Remote-User"),
BodyBytesSent: 0, //@todo this one is missing!
}
logs.AccessLog(record, BConfig.Log.AccessLogsFormat)
}

View File

@ -71,10 +71,6 @@ func (tc *TestController) GetEmptyBody() {
tc.Ctx.Output.Body(res) tc.Ctx.Output.Body(res)
} }
type ResStatus struct {
Code int
Msg string
}
type JSONController struct { type JSONController struct {
Controller Controller
@ -475,7 +471,7 @@ func TestParamResetFilter(t *testing.T) {
// a response header of `Splat`. The expectation here is that that Header // a response header of `Splat`. The expectation here is that that Header
// value should match what the _request's_ router set, not the filter's. // value should match what the _request's_ router set, not the filter's.
headers := rw.HeaderMap headers := rw.Result().Header
if len(headers["Splat"]) != 1 { if len(headers["Splat"]) != 1 {
t.Errorf( t.Errorf(
"%s: There was an error in the test. Splat param not set in Header", "%s: There was an error in the test. Splat param not set in Header",
@ -660,25 +656,16 @@ func beegoBeforeRouter1(ctx *context.Context) {
ctx.WriteString("|BeforeRouter1") ctx.WriteString("|BeforeRouter1")
} }
func beegoBeforeRouter2(ctx *context.Context) {
ctx.WriteString("|BeforeRouter2")
}
func beegoBeforeExec1(ctx *context.Context) { func beegoBeforeExec1(ctx *context.Context) {
ctx.WriteString("|BeforeExec1") ctx.WriteString("|BeforeExec1")
} }
func beegoBeforeExec2(ctx *context.Context) {
ctx.WriteString("|BeforeExec2")
}
func beegoAfterExec1(ctx *context.Context) { func beegoAfterExec1(ctx *context.Context) {
ctx.WriteString("|AfterExec1") ctx.WriteString("|AfterExec1")
} }
func beegoAfterExec2(ctx *context.Context) {
ctx.WriteString("|AfterExec2")
}
func beegoFinishRouter1(ctx *context.Context) { func beegoFinishRouter1(ctx *context.Context) {
ctx.WriteString("|FinishRouter1") ctx.WriteString("|FinishRouter1")
@ -695,3 +682,30 @@ func beegoResetParams(ctx *context.Context) {
func beegoHandleResetParams(ctx *context.Context) { func beegoHandleResetParams(ctx *context.Context) {
ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat")) ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat"))
} }
// YAML
type YAMLController struct {
Controller
}
func (jc *YAMLController) Prepare() {
jc.Data["yaml"] = "prepare"
jc.ServeYAML()
}
func (jc *YAMLController) Get() {
jc.Data["Username"] = "astaxie"
jc.Ctx.Output.Body([]byte("ok"))
}
func TestYAMLPrepare(t *testing.T) {
r, _ := http.NewRequest("GET", "/yaml/list", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/yaml/list", &YAMLController{})
handler.ServeHTTP(w, r)
if strings.TrimSpace(w.Body.String()) != "prepare" {
t.Errorf(w.Body.String())
}
}

View File

@ -133,7 +133,7 @@ func (lp *Provider) SessionRead(sid string) (session.Store, error) {
// SessionExist check ledis session exist by sid // SessionExist check ledis session exist by sid
func (lp *Provider) SessionExist(sid string) bool { func (lp *Provider) SessionExist(sid string) bool {
count, _ := c.Exists([]byte(sid)) count, _ := c.Exists([]byte(sid))
return !(count == 0) return count != 0
} }
// SessionRegenerate generate new sid for ledis session // SessionRegenerate generate new sid for ledis session

View File

@ -128,9 +128,12 @@ func (rp *MemProvider) SessionRead(sid string) (session.Store, error) {
} }
} }
item, err := client.Get(sid) item, err := client.Get(sid)
if err != nil && err == memcache.ErrCacheMiss { if err != nil {
rs := &SessionStore{sid: sid, values: make(map[interface{}]interface{}), maxlifetime: rp.maxlifetime} if err == memcache.ErrCacheMiss {
return rs, nil rs := &SessionStore{sid: sid, values: make(map[interface{}]interface{}), maxlifetime: rp.maxlifetime}
return rs, nil
}
return nil, err
} }
var kv map[interface{}]interface{} var kv map[interface{}]interface{}
if len(item.Value) == 0 { if len(item.Value) == 0 {

View File

@ -170,7 +170,7 @@ func (mp *Provider) SessionExist(sid string) bool {
row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid)
var sessiondata []byte var sessiondata []byte
err := row.Scan(&sessiondata) err := row.Scan(&sessiondata)
return !(err == sql.ErrNoRows) return err != sql.ErrNoRows
} }
// SessionRegenerate generate new sid for mysql session // SessionRegenerate generate new sid for mysql session

View File

@ -184,7 +184,7 @@ func (mp *Provider) SessionExist(sid string) bool {
row := c.QueryRow("select session_data from session where session_key=$1", sid) row := c.QueryRow("select session_data from session where session_key=$1", sid)
var sessiondata []byte var sessiondata []byte
err := row.Scan(&sessiondata) err := row.Scan(&sessiondata)
return !(err == sql.ErrNoRows) return err != sql.ErrNoRows
} }
// SessionRegenerate generate new sid for postgresql session // SessionRegenerate generate new sid for postgresql session

View File

@ -14,9 +14,9 @@
// Package redis for session provider // Package redis for session provider
// //
// depend on github.com/garyburd/redigo/redis // depend on github.com/gomodule/redigo/redis
// //
// go install github.com/garyburd/redigo/redis // go install github.com/gomodule/redigo/redis
// //
// Usage: // Usage:
// import( // import(
@ -24,10 +24,10 @@
// "github.com/astaxie/beego/session" // "github.com/astaxie/beego/session"
// ) // )
// //
// func init() { // func init() {
// globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``) // globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``)
// go globalSessions.GC() // go globalSessions.GC()
// } // }
// //
// more docs: http://beego.me/docs/module/session.md // more docs: http://beego.me/docs/module/session.md
package redis package redis
@ -37,10 +37,11 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time"
"github.com/astaxie/beego/session" "github.com/astaxie/beego/session"
"github.com/garyburd/redigo/redis" "github.com/gomodule/redigo/redis"
) )
var redispder = &Provider{} var redispder = &Provider{}
@ -118,8 +119,8 @@ type Provider struct {
} }
// SessionInit init redis session // SessionInit init redis session
// savepath like redis server addr,pool size,password,dbnum // savepath like redis server addr,pool size,password,dbnum,IdleTimeout second
// e.g. 127.0.0.1:6379,100,astaxie,0 // e.g. 127.0.0.1:6379,100,astaxie,0,30
func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",") configs := strings.Split(savePath, ",")
@ -149,27 +150,39 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
} else { } else {
rp.dbNum = 0 rp.dbNum = 0
} }
rp.poollist = redis.NewPool(func() (redis.Conn, error) { var idleTimeout time.Duration = 0
c, err := redis.Dial("tcp", rp.savePath) if len(configs) > 4 {
if err != nil { timeout, err := strconv.Atoi(configs[4])
return nil, err if err == nil && timeout > 0 {
idleTimeout = time.Duration(timeout) * time.Second
} }
if rp.password != "" { }
if _, err = c.Do("AUTH", rp.password); err != nil { rp.poollist = &redis.Pool{
c.Close() Dial: func() (redis.Conn, error) {
return nil, err c, err := redis.Dial("tcp", rp.savePath)
}
}
//some redis proxy such as twemproxy is not support select command
if rp.dbNum > 0 {
_, err = c.Do("SELECT", rp.dbNum)
if err != nil { if err != nil {
c.Close()
return nil, err return nil, err
} }
} if rp.password != "" {
return c, err if _, err = c.Do("AUTH", rp.password); err != nil {
}, rp.poolsize) c.Close()
return nil, err
}
}
// some redis proxy such as twemproxy is not support select command
if rp.dbNum > 0 {
_, err = c.Do("SELECT", rp.dbNum)
if err != nil {
c.Close()
return nil, err
}
}
return c, err
},
MaxIdle: rp.poolsize,
}
rp.poollist.IdleTimeout = idleTimeout
return rp.poollist.Get().Err() return rp.poollist.Get().Err()
} }

View File

@ -0,0 +1,220 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package redis for session provider
//
// depend on github.com/go-redis/redis
//
// go install github.com/go-redis/redis
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/redis_cluster"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("redis_cluster", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070;127.0.0.1:7071"}``)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package redis_cluster
import (
"net/http"
"strconv"
"strings"
"sync"
"github.com/astaxie/beego/session"
rediss "github.com/go-redis/redis"
"time"
)
var redispder = &Provider{}
// MaxPoolSize redis_cluster max pool size
var MaxPoolSize = 1000
// SessionStore redis_cluster session store
type SessionStore struct {
p *rediss.ClusterClient
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Set value in redis_cluster session
func (rs *SessionStore) Set(key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
return nil
}
// Get value in redis_cluster session
func (rs *SessionStore) Get(key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
return v
}
return nil
}
// Delete value in redis_cluster session
func (rs *SessionStore) Delete(key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
return nil
}
// Flush clear all values in redis_cluster session
func (rs *SessionStore) Flush() error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
return nil
}
// SessionID get redis_cluster session id
func (rs *SessionStore) SessionID() string {
return rs.sid
}
// SessionRelease save session values to redis_cluster
func (rs *SessionStore) SessionRelease(w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
}
c := rs.p
c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime) * time.Second)
}
// Provider redis_cluster session provider
type Provider struct {
maxlifetime int64
savePath string
poolsize int
password string
dbNum int
poollist *rediss.ClusterClient
}
// SessionInit init redis_cluster session
// savepath like redis server addr,pool size,password,dbnum
// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0
func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
rp.savePath = configs[0]
}
if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize < 0 {
rp.poolsize = MaxPoolSize
} else {
rp.poolsize = poolsize
}
} else {
rp.poolsize = MaxPoolSize
}
if len(configs) > 2 {
rp.password = configs[2]
}
if len(configs) > 3 {
dbnum, err := strconv.Atoi(configs[3])
if err != nil || dbnum < 0 {
rp.dbNum = 0
} else {
rp.dbNum = dbnum
}
} else {
rp.dbNum = 0
}
rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{
Addrs: strings.Split(rp.savePath, ";"),
Password: rp.password,
PoolSize: rp.poolsize,
})
return rp.poollist.Ping().Err()
}
// SessionRead read redis_cluster session by sid
func (rp *Provider) SessionRead(sid string) (session.Store, error) {
var kv map[interface{}]interface{}
kvs, err := rp.poollist.Get(sid).Result()
if err != nil && err != rediss.Nil {
return nil, err
}
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
if kv, err = session.DecodeGob([]byte(kvs)); err != nil {
return nil, err
}
}
rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
// SessionExist check redis_cluster session exist by sid
func (rp *Provider) SessionExist(sid string) bool {
c := rp.poollist
if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 {
return false
}
return true
}
// SessionRegenerate generate new sid for redis_cluster session
func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
c := rp.poollist
if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
c.Set(sid, "", time.Duration(rp.maxlifetime) * time.Second)
} else {
c.Rename(oldsid, sid)
c.Expire(sid, time.Duration(rp.maxlifetime) * time.Second)
}
return rp.SessionRead(sid)
}
// SessionDestroy delete redis session by id
func (rp *Provider) SessionDestroy(sid string) error {
c := rp.poollist
c.Del(sid)
return nil
}
// SessionGC Impelment method, no used.
func (rp *Provider) SessionGC() {
}
// SessionAll return all activeSession
func (rp *Provider) SessionAll() int {
return 0
}
func init() {
session.Register("redis_cluster", redispder)
}

View File

@ -0,0 +1,234 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package redis for session provider
//
// depend on github.com/go-redis/redis
//
// go install github.com/go-redis/redis
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/redis_sentinel"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("redis_sentinel", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:26379;127.0.0.2:26379"}``)
// go globalSessions.GC()
// }
//
// more detail about params: please check the notes on the function SessionInit in this package
package redis_sentinel
import (
"github.com/astaxie/beego/session"
"github.com/go-redis/redis"
"net/http"
"strconv"
"strings"
"sync"
"time"
)
var redispder = &Provider{}
// DefaultPoolSize redis_sentinel default pool size
var DefaultPoolSize = 100
// SessionStore redis_sentinel session store
type SessionStore struct {
p *redis.Client
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Set value in redis_sentinel session
func (rs *SessionStore) Set(key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
return nil
}
// Get value in redis_sentinel session
func (rs *SessionStore) Get(key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
return v
}
return nil
}
// Delete value in redis_sentinel session
func (rs *SessionStore) Delete(key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
return nil
}
// Flush clear all values in redis_sentinel session
func (rs *SessionStore) Flush() error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
return nil
}
// SessionID get redis_sentinel session id
func (rs *SessionStore) SessionID() string {
return rs.sid
}
// SessionRelease save session values to redis_sentinel
func (rs *SessionStore) SessionRelease(w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
}
c := rs.p
c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second)
}
// Provider redis_sentinel session provider
type Provider struct {
maxlifetime int64
savePath string
poolsize int
password string
dbNum int
poollist *redis.Client
masterName string
}
// SessionInit init redis_sentinel session
// savepath like redis sentinel addr,pool size,password,dbnum,masterName
// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster
func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
rp.savePath = configs[0]
}
if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize < 0 {
rp.poolsize = DefaultPoolSize
} else {
rp.poolsize = poolsize
}
} else {
rp.poolsize = DefaultPoolSize
}
if len(configs) > 2 {
rp.password = configs[2]
}
if len(configs) > 3 {
dbnum, err := strconv.Atoi(configs[3])
if err != nil || dbnum < 0 {
rp.dbNum = 0
} else {
rp.dbNum = dbnum
}
} else {
rp.dbNum = 0
}
if len(configs) > 4 {
if configs[4] != "" {
rp.masterName = configs[4]
} else {
rp.masterName = "mymaster"
}
} else {
rp.masterName = "mymaster"
}
rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{
SentinelAddrs: strings.Split(rp.savePath, ";"),
Password: rp.password,
PoolSize: rp.poolsize,
DB: rp.dbNum,
MasterName: rp.masterName,
})
return rp.poollist.Ping().Err()
}
// SessionRead read redis_sentinel session by sid
func (rp *Provider) SessionRead(sid string) (session.Store, error) {
var kv map[interface{}]interface{}
kvs, err := rp.poollist.Get(sid).Result()
if err != nil && err != redis.Nil {
return nil, err
}
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
if kv, err = session.DecodeGob([]byte(kvs)); err != nil {
return nil, err
}
}
rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
// SessionExist check redis_sentinel session exist by sid
func (rp *Provider) SessionExist(sid string) bool {
c := rp.poollist
if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 {
return false
}
return true
}
// SessionRegenerate generate new sid for redis_sentinel session
func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
c := rp.poollist
if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second)
} else {
c.Rename(oldsid, sid)
c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second)
}
return rp.SessionRead(sid)
}
// SessionDestroy delete redis session by id
func (rp *Provider) SessionDestroy(sid string) error {
c := rp.poollist
c.Del(sid)
return nil
}
// SessionGC Impelment method, no used.
func (rp *Provider) SessionGC() {
}
// SessionAll return all activeSession
func (rp *Provider) SessionAll() int {
return 0
}
func init() {
session.Register("redis_sentinel", redispder)
}

View File

@ -0,0 +1,90 @@
package redis_sentinel
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/astaxie/beego/session"
)
func TestRedisSentinel(t *testing.T) {
sessionConfig := &session.ManagerConfig{
CookieName: "gosessionid",
EnableSetCookie: true,
Gclifetime: 3600,
Maxlifetime: 3600,
Secure: false,
CookieLifeTime: 3600,
ProviderConfig: "127.0.0.1:6379,100,,0,master",
}
globalSessions, e := session.NewManager("redis_sentinel", sessionConfig)
if e != nil {
t.Log(e)
return
}
//todo test if e==nil
go globalSessions.GC()
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess, err := globalSessions.SessionStart(w, r)
if err != nil {
t.Fatal("session start failed:", err)
}
defer sess.SessionRelease(w)
// SET AND GET
err = sess.Set("username", "astaxie")
if err != nil {
t.Fatal("set username failed:", err)
}
username := sess.Get("username")
if username != "astaxie" {
t.Fatal("get username failed")
}
// DELETE
err = sess.Delete("username")
if err != nil {
t.Fatal("delete username failed:", err)
}
username = sess.Get("username")
if username != nil {
t.Fatal("delete username failed")
}
// FLUSH
err = sess.Set("username", "astaxie")
if err != nil {
t.Fatal("set failed:", err)
}
err = sess.Set("password", "1qaz2wsx")
if err != nil {
t.Fatal("set failed:", err)
}
username = sess.Get("username")
if username != "astaxie" {
t.Fatal("get username failed")
}
password := sess.Get("password")
if password != "1qaz2wsx" {
t.Fatal("get password failed")
}
err = sess.Flush()
if err != nil {
t.Fatal("flush failed:", err)
}
username = sess.Get("username")
if username != nil {
t.Fatal("flush failed")
}
password = sess.Get("password")
if password != nil {
t.Fatal("flush failed")
}
sess.SessionRelease(w)
}

View File

@ -21,6 +21,7 @@ import (
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"strings"
"sync" "sync"
"time" "time"
) )
@ -78,6 +79,8 @@ func (fs *FileSessionStore) SessionID() string {
// SessionRelease Write file session to local file with Gob string // SessionRelease Write file session to local file with Gob string
func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
b, err := EncodeGob(fs.values) b, err := EncodeGob(fs.values)
if err != nil { if err != nil {
SLogger.Println(err) SLogger.Println(err)
@ -125,6 +128,9 @@ func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error {
// if file is not exist, create it. // if file is not exist, create it.
// the file path is generated from sid string. // the file path is generated from sid string.
func (fp *FileProvider) SessionRead(sid string) (Store, error) { func (fp *FileProvider) SessionRead(sid string) (Store, error) {
if strings.ContainsAny(sid, "./") {
return nil, nil
}
filepder.lock.Lock() filepder.lock.Lock()
defer filepder.lock.Unlock() defer filepder.lock.Unlock()
@ -164,7 +170,7 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) {
} }
// SessionExist Check file session exist. // SessionExist Check file session exist.
// it checkes the file named from sid exist or not. // it checks the file named from sid exist or not.
func (fp *FileProvider) SessionExist(sid string) bool { func (fp *FileProvider) SessionExist(sid string) bool {
filepder.lock.Lock() filepder.lock.Lock()
defer filepder.lock.Unlock() defer filepder.lock.Unlock()

View File

@ -81,6 +81,15 @@ func Register(name string, provide Provider) {
provides[name] = provide provides[name] = provide
} }
//GetProvider
func GetProvider(name string) (Provider, error) {
provider, ok := provides[name]
if !ok {
return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", name)
}
return provider, nil
}
// ManagerConfig define the session config // ManagerConfig define the session config
type ManagerConfig struct { type ManagerConfig struct {
CookieName string `json:"cookieName"` CookieName string `json:"cookieName"`
@ -96,6 +105,7 @@ type ManagerConfig struct {
EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"` EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"`
SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"` SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"`
EnableSidInURLQuery bool `json:"EnableSidInURLQuery"` EnableSidInURLQuery bool `json:"EnableSidInURLQuery"`
SessionIDPrefix string `json:"sessionIDPrefix"`
} }
// Manager contains Provider and its configuration. // Manager contains Provider and its configuration.
@ -153,6 +163,11 @@ func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) {
}, nil }, nil
} }
// GetProvider return current manager's provider
func (manager *Manager) GetProvider() Provider {
return manager.provider
}
// getSid retrieves session identifier from HTTP Request. // getSid retrieves session identifier from HTTP Request.
// First try to retrieve id by reading from cookie, session cookie name is configurable, // First try to retrieve id by reading from cookie, session cookie name is configurable,
// if not exist, then retrieve id from querying parameters. // if not exist, then retrieve id from querying parameters.
@ -331,7 +346,7 @@ func (manager *Manager) sessionID() (string, error) {
if n != len(b) || err != nil { if n != len(b) || err != nil {
return "", fmt.Errorf("Could not successfully read from the system CSPRNG") return "", fmt.Errorf("Could not successfully read from the system CSPRNG")
} }
return hex.EncodeToString(b), nil return manager.config.SessionIDPrefix + hex.EncodeToString(b), nil
} }
// Set cookie with https. // Set cookie with https.

View File

@ -74,7 +74,7 @@ func serverStaticRouter(ctx *context.Context) {
if enableCompress { if enableCompress {
acceptEncoding = context.ParseEncoding(ctx.Request) acceptEncoding = context.ParseEncoding(ctx.Request)
} }
b, n, sch, err := openFile(filePath, fileInfo, acceptEncoding) b, n, sch, reader, err := openFile(filePath, fileInfo, acceptEncoding)
if err != nil { if err != nil {
if BConfig.RunMode == DEV { if BConfig.RunMode == DEV {
logs.Warn("Can't compress the file:", filePath, err) logs.Warn("Can't compress the file:", filePath, err)
@ -89,47 +89,53 @@ func serverStaticRouter(ctx *context.Context) {
ctx.Output.Header("Content-Length", strconv.FormatInt(sch.size, 10)) ctx.Output.Header("Content-Length", strconv.FormatInt(sch.size, 10))
} }
http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, sch.modTime, sch) http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, sch.modTime, reader)
} }
type serveContentHolder struct { type serveContentHolder struct {
*bytes.Reader data []byte
modTime time.Time modTime time.Time
size int64 size int64
encoding string encoding string
} }
type serveContentReader struct {
*bytes.Reader
}
var ( var (
staticFileMap = make(map[string]*serveContentHolder) staticFileMap = make(map[string]*serveContentHolder)
mapLock sync.RWMutex mapLock sync.RWMutex
) )
func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, error) { func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, *serveContentReader, error) {
mapKey := acceptEncoding + ":" + filePath mapKey := acceptEncoding + ":" + filePath
mapLock.RLock() mapLock.RLock()
mapFile := staticFileMap[mapKey] mapFile := staticFileMap[mapKey]
mapLock.RUnlock() mapLock.RUnlock()
if isOk(mapFile, fi) { if isOk(mapFile, fi) {
return mapFile.encoding != "", mapFile.encoding, mapFile, nil reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)}
return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil
} }
mapLock.Lock() mapLock.Lock()
defer mapLock.Unlock() defer mapLock.Unlock()
if mapFile = staticFileMap[mapKey]; !isOk(mapFile, fi) { if mapFile = staticFileMap[mapKey]; !isOk(mapFile, fi) {
file, err := os.Open(filePath) file, err := os.Open(filePath)
if err != nil { if err != nil {
return false, "", nil, err return false, "", nil, nil, err
} }
defer file.Close() defer file.Close()
var bufferWriter bytes.Buffer var bufferWriter bytes.Buffer
_, n, err := context.WriteFile(acceptEncoding, &bufferWriter, file) _, n, err := context.WriteFile(acceptEncoding, &bufferWriter, file)
if err != nil { if err != nil {
return false, "", nil, err return false, "", nil, nil, err
} }
mapFile = &serveContentHolder{Reader: bytes.NewReader(bufferWriter.Bytes()), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), encoding: n} mapFile = &serveContentHolder{data: bufferWriter.Bytes(), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), encoding: n}
staticFileMap[mapKey] = mapFile staticFileMap[mapKey] = mapFile
} }
return mapFile.encoding != "", mapFile.encoding, mapFile, nil reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)}
return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil
} }
func isOk(s *serveContentHolder, fi os.FileInfo) bool { func isOk(s *serveContentHolder, fi os.FileInfo) bool {
@ -172,7 +178,7 @@ func searchFile(ctx *context.Context) (string, os.FileInfo, error) {
if !strings.Contains(requestPath, prefix) { if !strings.Contains(requestPath, prefix) {
continue continue
} }
if len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { if prefix != "/" && len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' {
continue continue
} }
filePath := path.Join(staticDir, requestPath[len(prefix):]) filePath := path.Join(staticDir, requestPath[len(prefix):])

View File

@ -16,7 +16,7 @@ var licenseFile = filepath.Join(currentWorkDir, "LICENSE")
func testOpenFile(encoding string, content []byte, t *testing.T) { func testOpenFile(encoding string, content []byte, t *testing.T) {
fi, _ := os.Stat(licenseFile) fi, _ := os.Stat(licenseFile)
b, n, sch, err := openFile(licenseFile, fi, encoding) b, n, sch, reader, err := openFile(licenseFile, fi, encoding)
if err != nil { if err != nil {
t.Log(err) t.Log(err)
t.Fail() t.Fail()
@ -24,7 +24,7 @@ func testOpenFile(encoding string, content []byte, t *testing.T) {
t.Log("open static file encoding "+n, b) t.Log("open static file encoding "+n, b)
assetOpenFileAndContent(sch, content, t) assetOpenFileAndContent(sch, reader, content, t)
} }
func TestOpenStaticFile_1(t *testing.T) { func TestOpenStaticFile_1(t *testing.T) {
file, _ := os.Open(licenseFile) file, _ := os.Open(licenseFile)
@ -53,13 +53,13 @@ func TestOpenStaticFileDeflate_1(t *testing.T) {
testOpenFile("deflate", content, t) testOpenFile("deflate", content, t)
} }
func assetOpenFileAndContent(sch *serveContentHolder, content []byte, t *testing.T) { func assetOpenFileAndContent(sch *serveContentHolder, reader *serveContentReader, content []byte, t *testing.T) {
t.Log(sch.size, len(content)) t.Log(sch.size, len(content))
if sch.size != int64(len(content)) { if sch.size != int64(len(content)) {
t.Log("static content file size not same") t.Log("static content file size not same")
t.Fail() t.Fail()
} }
bs, _ := ioutil.ReadAll(sch) bs, _ := ioutil.ReadAll(reader)
for i, v := range content { for i, v := range content {
if v != bs[i] { if v != bs[i] {
t.Log("content not same") t.Log("content not same")

View File

@ -121,6 +121,8 @@ type Schema struct {
Type string `json:"type,omitempty" yaml:"type,omitempty"` Type string `json:"type,omitempty" yaml:"type,omitempty"`
Items *Schema `json:"items,omitempty" yaml:"items,omitempty"` Items *Schema `json:"items,omitempty" yaml:"items,omitempty"`
Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"` Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"`
Enum []interface{} `json:"enum,omitempty" yaml:"enum,omitempty"`
Example interface{} `json:"example,omitempty" yaml:"example,omitempty"`
} }
// Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification // Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification
@ -130,7 +132,7 @@ type Propertie struct {
Description string `json:"description,omitempty" yaml:"description,omitempty"` Description string `json:"description,omitempty" yaml:"description,omitempty"`
Default interface{} `json:"default,omitempty" yaml:"default,omitempty"` Default interface{} `json:"default,omitempty" yaml:"default,omitempty"`
Type string `json:"type,omitempty" yaml:"type,omitempty"` Type string `json:"type,omitempty" yaml:"type,omitempty"`
Example string `json:"example,omitempty" yaml:"example,omitempty"` Example interface{} `json:"example,omitempty" yaml:"example,omitempty"`
Required []string `json:"required,omitempty" yaml:"required,omitempty"` Required []string `json:"required,omitempty" yaml:"required,omitempty"`
Format string `json:"format,omitempty" yaml:"format,omitempty"` Format string `json:"format,omitempty" yaml:"format,omitempty"`
ReadOnly bool `json:"readOnly,omitempty" yaml:"readOnly,omitempty"` ReadOnly bool `json:"readOnly,omitempty" yaml:"readOnly,omitempty"`
@ -141,7 +143,7 @@ type Propertie struct {
// Response as they are returned from executing this operation. // Response as they are returned from executing this operation.
type Response struct { type Response struct {
Description string `json:"description,omitempty" yaml:"description,omitempty"` Description string `json:"description" yaml:"description"`
Schema *Schema `json:"schema,omitempty" yaml:"schema,omitempty"` Schema *Schema `json:"schema,omitempty" yaml:"schema,omitempty"`
Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"`
} }

View File

@ -20,6 +20,7 @@ import (
"html/template" "html/template"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http"
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
@ -40,6 +41,7 @@ var (
beeTemplateExt = []string{"tpl", "html"} beeTemplateExt = []string{"tpl", "html"}
// beeTemplatePreprocessors stores associations of extension -> preprocessor handler // beeTemplatePreprocessors stores associations of extension -> preprocessor handler
beeTemplateEngines = map[string]templatePreProcessor{} beeTemplateEngines = map[string]templatePreProcessor{}
beeTemplateFS = defaultFSFunc
) )
// ExecuteTemplate applies the template with name to the specified data object, // ExecuteTemplate applies the template with name to the specified data object,
@ -181,12 +183,17 @@ func lockViewPaths() {
// BuildTemplate will build all template files in a directory. // BuildTemplate will build all template files in a directory.
// it makes beego can render any template file in view directory. // it makes beego can render any template file in view directory.
func BuildTemplate(dir string, files ...string) error { func BuildTemplate(dir string, files ...string) error {
if _, err := os.Stat(dir); err != nil { var err error
fs := beeTemplateFS()
f, err := fs.Open(dir)
if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
return nil return nil
} }
return errors.New("dir open err") return errors.New("dir open err")
} }
defer f.Close()
beeTemplates, ok := beeViewPathTemplates[dir] beeTemplates, ok := beeViewPathTemplates[dir]
if !ok { if !ok {
panic("Unknown view path: " + dir) panic("Unknown view path: " + dir)
@ -195,11 +202,11 @@ func BuildTemplate(dir string, files ...string) error {
root: dir, root: dir,
files: make(map[string][]string), files: make(map[string][]string),
} }
err := filepath.Walk(dir, func(path string, f os.FileInfo, err error) error { err = Walk(fs, dir, func(path string, f os.FileInfo, err error) error {
return self.visit(path, f, err) return self.visit(path, f, err)
}) })
if err != nil { if err != nil {
fmt.Printf("filepath.Walk() returned %v\n", err) fmt.Printf("Walk() returned %v\n", err)
return err return err
} }
buildAllFiles := len(files) == 0 buildAllFiles := len(files) == 0
@ -210,18 +217,18 @@ func BuildTemplate(dir string, files ...string) error {
ext := filepath.Ext(file) ext := filepath.Ext(file)
var t *template.Template var t *template.Template
if len(ext) == 0 { if len(ext) == 0 {
t, err = getTemplate(self.root, file, v...) t, err = getTemplate(self.root, fs, file, v...)
} else if fn, ok := beeTemplateEngines[ext[1:]]; ok { } else if fn, ok := beeTemplateEngines[ext[1:]]; ok {
t, err = fn(self.root, file, beegoTplFuncMap) t, err = fn(self.root, file, beegoTplFuncMap)
} else { } else {
t, err = getTemplate(self.root, file, v...) t, err = getTemplate(self.root, fs, file, v...)
} }
if err != nil { if err != nil {
logs.Error("parse template err:", file, err) logs.Error("parse template err:", file, err)
templatesLock.Unlock()
return err return err
} else {
beeTemplates[file] = t
} }
beeTemplates[file] = t
templatesLock.Unlock() templatesLock.Unlock()
} }
} }
@ -229,20 +236,23 @@ func BuildTemplate(dir string, files ...string) error {
return nil return nil
} }
func getTplDeep(root, file, parent string, t *template.Template) (*template.Template, [][]string, error) { func getTplDeep(root string, fs http.FileSystem, file string, parent string, t *template.Template) (*template.Template, [][]string, error) {
var fileAbsPath string var fileAbsPath string
var rParent string var rParent string
if filepath.HasPrefix(file, "../") { var err error
if strings.HasPrefix(file, "../") {
rParent = filepath.Join(filepath.Dir(parent), file) rParent = filepath.Join(filepath.Dir(parent), file)
fileAbsPath = filepath.Join(root, filepath.Dir(parent), file) fileAbsPath = filepath.Join(root, filepath.Dir(parent), file)
} else { } else {
rParent = file rParent = file
fileAbsPath = filepath.Join(root, file) fileAbsPath = filepath.Join(root, file)
} }
if e := utils.FileExists(fileAbsPath); !e { f, err := fs.Open(fileAbsPath)
if err != nil {
panic("can't find template file:" + file) panic("can't find template file:" + file)
} }
data, err := ioutil.ReadFile(fileAbsPath) defer f.Close()
data, err := ioutil.ReadAll(f)
if err != nil { if err != nil {
return nil, [][]string{}, err return nil, [][]string{}, err
} }
@ -261,7 +271,7 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp
if !HasTemplateExt(m[1]) { if !HasTemplateExt(m[1]) {
continue continue
} }
_, _, err = getTplDeep(root, m[1], rParent, t) _, _, err = getTplDeep(root, fs, m[1], rParent, t)
if err != nil { if err != nil {
return nil, [][]string{}, err return nil, [][]string{}, err
} }
@ -270,14 +280,14 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp
return t, allSub, nil return t, allSub, nil
} }
func getTemplate(root, file string, others ...string) (t *template.Template, err error) { func getTemplate(root string, fs http.FileSystem, file string, others ...string) (t *template.Template, err error) {
t = template.New(file).Delims(BConfig.WebConfig.TemplateLeft, BConfig.WebConfig.TemplateRight).Funcs(beegoTplFuncMap) t = template.New(file).Delims(BConfig.WebConfig.TemplateLeft, BConfig.WebConfig.TemplateRight).Funcs(beegoTplFuncMap)
var subMods [][]string var subMods [][]string
t, subMods, err = getTplDeep(root, file, "", t) t, subMods, err = getTplDeep(root, fs, file, "", t)
if err != nil { if err != nil {
return nil, err return nil, err
} }
t, err = _getTemplate(t, root, subMods, others...) t, err = _getTemplate(t, root, fs, subMods, others...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -285,7 +295,7 @@ func getTemplate(root, file string, others ...string) (t *template.Template, err
return return
} }
func _getTemplate(t0 *template.Template, root string, subMods [][]string, others ...string) (t *template.Template, err error) { func _getTemplate(t0 *template.Template, root string, fs http.FileSystem, subMods [][]string, others ...string) (t *template.Template, err error) {
t = t0 t = t0
for _, m := range subMods { for _, m := range subMods {
if len(m) == 2 { if len(m) == 2 {
@ -297,11 +307,11 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others
for _, otherFile := range others { for _, otherFile := range others {
if otherFile == m[1] { if otherFile == m[1] {
var subMods1 [][]string var subMods1 [][]string
t, subMods1, err = getTplDeep(root, otherFile, "", t) t, subMods1, err = getTplDeep(root, fs, otherFile, "", t)
if err != nil { if err != nil {
logs.Trace("template parse file err:", err) logs.Trace("template parse file err:", err)
} else if len(subMods1) > 0 { } else if len(subMods1) > 0 {
t, err = _getTemplate(t, root, subMods1, others...) t, err = _getTemplate(t, root, fs, subMods1, others...)
} }
break break
} }
@ -310,8 +320,16 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others
for _, otherFile := range others { for _, otherFile := range others {
var data []byte var data []byte
fileAbsPath := filepath.Join(root, otherFile) fileAbsPath := filepath.Join(root, otherFile)
data, err = ioutil.ReadFile(fileAbsPath) f, err := fs.Open(fileAbsPath)
if err != nil { if err != nil {
f.Close()
logs.Trace("template file parse error, not success open file:", err)
continue
}
data, err = ioutil.ReadAll(f)
f.Close()
if err != nil {
logs.Trace("template file parse error, not success read file:", err)
continue continue
} }
reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*define[ ]+\"([^\"]+)\"") reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*define[ ]+\"([^\"]+)\"")
@ -319,11 +337,14 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others
for _, sub := range allSub { for _, sub := range allSub {
if len(sub) == 2 && sub[1] == m[1] { if len(sub) == 2 && sub[1] == m[1] {
var subMods1 [][]string var subMods1 [][]string
t, subMods1, err = getTplDeep(root, otherFile, "", t) t, subMods1, err = getTplDeep(root, fs, otherFile, "", t)
if err != nil { if err != nil {
logs.Trace("template parse file err:", err) logs.Trace("template parse file err:", err)
} else if len(subMods1) > 0 { } else if len(subMods1) > 0 {
t, err = _getTemplate(t, root, subMods1, others...) t, err = _getTemplate(t, root, fs, subMods1, others...)
if err != nil {
logs.Trace("template parse file err:", err)
}
} }
break break
} }
@ -335,6 +356,17 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others
return return
} }
type templateFSFunc func() http.FileSystem
func defaultFSFunc() http.FileSystem {
return FileSystem{}
}
// SetTemplateFSFunc set default filesystem function
func SetTemplateFSFunc(fnt templateFSFunc) {
beeTemplateFS = fnt
}
// SetViewsPath sets view directory path in beego application. // SetViewsPath sets view directory path in beego application.
func SetViewsPath(path string) *App { func SetViewsPath(path string) *App {
BConfig.WebConfig.ViewsPath = path BConfig.WebConfig.ViewsPath = path

View File

@ -16,6 +16,9 @@ package beego
import ( import (
"bytes" "bytes"
"github.com/astaxie/beego/testdata"
"github.com/elazarl/go-bindata-assetfs"
"net/http"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@ -256,3 +259,58 @@ func TestTemplateLayout(t *testing.T) {
} }
os.RemoveAll(dir) os.RemoveAll(dir)
} }
type TestingFileSystem struct {
assetfs *assetfs.AssetFS
}
func (d TestingFileSystem) Open(name string) (http.File, error) {
return d.assetfs.Open(name)
}
var outputBinData = `<!DOCTYPE html>
<html>
<head>
<title>beego welcome template</title>
</head>
<body>
<h1>Hello, blocks!</h1>
<h1>Hello, astaxie!</h1>
<h2>Hello</h2>
<p> This is SomeVar: val</p>
</body>
</html>
`
func TestFsBinData(t *testing.T) {
SetTemplateFSFunc(func() http.FileSystem {
return TestingFileSystem{&assetfs.AssetFS{Asset: testdata.Asset, AssetDir: testdata.AssetDir, AssetInfo: testdata.AssetInfo}}
})
dir := "views"
if err := AddViewPath("views"); err != nil {
t.Fatal(err)
}
beeTemplates := beeViewPathTemplates[dir]
if len(beeTemplates) != 3 {
t.Fatalf("should be 3 but got %v", len(beeTemplates))
}
if err := beeTemplates["index.tpl"].ExecuteTemplate(os.Stdout, "index.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil {
t.Fatal(err)
}
out := bytes.NewBufferString("")
if err := beeTemplates["index.tpl"].ExecuteTemplate(out, "index.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil {
t.Fatal(err)
}
if out.String() != outputBinData {
t.Log(out.String())
t.Fatal("Compare failed")
}
}

View File

@ -17,6 +17,7 @@ package beego
import ( import (
"errors" "errors"
"fmt" "fmt"
"html"
"html/template" "html/template"
"net/url" "net/url"
"reflect" "reflect"
@ -54,21 +55,21 @@ func Substr(s string, start, length int) string {
// HTML2str returns escaping text convert from html. // HTML2str returns escaping text convert from html.
func HTML2str(html string) string { func HTML2str(html string) string {
re, _ := regexp.Compile(`\<[\S\s]+?\>`) re := regexp.MustCompile(`\<[\S\s]+?\>`)
html = re.ReplaceAllStringFunc(html, strings.ToLower) html = re.ReplaceAllStringFunc(html, strings.ToLower)
//remove STYLE //remove STYLE
re, _ = regexp.Compile(`\<style[\S\s]+?\</style\>`) re = regexp.MustCompile(`\<style[\S\s]+?\</style\>`)
html = re.ReplaceAllString(html, "") html = re.ReplaceAllString(html, "")
//remove SCRIPT //remove SCRIPT
re, _ = regexp.Compile(`\<script[\S\s]+?\</script\>`) re = regexp.MustCompile(`\<script[\S\s]+?\</script\>`)
html = re.ReplaceAllString(html, "") html = re.ReplaceAllString(html, "")
re, _ = regexp.Compile(`\<[\S\s]+?\>`) re = regexp.MustCompile(`\<[\S\s]+?\>`)
html = re.ReplaceAllString(html, "\n") html = re.ReplaceAllString(html, "\n")
re, _ = regexp.Compile(`\s{2,}`) re = regexp.MustCompile(`\s{2,}`)
html = re.ReplaceAllString(html, "\n") html = re.ReplaceAllString(html, "\n")
return strings.TrimSpace(html) return strings.TrimSpace(html)
@ -171,7 +172,7 @@ func GetConfig(returnType, key string, defaultVal interface{}) (value interface{
case "DIY": case "DIY":
value, err = AppConfig.DIY(key) value, err = AppConfig.DIY(key)
default: default:
err = errors.New("Config keys must be of type String, Bool, Int, Int64, Float, or DIY") err = errors.New("config keys must be of type String, Bool, Int, Int64, Float, or DIY")
} }
if err != nil { if err != nil {
@ -207,14 +208,12 @@ func Htmlquote(text string) string {
'&lt;&#39;&amp;&quot;&gt;' '&lt;&#39;&amp;&quot;&gt;'
*/ */
text = strings.Replace(text, "&", "&amp;", -1) // Must be done first! text = html.EscapeString(text)
text = strings.Replace(text, "<", "&lt;", -1) text = strings.NewReplacer(
text = strings.Replace(text, ">", "&gt;", -1) ``, "&ldquo;",
text = strings.Replace(text, "'", "&#39;", -1) ``, "&rdquo;",
text = strings.Replace(text, "\"", "&quot;", -1) ` `, "&nbsp;",
text = strings.Replace(text, "“", "&ldquo;", -1) ).Replace(text)
text = strings.Replace(text, "”", "&rdquo;", -1)
text = strings.Replace(text, " ", "&nbsp;", -1)
return strings.TrimSpace(text) return strings.TrimSpace(text)
} }
@ -228,17 +227,7 @@ func Htmlunquote(text string) string {
'<\\'&">' '<\\'&">'
*/ */
// strings.Replace(s, old, new, n) text = html.UnescapeString(text)
// 在s字符串中把old字符串替换为new字符串n表示替换的次数小于0表示全部替换
text = strings.Replace(text, "&nbsp;", " ", -1)
text = strings.Replace(text, "&rdquo;", "”", -1)
text = strings.Replace(text, "&ldquo;", "“", -1)
text = strings.Replace(text, "&quot;", "\"", -1)
text = strings.Replace(text, "&#39;", "'", -1)
text = strings.Replace(text, "&gt;", ">", -1)
text = strings.Replace(text, "&lt;", "<", -1)
text = strings.Replace(text, "&amp;", "&", -1) // Must be done last!
return strings.TrimSpace(text) return strings.TrimSpace(text)
} }
@ -308,10 +297,17 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e
tag = tags[0] tag = tags[0]
} }
value := form.Get(tag) formValues := form[tag]
if len(value) == 0 { var value string
if len(formValues) == 0 {
continue continue
} }
if len(formValues) == 1 {
value = formValues[0]
if value == "" {
continue
}
}
switch fieldT.Type.Kind() { switch fieldT.Type.Kind() {
case reflect.Bool: case reflect.Bool:
@ -703,7 +699,7 @@ func ge(arg1, arg2 interface{}) (bool, error) {
// MapGet getting value from map by keys // MapGet getting value from map by keys
// usage: // usage:
// Data["m"] = map[string]interface{} { // Data["m"] = M{
// "a": 1, // "a": 1,
// "1": map[string]float64{ // "1": map[string]float64{
// "c": 4, // "c": 4,

View File

@ -94,7 +94,7 @@ func TestCompareRelated(t *testing.T) {
} }
func TestHtmlquote(t *testing.T) { func TestHtmlquote(t *testing.T) {
h := `&lt;&#39;&nbsp;&rdquo;&ldquo;&amp;&quot;&gt;` h := `&lt;&#39;&nbsp;&rdquo;&ldquo;&amp;&#34;&gt;`
s := `<' ”“&">` s := `<' ”“&">`
if Htmlquote(s) != h { if Htmlquote(s) != h {
t.Error("should be equal") t.Error("should be equal")
@ -102,8 +102,8 @@ func TestHtmlquote(t *testing.T) {
} }
func TestHtmlunquote(t *testing.T) { func TestHtmlunquote(t *testing.T) {
h := `&lt;&#39;&nbsp;&rdquo;&ldquo;&amp;&quot;&gt;` h := `&lt;&#39;&nbsp;&rdquo;&ldquo;&amp;&#34;&gt;`
s := `<' ”“&">` s := `<' ”“&">`
if Htmlunquote(h) != s { if Htmlunquote(h) != s {
t.Error("should be equal") t.Error("should be equal")
} }
@ -111,7 +111,7 @@ func TestHtmlunquote(t *testing.T) {
func TestParseForm(t *testing.T) { func TestParseForm(t *testing.T) {
type ExtendInfo struct { type ExtendInfo struct {
Hobby string `form:"hobby"` Hobby []string `form:"hobby"`
Memo string Memo string
} }
@ -146,7 +146,7 @@ func TestParseForm(t *testing.T) {
"date": []string{"2014-11-12"}, "date": []string{"2014-11-12"},
"organization": []string{"beego"}, "organization": []string{"beego"},
"title": []string{"CXO"}, "title": []string{"CXO"},
"hobby": []string{"Basketball"}, "hobby": []string{"", "Basketball", "Football"},
"memo": []string{"nothing"}, "memo": []string{"nothing"},
} }
if err := ParseForm(form, u); err == nil { if err := ParseForm(form, u); err == nil {
@ -186,8 +186,14 @@ func TestParseForm(t *testing.T) {
if u.Title != "CXO" { if u.Title != "CXO" {
t.Errorf("Title should equal `CXO`, but got `%v`", u.Title) t.Errorf("Title should equal `CXO`, but got `%v`", u.Title)
} }
if u.Hobby != "Basketball" { if u.Hobby[0] != "" {
t.Errorf("Hobby should equal `Basketball`, but got `%v`", u.Hobby) t.Errorf("Hobby should equal ``, but got `%v`", u.Hobby[0])
}
if u.Hobby[1] != "Basketball" {
t.Errorf("Hobby should equal `Basketball`, but got `%v`", u.Hobby[1])
}
if u.Hobby[2] != "Football" {
t.Errorf("Hobby should equal `Football`, but got `%v`", u.Hobby[2])
} }
if len(u.Memo) != 0 { if len(u.Memo) != 0 {
t.Errorf("Memo's length should equal 0 but got %v", len(u.Memo)) t.Errorf("Memo's length should equal 0 but got %v", len(u.Memo))
@ -197,7 +203,6 @@ func TestParseForm(t *testing.T) {
func TestRenderForm(t *testing.T) { func TestRenderForm(t *testing.T) {
type user struct { type user struct {
ID int `form:"-"` ID int `form:"-"`
tag string `form:"tag"`
Name interface{} `form:"username"` Name interface{} `form:"username"`
Age int `form:"age,text,年龄:"` Age int `form:"age,text,年龄:"`
Sex string Sex string
@ -329,7 +334,7 @@ func TestMapGet(t *testing.T) {
} }
// test 2 level map // test 2 level map
m2 := map[string]interface{}{ m2 := M{
"1": map[string]float64{ "1": map[string]float64{
"2": 3.5, "2": 3.5,
}, },
@ -344,11 +349,11 @@ func TestMapGet(t *testing.T) {
} }
// test 5 level map // test 5 level map
m5 := map[string]interface{}{ m5 := M{
"1": map[string]interface{}{ "1": M{
"2": map[string]interface{}{ "2": M{
"3": map[string]interface{}{ "3": M{
"4": map[string]interface{}{ "4": M{
"5": 1.2, "5": 1.2,
}, },
}, },

2
testdata/Makefile vendored Normal file
View File

@ -0,0 +1,2 @@
build_view:
$(GOPATH)/bin/go-bindata-assetfs -pkg testdata views/...

296
testdata/bindata.go vendored Normal file
View File

@ -0,0 +1,296 @@
// Code generated by go-bindata.
// sources:
// views/blocks/block.tpl
// views/header.tpl
// views/index.tpl
// DO NOT EDIT!
package testdata
import (
"bytes"
"compress/gzip"
"fmt"
"github.com/elazarl/go-bindata-assetfs"
"io"
"io/ioutil"
"os"
"path/filepath"
"strings"
"time"
)
func bindataRead(data []byte, name string) ([]byte, error) {
gz, err := gzip.NewReader(bytes.NewBuffer(data))
if err != nil {
return nil, fmt.Errorf("Read %q: %v", name, err)
}
var buf bytes.Buffer
_, err = io.Copy(&buf, gz)
clErr := gz.Close()
if err != nil {
return nil, fmt.Errorf("Read %q: %v", name, err)
}
if clErr != nil {
return nil, err
}
return buf.Bytes(), nil
}
type asset struct {
bytes []byte
info os.FileInfo
}
type bindataFileInfo struct {
name string
size int64
mode os.FileMode
modTime time.Time
}
func (fi bindataFileInfo) Name() string {
return fi.name
}
func (fi bindataFileInfo) Size() int64 {
return fi.size
}
func (fi bindataFileInfo) Mode() os.FileMode {
return fi.mode
}
func (fi bindataFileInfo) ModTime() time.Time {
return fi.modTime
}
func (fi bindataFileInfo) IsDir() bool {
return false
}
func (fi bindataFileInfo) Sys() interface{} {
return nil
}
var _viewsBlocksBlockTpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xaa\xae\x4e\x49\x4d\xcb\xcc\x4b\x55\x50\x4a\xca\xc9\x4f\xce\x56\xaa\xad\xe5\xb2\xc9\x30\xb4\xf3\x48\xcd\xc9\xc9\xd7\x51\x00\x8b\x15\x2b\xda\xe8\x67\x18\xda\x71\x55\x57\xa7\xe6\xa5\xd4\xd6\x02\x02\x00\x00\xff\xff\xfd\xa1\x7a\xf6\x32\x00\x00\x00")
func viewsBlocksBlockTplBytes() ([]byte, error) {
return bindataRead(
_viewsBlocksBlockTpl,
"views/blocks/block.tpl",
)
}
func viewsBlocksBlockTpl() (*asset, error) {
bytes, err := viewsBlocksBlockTplBytes()
if err != nil {
return nil, err
}
info := bindataFileInfo{name: "views/blocks/block.tpl", size: 50, mode: os.FileMode(436), modTime: time.Unix(1541431067, 0)}
a := &asset{bytes: bytes, info: info}
return a, nil
}
var _viewsHeaderTpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xaa\xae\x4e\x49\x4d\xcb\xcc\x4b\x55\x50\xca\x48\x4d\x4c\x49\x2d\x52\xaa\xad\xe5\xb2\xc9\x30\xb4\xf3\x48\xcd\xc9\xc9\xd7\x51\x48\x2c\x2e\x49\xac\xc8\x4c\x55\xb4\xd1\xcf\x30\xb4\xe3\xaa\xae\x4e\xcd\x4b\xa9\xad\x05\x04\x00\x00\xff\xff\xe4\x12\x47\x01\x34\x00\x00\x00")
func viewsHeaderTplBytes() ([]byte, error) {
return bindataRead(
_viewsHeaderTpl,
"views/header.tpl",
)
}
func viewsHeaderTpl() (*asset, error) {
bytes, err := viewsHeaderTplBytes()
if err != nil {
return nil, err
}
info := bindataFileInfo{name: "views/header.tpl", size: 52, mode: os.FileMode(436), modTime: time.Unix(1541431067, 0)}
a := &asset{bytes: bytes, info: info}
return a, nil
}
var _viewsIndexTpl = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x64\x8f\xbd\x8a\xc3\x30\x10\x84\x6b\xeb\x29\xe6\xfc\x00\x16\xb8\x3c\x16\x35\x77\xa9\x13\x88\x09\xa4\xf4\xcf\x12\x99\x48\x48\xd8\x82\x10\x84\xde\x3d\xc8\x8a\x8b\x90\x6a\xa4\xd9\x6f\xd8\x59\xfa\xf9\x3f\xfe\x75\xd7\xd3\x01\x3a\x58\xa3\x04\x15\x01\x48\x73\x3f\xe5\x07\x40\x61\x0e\x86\xd5\xc0\x7c\x73\x78\xb0\x19\x9d\x65\x04\xb6\xde\xf4\x81\x49\x96\x69\x8e\xc8\x3d\x43\x83\x9b\x9e\x4a\x88\x2a\xc6\x9d\x43\x3d\x18\x37\xde\xeb\x94\x3e\xdd\x1c\xe1\xe5\xcb\xde\xe0\x55\x6e\xd2\x04\x6f\x32\x20\x2a\xd2\xad\x8a\x11\x4d\x97\x57\x22\x25\x92\xba\x55\xa2\x22\xaf\xd0\xe9\x79\xc5\xbc\xe2\xec\x2c\x5f\xfa\xe5\x17\x99\x7b\x7f\x36\xd2\x97\x8a\xa5\x19\xc9\x72\xe7\x2b\x00\x00\xff\xff\xb2\x39\xca\x9f\xff\x00\x00\x00")
func viewsIndexTplBytes() ([]byte, error) {
return bindataRead(
_viewsIndexTpl,
"views/index.tpl",
)
}
func viewsIndexTpl() (*asset, error) {
bytes, err := viewsIndexTplBytes()
if err != nil {
return nil, err
}
info := bindataFileInfo{name: "views/index.tpl", size: 255, mode: os.FileMode(436), modTime: time.Unix(1541434906, 0)}
a := &asset{bytes: bytes, info: info}
return a, nil
}
// Asset loads and returns the asset for the given name.
// It returns an error if the asset could not be found or
// could not be loaded.
func Asset(name string) ([]byte, error) {
cannonicalName := strings.Replace(name, "\\", "/", -1)
if f, ok := _bindata[cannonicalName]; ok {
a, err := f()
if err != nil {
return nil, fmt.Errorf("Asset %s can't read by error: %v", name, err)
}
return a.bytes, nil
}
return nil, fmt.Errorf("Asset %s not found", name)
}
// MustAsset is like Asset but panics when Asset would return an error.
// It simplifies safe initialization of global variables.
func MustAsset(name string) []byte {
a, err := Asset(name)
if err != nil {
panic("asset: Asset(" + name + "): " + err.Error())
}
return a
}
// AssetInfo loads and returns the asset info for the given name.
// It returns an error if the asset could not be found or
// could not be loaded.
func AssetInfo(name string) (os.FileInfo, error) {
cannonicalName := strings.Replace(name, "\\", "/", -1)
if f, ok := _bindata[cannonicalName]; ok {
a, err := f()
if err != nil {
return nil, fmt.Errorf("AssetInfo %s can't read by error: %v", name, err)
}
return a.info, nil
}
return nil, fmt.Errorf("AssetInfo %s not found", name)
}
// AssetNames returns the names of the assets.
func AssetNames() []string {
names := make([]string, 0, len(_bindata))
for name := range _bindata {
names = append(names, name)
}
return names
}
// _bindata is a table, holding each asset generator, mapped to its name.
var _bindata = map[string]func() (*asset, error){
"views/blocks/block.tpl": viewsBlocksBlockTpl,
"views/header.tpl": viewsHeaderTpl,
"views/index.tpl": viewsIndexTpl,
}
// AssetDir returns the file names below a certain
// directory embedded in the file by go-bindata.
// For example if you run go-bindata on data/... and data contains the
// following hierarchy:
// data/
// foo.txt
// img/
// a.png
// b.png
// then AssetDir("data") would return []string{"foo.txt", "img"}
// AssetDir("data/img") would return []string{"a.png", "b.png"}
// AssetDir("foo.txt") and AssetDir("notexist") would return an error
// AssetDir("") will return []string{"data"}.
func AssetDir(name string) ([]string, error) {
node := _bintree
if len(name) != 0 {
cannonicalName := strings.Replace(name, "\\", "/", -1)
pathList := strings.Split(cannonicalName, "/")
for _, p := range pathList {
node = node.Children[p]
if node == nil {
return nil, fmt.Errorf("Asset %s not found", name)
}
}
}
if node.Func != nil {
return nil, fmt.Errorf("Asset %s not found", name)
}
rv := make([]string, 0, len(node.Children))
for childName := range node.Children {
rv = append(rv, childName)
}
return rv, nil
}
type bintree struct {
Func func() (*asset, error)
Children map[string]*bintree
}
var _bintree = &bintree{nil, map[string]*bintree{
"views": &bintree{nil, map[string]*bintree{
"blocks": &bintree{nil, map[string]*bintree{
"block.tpl": &bintree{viewsBlocksBlockTpl, map[string]*bintree{}},
}},
"header.tpl": &bintree{viewsHeaderTpl, map[string]*bintree{}},
"index.tpl": &bintree{viewsIndexTpl, map[string]*bintree{}},
}},
}}
// RestoreAsset restores an asset under the given directory
func RestoreAsset(dir, name string) error {
data, err := Asset(name)
if err != nil {
return err
}
info, err := AssetInfo(name)
if err != nil {
return err
}
err = os.MkdirAll(_filePath(dir, filepath.Dir(name)), os.FileMode(0755))
if err != nil {
return err
}
err = ioutil.WriteFile(_filePath(dir, name), data, info.Mode())
if err != nil {
return err
}
err = os.Chtimes(_filePath(dir, name), info.ModTime(), info.ModTime())
if err != nil {
return err
}
return nil
}
// RestoreAssets restores an asset under the given directory recursively
func RestoreAssets(dir, name string) error {
children, err := AssetDir(name)
// File
if err != nil {
return RestoreAsset(dir, name)
}
// Dir
for _, child := range children {
err = RestoreAssets(dir, filepath.Join(name, child))
if err != nil {
return err
}
}
return nil
}
func _filePath(dir, name string) string {
cannonicalName := strings.Replace(name, "\\", "/", -1)
return filepath.Join(append([]string{dir}, strings.Split(cannonicalName, "/")...)...)
}
func assetFS() *assetfs.AssetFS {
assetInfo := func(path string) (os.FileInfo, error) {
return os.Stat(path)
}
for k := range _bintree.Children {
return &assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, AssetInfo: assetInfo, Prefix: k}
}
panic("unreachable")
}

3
testdata/views/blocks/block.tpl vendored Normal file
View File

@ -0,0 +1,3 @@
{{define "block"}}
<h1>Hello, blocks!</h1>
{{end}}

3
testdata/views/header.tpl vendored Normal file
View File

@ -0,0 +1,3 @@
{{define "header"}}
<h1>Hello, astaxie!</h1>
{{end}}

15
testdata/views/index.tpl vendored Normal file
View File

@ -0,0 +1,15 @@
<!DOCTYPE html>
<html>
<head>
<title>beego welcome template</title>
</head>
<body>
{{template "block"}}
{{template "header"}}
{{template "blocks/block.tpl"}}
<h2>{{ .Title }}</h2>
<p> This is SomeVar: {{ .SomeVar }}</p>
</body>
</html>

Some files were not shown because too many files have changed in this diff Show More