1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-25 19:10:54 +00:00

Merge pull request #2943 from astaxie/develop

1.9.2
This commit is contained in:
astaxie 2017-12-06 23:37:36 +08:00 committed by GitHub
commit bf5c5626ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 821 additions and 204 deletions

View File

@ -1,9 +1,9 @@
language: go language: go
go: go:
- 1.6.4
- 1.7.5 - 1.7.5
- 1.8.1 - 1.8.5
- 1.9.2
services: services:
- redis-server - redis-server
- mysql - mysql
@ -11,7 +11,6 @@ services:
- memcached - memcached
env: env:
- ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db - ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db
- ORM_DRIVER=mysql ORM_SOURCE="root:@/orm_test?charset=utf8"
- ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" - ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable"
before_install: before_install:
- git clone git://github.com/ideawu/ssdb.git - git clone git://github.com/ideawu/ssdb.git

View File

@ -1,4 +1,5 @@
# Beego [![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego) [![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego) [![Foundation](https://img.shields.io/badge/Golang-Foundation-green.svg)](http://golangfoundation.org) # Beego [![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego) [![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego) [![Foundation](https://img.shields.io/badge/Golang-Foundation-green.svg)](http://golangfoundation.org) [![Go Report Card](https://goreportcard.com/badge/github.com/astaxie/beego)](https://goreportcard.com/report/github.com/astaxie/beego)
beego is used for rapid development of RESTful APIs, web apps and backend services in Go. beego is used for rapid development of RESTful APIs, web apps and backend services in Go.
It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct embedding. It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct embedding.

View File

@ -67,6 +67,7 @@ func oldMap() map[string]interface{} {
m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain
m["BConfig.WebConfig.Session.SessionDisableHTTPOnly"] = BConfig.WebConfig.Session.SessionDisableHTTPOnly m["BConfig.WebConfig.Session.SessionDisableHTTPOnly"] = BConfig.WebConfig.Session.SessionDisableHTTPOnly
m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs
m["BConfig.Log.AccessLogsFormat"] = BConfig.Log.AccessLogsFormat
m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum
m["BConfig.Log.Outputs"] = BConfig.Log.Outputs m["BConfig.Log.Outputs"] = BConfig.Log.Outputs
return m return m

127
app.go
View File

@ -15,13 +15,17 @@
package beego package beego
import ( import (
"crypto/tls"
"crypto/x509"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"net/http/fcgi" "net/http/fcgi"
"os" "os"
"path" "path"
"time" "time"
"strings"
"github.com/astaxie/beego/grace" "github.com/astaxie/beego/grace"
"github.com/astaxie/beego/logs" "github.com/astaxie/beego/logs"
@ -51,8 +55,11 @@ func NewApp() *App {
return app return app
} }
// MiddleWare function for http.Handler
type MiddleWare func(http.Handler) http.Handler
// Run beego application. // Run beego application.
func (app *App) Run() { func (app *App) Run(mws ...MiddleWare) {
addr := BConfig.Listen.HTTPAddr addr := BConfig.Listen.HTTPAddr
if BConfig.Listen.HTTPPort != 0 { if BConfig.Listen.HTTPPort != 0 {
@ -94,6 +101,12 @@ func (app *App) Run() {
} }
app.Server.Handler = app.Handlers app.Server.Handler = app.Handlers
for i:=len(mws)-1;i>=0;i-- {
if mws[i] == nil {
continue
}
app.Server.Handler = mws[i](app.Server.Handler)
}
app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second
app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second
app.Server.ErrorLog = logs.GetLogger("HTTP") app.Server.ErrorLog = logs.GetLogger("HTTP")
@ -102,7 +115,7 @@ func (app *App) Run() {
if BConfig.Listen.Graceful { if BConfig.Listen.Graceful {
httpsAddr := BConfig.Listen.HTTPSAddr httpsAddr := BConfig.Listen.HTTPSAddr
app.Server.Addr = httpsAddr app.Server.Addr = httpsAddr
if BConfig.Listen.EnableHTTPS { if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS {
go func() { go func() {
time.Sleep(20 * time.Microsecond) time.Sleep(20 * time.Microsecond)
if BConfig.Listen.HTTPSPort != 0 { if BConfig.Listen.HTTPSPort != 0 {
@ -112,10 +125,19 @@ func (app *App) Run() {
server := grace.NewServer(httpsAddr, app.Handlers) server := grace.NewServer(httpsAddr, app.Handlers)
server.Server.ReadTimeout = app.Server.ReadTimeout server.Server.ReadTimeout = app.Server.ReadTimeout
server.Server.WriteTimeout = app.Server.WriteTimeout server.Server.WriteTimeout = app.Server.WriteTimeout
if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { if BConfig.Listen.EnableMutualHTTPS {
logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond) if err := server.ListenAndServeMutualTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile, BConfig.Listen.TrustCaFile); err != nil {
endRunning <- true logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
} else {
if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
} }
}() }()
} }
@ -139,7 +161,7 @@ func (app *App) Run() {
} }
// run normal mode // run normal mode
if BConfig.Listen.EnableHTTPS { if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS {
go func() { go func() {
time.Sleep(20 * time.Microsecond) time.Sleep(20 * time.Microsecond)
if BConfig.Listen.HTTPSPort != 0 { if BConfig.Listen.HTTPSPort != 0 {
@ -149,6 +171,19 @@ func (app *App) Run() {
return return
} }
logs.Info("https server Running on https://%s", app.Server.Addr) logs.Info("https server Running on https://%s", app.Server.Addr)
if BConfig.Listen.EnableMutualHTTPS {
pool := x509.NewCertPool()
data, err := ioutil.ReadFile(BConfig.Listen.TrustCaFile)
if err != nil {
BeeLogger.Info("MutualHTTPS should provide TrustCaFile")
return
}
pool.AppendCertsFromPEM(data)
app.Server.TLSConfig = &tls.Config{
ClientCAs: pool,
ClientAuth: tls.RequireAndVerifyClientCert,
}
}
if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil {
logs.Critical("ListenAndServeTLS: ", err) logs.Critical("ListenAndServeTLS: ", err)
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
@ -207,6 +242,84 @@ func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *A
return BeeApp return BeeApp
} }
// UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful
// in web applications that inherit most routes from a base webapp via the underscore
// import, and aim to overwrite only certain paths.
// The method parameter can be empty or "*" for all HTTP methods, or a particular
// method type (e.g. "GET" or "POST") for selective removal.
//
// Usage (replace "GET" with "*" for all methods):
// beego.UnregisterFixedRoute("/yourpreviouspath", "GET")
// beego.Router("/yourpreviouspath", yourControllerAddress, "get:GetNewPage")
func UnregisterFixedRoute(fixedRoute string, method string) *App {
subPaths := splitPath(fixedRoute)
if method == "" || method == "*" {
for m := range HTTPMETHOD {
if _, ok := BeeApp.Handlers.routers[m]; !ok {
continue
}
if BeeApp.Handlers.routers[m].prefix == strings.Trim(fixedRoute, "/ ") {
findAndRemoveSingleTree(BeeApp.Handlers.routers[m])
continue
}
findAndRemoveTree(subPaths, BeeApp.Handlers.routers[m], m)
}
return BeeApp
}
// Single HTTP method
um := strings.ToUpper(method)
if _, ok := BeeApp.Handlers.routers[um]; ok {
if BeeApp.Handlers.routers[um].prefix == strings.Trim(fixedRoute, "/ ") {
findAndRemoveSingleTree(BeeApp.Handlers.routers[um])
return BeeApp
}
findAndRemoveTree(subPaths, BeeApp.Handlers.routers[um], um)
}
return BeeApp
}
func findAndRemoveTree(paths []string, entryPointTree *Tree, method string) {
for i := range entryPointTree.fixrouters {
if entryPointTree.fixrouters[i].prefix == paths[0] {
if len(paths) == 1 {
if len(entryPointTree.fixrouters[i].fixrouters) > 0 {
// If the route had children subtrees, remove just the functional leaf,
// to allow children to function as before
if len(entryPointTree.fixrouters[i].leaves) > 0 {
entryPointTree.fixrouters[i].leaves[0] = nil
entryPointTree.fixrouters[i].leaves = entryPointTree.fixrouters[i].leaves[1:]
}
} else {
// Remove the *Tree from the fixrouters slice
entryPointTree.fixrouters[i] = nil
if i == len(entryPointTree.fixrouters)-1 {
entryPointTree.fixrouters = entryPointTree.fixrouters[:i]
} else {
entryPointTree.fixrouters = append(entryPointTree.fixrouters[:i], entryPointTree.fixrouters[i+1:len(entryPointTree.fixrouters)]...)
}
}
return
}
findAndRemoveTree(paths[1:], entryPointTree.fixrouters[i], method)
}
}
}
func findAndRemoveSingleTree(entryPointTree *Tree) {
if entryPointTree == nil {
return
}
if len(entryPointTree.fixrouters) > 0 {
// If the route had children subtrees, remove just the functional leaf,
// to allow children to function as before
if len(entryPointTree.leaves) > 0 {
entryPointTree.leaves[0] = nil
entryPointTree.leaves = entryPointTree.leaves[1:]
}
}
}
// Include will generate router file in the router/xxx.go from the controller's comments // Include will generate router file in the router/xxx.go from the controller's comments
// usage: // usage:
// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) // beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{})

View File

@ -23,7 +23,7 @@ import (
const ( const (
// VERSION represent beego web framework version. // VERSION represent beego web framework version.
VERSION = "1.9.0" VERSION = "1.9.2"
// DEV is for develop // DEV is for develop
DEV = "dev" DEV = "dev"
@ -67,6 +67,21 @@ func Run(params ...string) {
BeeApp.Run() BeeApp.Run()
} }
// RunWithMiddleWares Run beego application with middlewares.
func RunWithMiddleWares(addr string, mws ...MiddleWare) {
initBeforeHTTPRun()
strs := strings.Split(addr, ":")
if len(strs) > 0 && strs[0] != "" {
BConfig.Listen.HTTPAddr = strs[0]
}
if len(strs) > 1 && strs[1] != "" {
BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1])
}
BeeApp.Run(mws...)
}
func initBeforeHTTPRun() { func initBeforeHTTPRun() {
//init hooks //init hooks
AddAPPStartHook( AddAPPStartHook(

2
cache/cache.go vendored
View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Package cache provide a Cache interface and some implemetn engine // Package cache provide a Cache interface and some implement engine
// Usage: // Usage:
// //
// import( // import(

69
cache/redis/redis.go vendored
View File

@ -32,6 +32,7 @@ package redis
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"strconv" "strconv"
"time" "time"
@ -59,14 +60,23 @@ func NewRedisCache() cache.Cache {
return &Cache{key: DefaultKey} return &Cache{key: DefaultKey}
} }
// actually do the redis cmds // actually do the redis cmds, args[0] must be the key name.
func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) { func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) {
if len(args) < 1 {
return nil, errors.New("missing required arguments")
}
args[0] = rc.associate(args[0])
c := rc.p.Get() c := rc.p.Get()
defer c.Close() defer c.Close()
return c.Do(commandName, args...) return c.Do(commandName, args...)
} }
// associate with config key.
func (rc *Cache) associate(originKey interface{}) string {
return fmt.Sprintf("%s:%s", rc.key, originKey)
}
// Get cache from redis. // Get cache from redis.
func (rc *Cache) Get(key string) interface{} { func (rc *Cache) Get(key string) interface{} {
if v, err := rc.do("GET", key); err == nil { if v, err := rc.do("GET", key); err == nil {
@ -77,57 +87,28 @@ func (rc *Cache) Get(key string) interface{} {
// GetMulti get cache from redis. // GetMulti get cache from redis.
func (rc *Cache) GetMulti(keys []string) []interface{} { func (rc *Cache) GetMulti(keys []string) []interface{} {
size := len(keys)
var rv []interface{}
c := rc.p.Get() c := rc.p.Get()
defer c.Close() defer c.Close()
var err error var args []interface{}
for _, key := range keys { for _, key := range keys {
err = c.Send("GET", key) args = append(args, rc.associate(key))
if err != nil {
goto ERROR
}
} }
if err = c.Flush(); err != nil { values, err := redis.Values(c.Do("MGET", args...))
goto ERROR if err != nil {
return nil
} }
for i := 0; i < size; i++ { return values
if v, err := c.Receive(); err == nil {
rv = append(rv, v.([]byte))
} else {
rv = append(rv, err)
}
}
return rv
ERROR:
rv = rv[0:0]
for i := 0; i < size; i++ {
rv = append(rv, nil)
}
return rv
} }
// Put put cache to redis. // Put put cache to redis.
func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error {
var err error _, err := rc.do("SETEX", key, int64(timeout/time.Second), val)
if _, err = rc.do("SETEX", key, int64(timeout/time.Second), val); err != nil {
return err
}
if _, err = rc.do("HSET", rc.key, key, true); err != nil {
return err
}
return err return err
} }
// Delete delete cache in redis. // Delete delete cache in redis.
func (rc *Cache) Delete(key string) error { func (rc *Cache) Delete(key string) error {
var err error _, err := rc.do("DEL", key)
if _, err = rc.do("DEL", key); err != nil {
return err
}
_, err = rc.do("HDEL", rc.key, key)
return err return err
} }
@ -137,11 +118,6 @@ func (rc *Cache) IsExist(key string) bool {
if err != nil { if err != nil {
return false return false
} }
if !v {
if _, err = rc.do("HDEL", rc.key, key); err != nil {
return false
}
}
return v return v
} }
@ -159,16 +135,17 @@ func (rc *Cache) Decr(key string) error {
// ClearAll clean all cache in redis. delete this redis collection. // ClearAll clean all cache in redis. delete this redis collection.
func (rc *Cache) ClearAll() error { func (rc *Cache) ClearAll() error {
cachedKeys, err := redis.Strings(rc.do("HKEYS", rc.key)) c := rc.p.Get()
defer c.Close()
cachedKeys, err := redis.Strings(c.Do("KEYS", rc.key+":*"))
if err != nil { if err != nil {
return err return err
} }
for _, str := range cachedKeys { for _, str := range cachedKeys {
if _, err = rc.do("DEL", str); err != nil { if _, err = c.Do("DEL", str); err != nil {
return err return err
} }
} }
_, err = rc.do("DEL", rc.key)
return err return err
} }

View File

@ -49,22 +49,24 @@ type Config struct {
// Listen holds for http and https related config // Listen holds for http and https related config
type Listen struct { type Listen struct {
Graceful bool // Graceful means use graceful module to start the server Graceful bool // Graceful means use graceful module to start the server
ServerTimeOut int64 ServerTimeOut int64
ListenTCP4 bool ListenTCP4 bool
EnableHTTP bool EnableHTTP bool
HTTPAddr string HTTPAddr string
HTTPPort int HTTPPort int
EnableHTTPS bool EnableHTTPS bool
HTTPSAddr string EnableMutualHTTPS bool
HTTPSPort int HTTPSAddr string
HTTPSCertFile string HTTPSPort int
HTTPSKeyFile string HTTPSCertFile string
EnableAdmin bool HTTPSKeyFile string
AdminAddr string TrustCaFile string
AdminPort int EnableAdmin bool
EnableFcgi bool AdminAddr string
EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O AdminPort int
EnableFcgi bool
EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O
} }
// WebConfig holds web related config // WebConfig holds web related config
@ -103,9 +105,10 @@ type SessionConfig struct {
// LogConfig holds Log related config // LogConfig holds Log related config
type LogConfig struct { type LogConfig struct {
AccessLogs bool AccessLogs bool
FileLineNum bool AccessLogsFormat string //access log format: JSON_FORMAT, APACHE_FORMAT or empty string
Outputs map[string]string // Store Adaptor : config FileLineNum bool
Outputs map[string]string // Store Adaptor : config
} }
var ( var (
@ -134,9 +137,13 @@ func init() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
appConfigPath = filepath.Join(workPath, "conf", "app.conf") var filename = "app.conf"
if os.Getenv("BEEGO_MODE") != "" {
filename = os.Getenv("BEEGO_MODE") + ".app.conf"
}
appConfigPath = filepath.Join(workPath, "conf", filename)
if !utils.FileExists(appConfigPath) { if !utils.FileExists(appConfigPath) {
appConfigPath = filepath.Join(AppPath, "conf", "app.conf") appConfigPath = filepath.Join(AppPath, "conf", filename)
if !utils.FileExists(appConfigPath) { if !utils.FileExists(appConfigPath) {
AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()} AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()}
return return
@ -239,9 +246,10 @@ func newBConfig() *Config {
}, },
}, },
Log: LogConfig{ Log: LogConfig{
AccessLogs: false, AccessLogs: false,
FileLineNum: true, AccessLogsFormat: "APACHE_FORMAT",
Outputs: map[string]string{"console": ""}, FileLineNum: true,
Outputs: map[string]string{"console": ""},
}, },
} }
} }

View File

@ -20,6 +20,7 @@ import (
"errors" "errors"
"io" "io"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"net/url" "net/url"
"reflect" "reflect"
@ -115,9 +116,8 @@ func (input *BeegoInput) Domain() string {
// if no host info in request, return localhost. // if no host info in request, return localhost.
func (input *BeegoInput) Host() string { func (input *BeegoInput) Host() string {
if input.Context.Request.Host != "" { if input.Context.Request.Host != "" {
hostParts := strings.Split(input.Context.Request.Host, ":") if hostPart, _, err := net.SplitHostPort(input.Context.Request.Host); err == nil {
if len(hostParts) > 0 { return hostPart
return hostParts[0]
} }
return input.Context.Request.Host return input.Context.Request.Host
} }
@ -206,20 +206,20 @@ func (input *BeegoInput) AcceptsJSON() bool {
// IP returns request client ip. // IP returns request client ip.
// if in proxy, return first proxy id. // if in proxy, return first proxy id.
// if error, return 127.0.0.1. // if error, return RemoteAddr.
func (input *BeegoInput) IP() string { func (input *BeegoInput) IP() string {
ips := input.Proxy() ips := input.Proxy()
if len(ips) > 0 && ips[0] != "" { if len(ips) > 0 && ips[0] != "" {
rip := strings.Split(ips[0], ":") rip, _, err := net.SplitHostPort(ips[0])
return rip[0] if err != nil {
} rip = ips[0]
ip := strings.Split(input.Context.Request.RemoteAddr, ":")
if len(ip) > 0 {
if ip[0] != "[" {
return ip[0]
} }
return rip
} }
return "127.0.0.1" if ip, _, err := net.SplitHostPort(input.Context.Request.RemoteAddr); err == nil {
return ip
}
return input.Context.Request.RemoteAddr
} }
// Proxy returns proxy client ips slice. // Proxy returns proxy client ips slice.
@ -253,9 +253,8 @@ func (input *BeegoInput) SubDomains() string {
// Port returns request client port. // Port returns request client port.
// when error or empty, return 80. // when error or empty, return 80.
func (input *BeegoInput) Port() int { func (input *BeegoInput) Port() int {
parts := strings.Split(input.Context.Request.Host, ":") if _, portPart, err := net.SplitHostPort(input.Context.Request.Host); err == nil {
if len(parts) == 2 { port, _ := strconv.Atoi(portPart)
port, _ := strconv.Atoi(parts[1])
return port return port
} }
return 80 return 80

View File

@ -2,7 +2,9 @@ package grace
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509"
"fmt" "fmt"
"io/ioutil"
"log" "log"
"net" "net"
"net/http" "net/http"
@ -65,7 +67,7 @@ func (srv *Server) ListenAndServe() (err error) {
log.Println(err) log.Println(err)
return err return err
} }
err = process.Kill() err = process.Signal(syscall.SIGTERM)
if err != nil { if err != nil {
return err return err
} }
@ -114,6 +116,62 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
srv.tlsInnerListener = newGraceListener(l, srv) srv.tlsInnerListener = newGraceListener(l, srv)
srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig) srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig)
if srv.isChild {
process, err := os.FindProcess(os.Getppid())
if err != nil {
log.Println(err)
return err
}
err = process.Signal(syscall.SIGTERM)
if err != nil {
return err
}
}
log.Println(os.Getpid(), srv.Addr)
return srv.Serve()
}
// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls
// Serve to handle requests on incoming mutual TLS connections.
func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) {
addr := srv.Addr
if addr == "" {
addr = ":https"
}
if srv.TLSConfig == nil {
srv.TLSConfig = &tls.Config{}
}
if srv.TLSConfig.NextProtos == nil {
srv.TLSConfig.NextProtos = []string{"http/1.1"}
}
srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return
}
srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
pool := x509.NewCertPool()
data, err := ioutil.ReadFile(trustFile)
if err != nil {
log.Println(err)
return err
}
pool.AppendCertsFromPEM(data)
srv.TLSConfig.ClientCAs = pool
log.Println("Mutual HTTPS")
go srv.handleSignals()
l, err := srv.getListener(addr)
if err != nil {
log.Println(err)
return err
}
srv.tlsInnerListener = newGraceListener(l, srv)
srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig)
if srv.isChild { if srv.isChild {
process, err := os.FindProcess(os.Getppid()) process, err := os.FindProcess(os.Getppid())
if err != nil { if err != nil {

View File

@ -317,7 +317,19 @@ func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest {
} }
return b return b
} }
// XMLBody adds request raw body encoding by XML.
func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) {
if b.req.Body == nil && obj != nil {
byts, err := xml.Marshal(obj)
if err != nil {
return b, err
}
b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
b.req.ContentLength = int64(len(byts))
b.req.Header.Set("Content-Type", "application/xml")
}
return b, nil
}
// JSONBody adds request raw body encoding by JSON. // JSONBody adds request raw body encoding by JSON.
func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) { func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) {
if b.req.Body == nil && obj != nil { if b.req.Body == nil && obj != nil {

86
logs/accesslog.go Normal file
View File

@ -0,0 +1,86 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package logs
import (
"bytes"
"encoding/json"
"time"
"fmt"
)
const (
apacheFormatPattern = "%s - - [%s] \"%s %d %d\" %f %s %s\n"
apacheFormat = "APACHE_FORMAT"
jsonFormat = "JSON_FORMAT"
)
// AccessLogRecord struct for holding access log data.
type AccessLogRecord struct {
RemoteAddr string `json:"remote_addr"`
RequestTime time.Time `json:"request_time"`
RequestMethod string `json:"request_method"`
Request string `json:"request"`
ServerProtocol string `json:"server_protocol"`
Host string `json:"host"`
Status int `json:"status"`
BodyBytesSent int64 `json:"body_bytes_sent"`
ElapsedTime time.Duration `json:"elapsed_time"`
HTTPReferrer string `json:"http_referrer"`
HTTPUserAgent string `json:"http_user_agent"`
RemoteUser string `json:"remote_user"`
}
func (r *AccessLogRecord) json() ([]byte, error) {
buffer := &bytes.Buffer{}
encoder := json.NewEncoder(buffer)
disableEscapeHTML(encoder)
err := encoder.Encode(r)
return buffer.Bytes(), err
}
func disableEscapeHTML(i interface{}) {
e, ok := i.(interface {
SetEscapeHTML(bool)
});
if ok {
e.SetEscapeHTML(false)
}
}
// AccessLog - Format and print access log.
func AccessLog(r *AccessLogRecord, format string) {
var msg string
switch format {
case apacheFormat:
timeFormatted := r.RequestTime.Format("02/Jan/2006 03:04:05")
msg = fmt.Sprintf(apacheFormatPattern, r.RemoteAddr, timeFormatted, r.Request, r.Status, r.BodyBytesSent,
r.ElapsedTime.Seconds(), r.HTTPReferrer, r.HTTPUserAgent)
case jsonFormat:
fallthrough
default:
jsonData, err := r.json()
if err != nil {
msg = fmt.Sprintf(`{"Error": "%s"}`, err)
} else {
msg = string(jsonData)
}
}
beeLogger.Debug(msg)
}

View File

@ -182,7 +182,7 @@ func (w *fileLogWriter) initFd() error {
if w.Daily { if w.Daily {
go w.dailyRotate(w.dailyOpenTime) go w.dailyRotate(w.dailyOpenTime)
} }
if fInfo.Size() > 0 { if fInfo.Size() > 0 && w.MaxLines > 0 {
count, err := w.lines() count, err := w.lines()
if err != nil { if err != nil {
return err return err

View File

@ -87,13 +87,15 @@ const (
mi2 = `012345678901234567890123456789012345678901234567890123456789` mi2 = `012345678901234567890123456789012345678901234567890123456789`
s1 = `000000000011111111112222222222333333333344444444445555555555` s1 = `000000000011111111112222222222333333333344444444445555555555`
s2 = `012345678901234567890123456789012345678901234567890123456789` s2 = `012345678901234567890123456789012345678901234567890123456789`
ns1 = `0123456789`
) )
func formatTimeHeader(when time.Time) ([]byte, int) { func formatTimeHeader(when time.Time) ([]byte, int) {
y, mo, d := when.Date() y, mo, d := when.Date()
h, mi, s := when.Clock() h, mi, s := when.Clock()
//len("2006/01/02 15:04:05 ")==20 ns := when.Nanosecond()/1000000
var buf [20]byte //len("2006/01/02 15:04:05.123 ")==24
var buf [24]byte
buf[0] = y1[y/1000%10] buf[0] = y1[y/1000%10]
buf[1] = y2[y/100] buf[1] = y2[y/100]
@ -114,7 +116,12 @@ func formatTimeHeader(when time.Time) ([]byte, int) {
buf[16] = ':' buf[16] = ':'
buf[17] = s1[s] buf[17] = s1[s]
buf[18] = s2[s] buf[18] = s2[s]
buf[19] = ' ' buf[19] = '.'
buf[20] = ns1[ns/100]
buf[21] = ns1[ns%100/10]
buf[22] = ns1[ns%10]
buf[23] = ' '
return buf[0:], d return buf[0:], d
} }

View File

@ -31,7 +31,7 @@ func TestFormatHeader_0(t *testing.T) {
break break
} }
h, _ := formatTimeHeader(tm) h, _ := formatTimeHeader(tm)
if tm.Format("2006/01/02 15:04:05 ") != string(h) { if tm.Format("2006/01/02 15:04:05.999 ") != string(h) {
t.Log(tm) t.Log(tm)
t.FailNow() t.FailNow()
} }
@ -49,7 +49,7 @@ func TestFormatHeader_1(t *testing.T) {
break break
} }
h, _ := formatTimeHeader(tm) h, _ := formatTimeHeader(tm)
if tm.Format("2006/01/02 15:04:05 ") != string(h) { if tm.Format("2006/01/02 15:04:05.999 ") != string(h) {
t.Log(tm) t.Log(tm)
t.FailNow() t.FailNow()
} }

View File

@ -172,7 +172,7 @@ func Register(name string, m Migrationer) error {
return nil return nil
} }
// Upgrade upgrate the migration from lasttime // Upgrade upgrade the migration from lasttime
func Upgrade(lasttime int64) error { func Upgrade(lasttime int64) error {
sm := sortMap(migrationMap) sm := sortMap(migrationMap)
i := 0 i := 0

View File

@ -61,6 +61,9 @@ func init() {
// set default database // set default database
orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30) orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30)
// create table
orm.RunSyncdb("default", false, true)
} }
func main() { func main() {

View File

@ -51,12 +51,14 @@ checkColumn:
switch fieldType { switch fieldType {
case TypeBooleanField: case TypeBooleanField:
col = T["bool"] col = T["bool"]
case TypeCharField: case TypeVarCharField:
if al.Driver == DRPostgres && fi.toText { if al.Driver == DRPostgres && fi.toText {
col = T["string-text"] col = T["string-text"]
} else { } else {
col = fmt.Sprintf(T["string"], fieldSize) col = fmt.Sprintf(T["string"], fieldSize)
} }
case TypeCharField:
col = fmt.Sprintf(T["string-char"], fieldSize)
case TypeTextField: case TypeTextField:
col = T["string-text"] col = T["string-text"]
case TypeTimeField: case TypeTimeField:
@ -96,13 +98,13 @@ checkColumn:
} }
case TypeJSONField: case TypeJSONField:
if al.Driver != DRPostgres { if al.Driver != DRPostgres {
fieldType = TypeCharField fieldType = TypeVarCharField
goto checkColumn goto checkColumn
} }
col = T["json"] col = T["json"]
case TypeJsonbField: case TypeJsonbField:
if al.Driver != DRPostgres { if al.Driver != DRPostgres {
fieldType = TypeCharField fieldType = TypeVarCharField
goto checkColumn goto checkColumn
} }
col = T["jsonb"] col = T["jsonb"]

View File

@ -142,7 +142,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
} else { } else {
value = field.Bool() value = field.Bool()
} }
case TypeCharField, TypeTextField, TypeJSONField, TypeJsonbField: case TypeVarCharField, TypeCharField, TypeTextField, TypeJSONField, TypeJsonbField:
if ns, ok := field.Interface().(sql.NullString); ok { if ns, ok := field.Interface().(sql.NullString); ok {
value = nil value = nil
if ns.Valid { if ns.Valid {
@ -1240,7 +1240,7 @@ setValue:
} }
value = b value = b
} }
case fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField:
if str == nil { if str == nil {
value = ToStr(val) value = ToStr(val)
} else { } else {
@ -1386,7 +1386,7 @@ setValue:
field.SetBool(value.(bool)) field.SetBool(value.(bool))
} }
} }
case fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField:
if isNative { if isNative {
if ns, ok := field.Interface().(sql.NullString); ok { if ns, ok := field.Interface().(sql.NullString); ok {
if value == nil { if value == nil {

View File

@ -119,7 +119,7 @@ type alias struct {
func detectTZ(al *alias) { func detectTZ(al *alias) {
// orm timezone system match database // orm timezone system match database
// default use Local // default use Local
al.TZ = time.Local al.TZ = DefaultTimeLoc
if al.DriverName == "sphinx" { if al.DriverName == "sphinx" {
return return
@ -136,7 +136,9 @@ func detectTZ(al *alias) {
} }
t, err := time.Parse("-07:00:00", tz) t, err := time.Parse("-07:00:00", tz)
if err == nil { if err == nil {
al.TZ = t.Location() if t.Location().String() != "" {
al.TZ = t.Location()
}
} else { } else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
} }

View File

@ -46,6 +46,7 @@ var mysqlTypes = map[string]string{
"pk": "NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY",
"bool": "bool", "bool": "bool",
"string": "varchar(%d)", "string": "varchar(%d)",
"string-char": "char(%d)",
"string-text": "longtext", "string-text": "longtext",
"time.Time-date": "date", "time.Time-date": "date",
"time.Time": "datetime", "time.Time": "datetime",

View File

@ -34,6 +34,7 @@ var oracleTypes = map[string]string{
"pk": "NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY",
"bool": "bool", "bool": "bool",
"string": "VARCHAR2(%d)", "string": "VARCHAR2(%d)",
"string-char": "CHAR(%d)",
"string-text": "VARCHAR2(%d)", "string-text": "VARCHAR2(%d)",
"time.Time-date": "DATE", "time.Time-date": "DATE",
"time.Time": "TIMESTAMP", "time.Time": "TIMESTAMP",

View File

@ -43,6 +43,7 @@ var postgresTypes = map[string]string{
"pk": "NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY",
"bool": "bool", "bool": "bool",
"string": "varchar(%d)", "string": "varchar(%d)",
"string-char": "char(%d)",
"string-text": "text", "string-text": "text",
"time.Time-date": "date", "time.Time-date": "date",
"time.Time": "timestamp with time zone", "time.Time": "timestamp with time zone",

View File

@ -43,6 +43,7 @@ var sqliteTypes = map[string]string{
"pk": "NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY",
"bool": "bool", "bool": "bool",
"string": "varchar(%d)", "string": "varchar(%d)",
"string-char": "character(%d)",
"string-text": "text", "string-text": "text",
"time.Time-date": "date", "time.Time-date": "date",
"time.Time": "datetime", "time.Time": "datetime",

View File

@ -52,7 +52,7 @@ func (mc *_modelCache) all() map[string]*modelInfo {
return m return m
} }
// get orderd model info // get ordered model info
func (mc *_modelCache) allOrdered() []*modelInfo { func (mc *_modelCache) allOrdered() []*modelInfo {
m := make([]*modelInfo, 0, len(mc.orders)) m := make([]*modelInfo, 0, len(mc.orders))
for _, table := range mc.orders { for _, table := range mc.orders {

View File

@ -89,7 +89,7 @@ func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) {
modelCache.set(table, mi) modelCache.set(table, mi)
} }
// boostrap models // bootstrap models
func bootStrap() { func bootStrap() {
if modelCache.done { if modelCache.done {
return return
@ -332,7 +332,7 @@ func RegisterModelWithSuffix(suffix string, models ...interface{}) {
} }
} }
// BootStrap bootrap models. // BootStrap bootstrap models.
// make all model parsed and can not add more models // make all model parsed and can not add more models
func BootStrap() { func BootStrap() {
if modelCache.done { if modelCache.done {

View File

@ -23,6 +23,7 @@ import (
// Define the Type enum // Define the Type enum
const ( const (
TypeBooleanField = 1 << iota TypeBooleanField = 1 << iota
TypeVarCharField
TypeCharField TypeCharField
TypeTextField TypeTextField
TypeTimeField TypeTimeField
@ -49,9 +50,9 @@ const (
// Define some logic enum // Define some logic enum
const ( const (
IsIntegerField = ^-TypePositiveBigIntegerField >> 5 << 6 IsIntegerField = ^-TypePositiveBigIntegerField >> 6 << 7
IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 9 << 10 IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 10 << 11
IsRelField = ^-RelReverseMany >> 17 << 18 IsRelField = ^-RelReverseMany >> 18 << 19
IsFieldType = ^-RelReverseMany<<1 + 1 IsFieldType = ^-RelReverseMany<<1 + 1
) )
@ -126,7 +127,7 @@ func (e *CharField) String() string {
// FieldType return the enum type // FieldType return the enum type
func (e *CharField) FieldType() int { func (e *CharField) FieldType() int {
return TypeCharField return TypeVarCharField
} }
// SetRaw set the interface to string // SetRaw set the interface to string
@ -232,7 +233,7 @@ func (e *DateField) Set(d time.Time) {
*e = DateField(d) *e = DateField(d)
} }
// String convert datatime to string // String convert datetime to string
func (e *DateField) String() string { func (e *DateField) String() string {
return e.Value().String() return e.Value().String()
} }
@ -272,12 +273,12 @@ var _ Fielder = new(DateField)
// Takes the same extra arguments as DateField. // Takes the same extra arguments as DateField.
type DateTimeField time.Time type DateTimeField time.Time
// Value return the datatime value // Value return the datetime value
func (e DateTimeField) Value() time.Time { func (e DateTimeField) Value() time.Time {
return time.Time(e) return time.Time(e)
} }
// Set set the time.Time to datatime // Set set the time.Time to datetime
func (e *DateTimeField) Set(d time.Time) { func (e *DateTimeField) Set(d time.Time) {
*e = DateTimeField(d) *e = DateTimeField(d)
} }
@ -309,12 +310,12 @@ func (e *DateTimeField) SetRaw(value interface{}) error {
return nil return nil
} }
// RawValue return the datatime value // RawValue return the datetime value
func (e *DateTimeField) RawValue() interface{} { func (e *DateTimeField) RawValue() interface{} {
return e.Value() return e.Value()
} }
// verify datatime implement fielder // verify datetime implement fielder
var _ Fielder = new(DateTimeField) var _ Fielder = new(DateTimeField)
// FloatField A floating-point number represented in go by a float32 value. // FloatField A floating-point number represented in go by a float32 value.

View File

@ -244,8 +244,10 @@ checkType:
if err != nil { if err != nil {
goto end goto end
} }
if fieldType == TypeCharField { if fieldType == TypeVarCharField {
switch tags["type"] { switch tags["type"] {
case "char":
fieldType = TypeCharField
case "text": case "text":
fieldType = TypeTextField fieldType = TypeTextField
case "json": case "json":
@ -357,7 +359,7 @@ checkType:
switch fieldType { switch fieldType {
case TypeBooleanField: case TypeBooleanField:
case TypeCharField, TypeJSONField, TypeJsonbField: case TypeVarCharField, TypeCharField, TypeJSONField, TypeJsonbField:
if size != "" { if size != "" {
v, e := StrTo(size).Int32() v, e := StrTo(size).Int32()
if e != nil { if e != nil {

View File

@ -49,7 +49,7 @@ func (e *SliceStringField) String() string {
} }
func (e *SliceStringField) FieldType() int { func (e *SliceStringField) FieldType() int {
return TypeCharField return TypeVarCharField
} }
func (e *SliceStringField) SetRaw(value interface{}) error { func (e *SliceStringField) SetRaw(value interface{}) error {

View File

@ -149,7 +149,7 @@ func getFieldType(val reflect.Value) (ft int, err error) {
case reflect.TypeOf(new(bool)): case reflect.TypeOf(new(bool)):
ft = TypeBooleanField ft = TypeBooleanField
case reflect.TypeOf(new(string)): case reflect.TypeOf(new(string)):
ft = TypeCharField ft = TypeVarCharField
case reflect.TypeOf(new(time.Time)): case reflect.TypeOf(new(time.Time)):
ft = TypeDateTimeField ft = TypeDateTimeField
default: default:
@ -176,7 +176,7 @@ func getFieldType(val reflect.Value) (ft int, err error) {
case reflect.Bool: case reflect.Bool:
ft = TypeBooleanField ft = TypeBooleanField
case reflect.String: case reflect.String:
ft = TypeCharField ft = TypeVarCharField
default: default:
if elm.Interface() == nil { if elm.Interface() == nil {
panic(fmt.Errorf("%s is nil pointer, may be miss setting tag", val)) panic(fmt.Errorf("%s is nil pointer, may be miss setting tag", val))
@ -189,7 +189,7 @@ func getFieldType(val reflect.Value) (ft int, err error) {
case sql.NullBool: case sql.NullBool:
ft = TypeBooleanField ft = TypeBooleanField
case sql.NullString: case sql.NullString:
ft = TypeCharField ft = TypeVarCharField
case time.Time: case time.Time:
ft = TypeDateTimeField ft = TypeDateTimeField
} }

157
router.go
View File

@ -50,23 +50,23 @@ const (
var ( var (
// HTTPMETHOD list the supported http methods. // HTTPMETHOD list the supported http methods.
HTTPMETHOD = map[string]string{ HTTPMETHOD = map[string]bool{
"GET": "GET", "GET": true,
"POST": "POST", "POST": true,
"PUT": "PUT", "PUT": true,
"DELETE": "DELETE", "DELETE": true,
"PATCH": "PATCH", "PATCH": true,
"OPTIONS": "OPTIONS", "OPTIONS": true,
"HEAD": "HEAD", "HEAD": true,
"TRACE": "TRACE", "TRACE": true,
"CONNECT": "CONNECT", "CONNECT": true,
"MKCOL": "MKCOL", "MKCOL": true,
"COPY": "COPY", "COPY": true,
"MOVE": "MOVE", "MOVE": true,
"PROPFIND": "PROPFIND", "PROPFIND": true,
"PROPPATCH": "PROPPATCH", "PROPPATCH": true,
"LOCK": "LOCK", "LOCK": true,
"UNLOCK": "UNLOCK", "UNLOCK": true,
} }
// these beego.Controller's methods shouldn't reflect to AutoRouter // these beego.Controller's methods shouldn't reflect to AutoRouter
exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString", exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString",
@ -117,6 +117,7 @@ type ControllerInfo struct {
handler http.Handler handler http.Handler
runFunction FilterFunc runFunction FilterFunc
routerType int routerType int
initialize func() ControllerInterface
methodParams []*param.MethodParam methodParams []*param.MethodParam
} }
@ -169,7 +170,7 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt
} }
comma := strings.Split(colon[0], ",") comma := strings.Split(colon[0], ",")
for _, m := range comma { for _, m := range comma {
if _, ok := HTTPMETHOD[strings.ToUpper(m)]; m == "*" || ok { if m == "*" || HTTPMETHOD[strings.ToUpper(m)] {
if val := reflectVal.MethodByName(colon[1]); val.IsValid() { if val := reflectVal.MethodByName(colon[1]); val.IsValid() {
methods[strings.ToUpper(m)] = colon[1] methods[strings.ToUpper(m)] = colon[1]
} else { } else {
@ -187,15 +188,36 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt
route.methods = methods route.methods = methods
route.routerType = routerTypeBeego route.routerType = routerTypeBeego
route.controllerType = t route.controllerType = t
route.initialize = func() ControllerInterface {
vc := reflect.New(route.controllerType)
execController, ok := vc.Interface().(ControllerInterface)
if !ok {
panic("controller is not ControllerInterface")
}
elemVal := reflect.ValueOf(c).Elem()
elemType := reflect.TypeOf(c).Elem()
execElem := reflect.ValueOf(execController).Elem()
numOfFields := elemVal.NumField()
for i := 0; i < numOfFields; i++ {
fieldVal := elemVal.Field(i)
fieldType := elemType.Field(i)
execElem.FieldByName(fieldType.Name).Set(fieldVal)
}
return execController
}
route.methodParams = methodParams route.methodParams = methodParams
if len(methods) == 0 { if len(methods) == 0 {
for _, m := range HTTPMETHOD { for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route) p.addToRouter(m, pattern, route)
} }
} else { } else {
for k := range methods { for k := range methods {
if k == "*" { if k == "*" {
for _, m := range HTTPMETHOD { for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route) p.addToRouter(m, pattern, route)
} }
} else { } else {
@ -337,7 +359,7 @@ func (p *ControllerRegister) Any(pattern string, f FilterFunc) {
// }) // })
func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
method = strings.ToUpper(method) method = strings.ToUpper(method)
if _, ok := HTTPMETHOD[method]; method != "*" && !ok { if method != "*" && !HTTPMETHOD[method] {
panic("not support http method: " + method) panic("not support http method: " + method)
} }
route := &ControllerInfo{} route := &ControllerInfo{}
@ -346,7 +368,7 @@ func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
route.runFunction = f route.runFunction = f
methods := make(map[string]string) methods := make(map[string]string)
if method == "*" { if method == "*" {
for _, val := range HTTPMETHOD { for val := range HTTPMETHOD {
methods[val] = val methods[val] = val
} }
} else { } else {
@ -355,7 +377,7 @@ func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
route.methods = methods route.methods = methods
for k := range methods { for k := range methods {
if k == "*" { if k == "*" {
for _, m := range HTTPMETHOD { for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route) p.addToRouter(m, pattern, route)
} }
} else { } else {
@ -375,7 +397,7 @@ func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...
pattern = path.Join(pattern, "?:all(.*)") pattern = path.Join(pattern, "?:all(.*)")
} }
} }
for _, m := range HTTPMETHOD { for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route) p.addToRouter(m, pattern, route)
} }
} }
@ -410,7 +432,7 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface)
patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name)) patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name))
patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name) patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name)
route.pattern = pattern route.pattern = pattern
for _, m := range HTTPMETHOD { for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route) p.addToRouter(m, pattern, route)
p.addToRouter(m, patternInit, route) p.addToRouter(m, patternInit, route)
p.addToRouter(m, patternFix, route) p.addToRouter(m, patternFix, route)
@ -511,7 +533,7 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin
if c.routerType == routerTypeBeego && if c.routerType == routerTypeBeego &&
strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllName) { strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllName) {
find := false find := false
if _, ok := HTTPMETHOD[strings.ToUpper(methodName)]; ok { if HTTPMETHOD[strings.ToUpper(methodName)] {
if len(c.methods) == 0 { if len(c.methods) == 0 {
find = true find = true
} else if m, ok := c.methods[strings.ToUpper(methodName)]; ok && m == strings.ToUpper(methodName) { } else if m, ok := c.methods[strings.ToUpper(methodName)]; ok && m == strings.ToUpper(methodName) {
@ -659,7 +681,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
// filter wrong http method // filter wrong http method
if _, ok := HTTPMETHOD[r.Method]; !ok { if !HTTPMETHOD[r.Method] {
http.Error(rw, "Method Not Allowed", 405) http.Error(rw, "Method Not Allowed", 405)
goto Admin goto Admin
} }
@ -768,14 +790,20 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
// also defined runRouter & runMethod from filter // also defined runRouter & runMethod from filter
if !isRunnable { if !isRunnable {
//Invoke the request handler //Invoke the request handler
vc := reflect.New(runRouter) var execController ControllerInterface
execController, ok := vc.Interface().(ControllerInterface) if routerInfo.initialize != nil {
if !ok { execController = routerInfo.initialize()
panic("controller is not ControllerInterface") } else {
vc := reflect.New(runRouter)
var ok bool
execController, ok = vc.Interface().(ControllerInterface)
if !ok {
panic("controller is not ControllerInterface")
}
} }
//call the controller init function //call the controller init function
execController.Init(context, runRouter.Name(), runMethod, vc.Interface()) execController.Init(context, runRouter.Name(), runMethod, execController)
//call prepare function //call prepare function
execController.Prepare() execController.Prepare()
@ -810,6 +838,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
execController.Options() execController.Options()
default: default:
if !execController.HandlerFunc(runMethod) { if !execController.HandlerFunc(runMethod) {
vc := reflect.ValueOf(execController)
method := vc.MethodByName(runMethod) method := vc.MethodByName(runMethod)
in := param.ConvertParams(methodParams, method.Type(), context) in := param.ConvertParams(methodParams, method.Type(), context)
out := method.Call(in) out := method.Call(in)
@ -846,16 +875,19 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
Admin: Admin:
//admin module record QPS //admin module record QPS
statusCode := context.ResponseWriter.Status
if statusCode == 0 {
statusCode = 200
}
if BConfig.Listen.EnableAdmin { if BConfig.Listen.EnableAdmin {
timeDur := time.Since(startTime) timeDur := time.Since(startTime)
pattern := "" pattern := ""
if routerInfo != nil { if routerInfo != nil {
pattern = routerInfo.pattern pattern = routerInfo.pattern
} }
statusCode := context.ResponseWriter.Status
if statusCode == 0 {
statusCode = 200
}
if FilterMonitorFunc(r.Method, r.URL.Path, timeDur, pattern, statusCode) { if FilterMonitorFunc(r.Method, r.URL.Path, timeDur, pattern, statusCode) {
if runRouter != nil { if runRouter != nil {
go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runRouter.Name(), timeDur) go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runRouter.Name(), timeDur)
@ -869,36 +901,47 @@ Admin:
timeDur := time.Since(startTime) timeDur := time.Since(startTime)
var devInfo string var devInfo string
statusCode := context.ResponseWriter.Status
if statusCode == 0 {
statusCode = 200
}
iswin := (runtime.GOOS == "windows") iswin := (runtime.GOOS == "windows")
statusColor := logs.ColorByStatus(iswin, statusCode) statusColor := logs.ColorByStatus(iswin, statusCode)
methodColor := logs.ColorByMethod(iswin, r.Method) methodColor := logs.ColorByMethod(iswin, r.Method)
resetColor := logs.ColorByMethod(iswin, "") resetColor := logs.ColorByMethod(iswin, "")
if BConfig.Log.AccessLogsFormat != "" {
if findRouter { record := &logs.AccessLogRecord{
if routerInfo != nil { RemoteAddr: context.Input.IP(),
devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s r:%s", context.Input.IP(), statusColor, statusCode, RequestTime: startTime,
resetColor, timeDur.String(), "match", methodColor, r.Method, resetColor, r.URL.Path, RequestMethod: r.Method,
routerInfo.pattern) Request: fmt.Sprintf("%s %s %s", r.Method, r.RequestURI, r.Proto),
ServerProtocol: r.Proto,
Host: r.Host,
Status: statusCode,
ElapsedTime: timeDur,
HTTPReferrer: r.Header.Get("Referer"),
HTTPUserAgent: r.Header.Get("User-Agent"),
RemoteUser: r.Header.Get("Remote-User"),
BodyBytesSent: 0, //@todo this one is missing!
}
logs.AccessLog(record, BConfig.Log.AccessLogsFormat)
} else {
if findRouter {
if routerInfo != nil {
devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s r:%s", context.Input.IP(), statusColor, statusCode,
resetColor, timeDur.String(), "match", methodColor, r.Method, resetColor, r.URL.Path,
routerInfo.pattern)
} else {
devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", context.Input.IP(), statusColor, statusCode, resetColor,
timeDur.String(), "match", methodColor, r.Method, resetColor, r.URL.Path)
}
} else { } else {
devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", context.Input.IP(), statusColor, statusCode, resetColor, devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", context.Input.IP(), statusColor, statusCode, resetColor,
timeDur.String(), "match", methodColor, r.Method, resetColor, r.URL.Path) timeDur.String(), "nomatch", methodColor, r.Method, resetColor, r.URL.Path)
}
if iswin {
logs.W32Debug(devInfo)
} else {
logs.Debug(devInfo)
} }
} else {
devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", context.Input.IP(), statusColor, statusCode, resetColor,
timeDur.String(), "nomatch", methodColor, r.Method, resetColor, r.URL.Path)
}
if iswin {
logs.W32Debug(devInfo)
} else {
logs.Debug(devInfo)
} }
} }
// Call WriteHeader if status code has been set changed // Call WriteHeader if status code has been set changed
if context.Output.Status != 0 { if context.Output.Status != 0 {
context.ResponseWriter.WriteHeader(context.Output.Status) context.ResponseWriter.WriteHeader(context.Output.Status)

View File

@ -160,10 +160,13 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
return nil, err return nil, err
} }
} }
_, err = c.Do("SELECT", rp.dbNum) //some redis proxy such as twemproxy is not support select command
if err != nil { if rp.dbNum > 0 {
c.Close() _, err = c.Do("SELECT", rp.dbNum)
return nil, err if err != nil {
c.Close()
return nil, err
}
} }
return c, err return c, err
}, rp.poolsize) }, rp.poolsize)

View File

@ -78,6 +78,8 @@ func (fs *FileSessionStore) SessionID() string {
// SessionRelease Write file session to local file with Gob string // SessionRelease Write file session to local file with Gob string
func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
b, err := EncodeGob(fs.values) b, err := EncodeGob(fs.values)
if err != nil { if err != nil {
SLogger.Println(err) SLogger.Println(err)
@ -164,7 +166,7 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) {
} }
// SessionExist Check file session exist. // SessionExist Check file session exist.
// it checkes the file named from sid exist or not. // it checks the file named from sid exist or not.
func (fp *FileProvider) SessionExist(sid string) bool { func (fp *FileProvider) SessionExist(sid string) bool {
filepder.lock.Lock() filepder.lock.Lock()
defer filepder.lock.Unlock() defer filepder.lock.Unlock()

View File

@ -149,7 +149,7 @@ func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime
// 2. Verify MAC. Value is "date|value|mac". // 2. Verify MAC. Value is "date|value|mac".
parts := bytes.SplitN(b, []byte("|"), 3) parts := bytes.SplitN(b, []byte("|"), 3)
if len(parts) != 3 { if len(parts) != 3 {
return nil, errors.New("Decode: invalid value %v") return nil, errors.New("Decode: invalid value format")
} }
b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...) b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...)

View File

@ -121,6 +121,7 @@ type Schema struct {
Type string `json:"type,omitempty" yaml:"type,omitempty"` Type string `json:"type,omitempty" yaml:"type,omitempty"`
Items *Schema `json:"items,omitempty" yaml:"items,omitempty"` Items *Schema `json:"items,omitempty" yaml:"items,omitempty"`
Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"` Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"`
Enum []interface{} `json:"enum,omitempty" yaml:"enum,omitempty"`
} }
// Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification // Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification
@ -141,7 +142,7 @@ type Propertie struct {
// Response as they are returned from executing this operation. // Response as they are returned from executing this operation.
type Response struct { type Response struct {
Description string `json:"description,omitempty" yaml:"description,omitempty"` Description string `json:"description" yaml:"description"`
Schema *Schema `json:"schema,omitempty" yaml:"schema,omitempty"` Schema *Schema `json:"schema,omitempty" yaml:"schema,omitempty"`
Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"` Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"`
} }

View File

@ -218,9 +218,9 @@ func BuildTemplate(dir string, files ...string) error {
} }
if err != nil { if err != nil {
logs.Error("parse template err:", file, err) logs.Error("parse template err:", file, err)
} else { return err
beeTemplates[file] = t
} }
beeTemplates[file] = t
templatesLock.Unlock() templatesLock.Unlock()
} }
} }

View File

@ -28,7 +28,7 @@ var (
) )
// Tree has three elements: FixRouter/wildcard/leaves // Tree has three elements: FixRouter/wildcard/leaves
// fixRouter sotres Fixed Router // fixRouter stores Fixed Router
// wildcard stores params // wildcard stores params
// leaves store the endpoint information // leaves store the endpoint information
type Tree struct { type Tree struct {

226
unregroute_test.go Normal file
View File

@ -0,0 +1,226 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package beego
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
//
// The unregroute_test.go contains tests for the unregister route
// functionality, that allows overriding route paths in children project
// that embed parent routers.
//
const contentRootOriginal = "ok-original-root"
const contentLevel1Original = "ok-original-level1"
const contentLevel2Original = "ok-original-level2"
const contentRootReplacement = "ok-replacement-root"
const contentLevel1Replacement = "ok-replacement-level1"
const contentLevel2Replacement = "ok-replacement-level2"
// TestPreUnregController will supply content for the original routes,
// before unregistration
type TestPreUnregController struct {
Controller
}
func (tc *TestPreUnregController) GetFixedRoot() {
tc.Ctx.Output.Body([]byte(contentRootOriginal))
}
func (tc *TestPreUnregController) GetFixedLevel1() {
tc.Ctx.Output.Body([]byte(contentLevel1Original))
}
func (tc *TestPreUnregController) GetFixedLevel2() {
tc.Ctx.Output.Body([]byte(contentLevel2Original))
}
// TestPostUnregController will supply content for the overriding routes,
// after the original ones are unregistered.
type TestPostUnregController struct {
Controller
}
func (tc *TestPostUnregController) GetFixedRoot() {
tc.Ctx.Output.Body([]byte(contentRootReplacement))
}
func (tc *TestPostUnregController) GetFixedLevel1() {
tc.Ctx.Output.Body([]byte(contentLevel1Replacement))
}
func (tc *TestPostUnregController) GetFixedLevel2() {
tc.Ctx.Output.Body([]byte(contentLevel2Replacement))
}
// TestUnregisterFixedRouteRoot replaces just the root fixed route path.
// In this case, for a path like "/level1/level2" or "/level1", those actions
// should remain intact, and continue to serve the original content.
func TestUnregisterFixedRouteRoot(t *testing.T) {
var method = "GET"
handler := NewControllerRegister()
handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot")
handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1")
handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2")
// Test original root
testHelperFnContentCheck(t, handler, "Test original root",
method, "/", contentRootOriginal)
// Test original level 1
testHelperFnContentCheck(t, handler, "Test original level 1",
method, "/level1", contentLevel1Original)
// Test original level 2
testHelperFnContentCheck(t, handler, "Test original level 2",
method, "/level1/level2", contentLevel2Original)
// Remove only the root path
findAndRemoveSingleTree(handler.routers[method])
// Replace the root path TestPreUnregController action with the action from
// TestPostUnregController
handler.Add("/", &TestPostUnregController{}, "get:GetFixedRoot")
// Test replacement root (expect change)
testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement)
// Test level 1 (expect no change from the original)
testHelperFnContentCheck(t, handler, "Test level 1 (expect no change from the original)", method, "/level1", contentLevel1Original)
// Test level 2 (expect no change from the original)
testHelperFnContentCheck(t, handler, "Test level 2 (expect no change from the original)", method, "/level1/level2", contentLevel2Original)
}
// TestUnregisterFixedRouteLevel1 replaces just the "/level1" fixed route path.
// In this case, for a path like "/level1/level2" or "/", those actions
// should remain intact, and continue to serve the original content.
func TestUnregisterFixedRouteLevel1(t *testing.T) {
var method = "GET"
handler := NewControllerRegister()
handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot")
handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1")
handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2")
// Test original root
testHelperFnContentCheck(t, handler,
"TestUnregisterFixedRouteLevel1.Test original root",
method, "/", contentRootOriginal)
// Test original level 1
testHelperFnContentCheck(t, handler,
"TestUnregisterFixedRouteLevel1.Test original level 1",
method, "/level1", contentLevel1Original)
// Test original level 2
testHelperFnContentCheck(t, handler,
"TestUnregisterFixedRouteLevel1.Test original level 2",
method, "/level1/level2", contentLevel2Original)
// Remove only the level1 path
subPaths := splitPath("/level1")
if handler.routers[method].prefix == strings.Trim("/level1", "/ ") {
findAndRemoveSingleTree(handler.routers[method])
} else {
findAndRemoveTree(subPaths, handler.routers[method], method)
}
// Replace the "level1" path TestPreUnregController action with the action from
// TestPostUnregController
handler.Add("/level1", &TestPostUnregController{}, "get:GetFixedLevel1")
// Test replacement root (expect no change from the original)
testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal)
// Test level 1 (expect change)
testHelperFnContentCheck(t, handler, "Test level 1 (expect change)", method, "/level1", contentLevel1Replacement)
// Test level 2 (expect no change from the original)
testHelperFnContentCheck(t, handler, "Test level 2 (expect no change from the original)", method, "/level1/level2", contentLevel2Original)
}
// TestUnregisterFixedRouteLevel2 unregisters just the "/level1/level2" fixed
// route path. In this case, for a path like "/level1" or "/", those actions
// should remain intact, and continue to serve the original content.
func TestUnregisterFixedRouteLevel2(t *testing.T) {
var method = "GET"
handler := NewControllerRegister()
handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot")
handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1")
handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2")
// Test original root
testHelperFnContentCheck(t, handler,
"TestUnregisterFixedRouteLevel1.Test original root",
method, "/", contentRootOriginal)
// Test original level 1
testHelperFnContentCheck(t, handler,
"TestUnregisterFixedRouteLevel1.Test original level 1",
method, "/level1", contentLevel1Original)
// Test original level 2
testHelperFnContentCheck(t, handler,
"TestUnregisterFixedRouteLevel1.Test original level 2",
method, "/level1/level2", contentLevel2Original)
// Remove only the level2 path
subPaths := splitPath("/level1/level2")
if handler.routers[method].prefix == strings.Trim("/level1/level2", "/ ") {
findAndRemoveSingleTree(handler.routers[method])
} else {
findAndRemoveTree(subPaths, handler.routers[method], method)
}
// Replace the "/level1/level2" path TestPreUnregController action with the action from
// TestPostUnregController
handler.Add("/level1/level2", &TestPostUnregController{}, "get:GetFixedLevel2")
// Test replacement root (expect no change from the original)
testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal)
// Test level 1 (expect no change from the original)
testHelperFnContentCheck(t, handler, "Test level 1 (expect no change from the original)", method, "/level1", contentLevel1Original)
// Test level 2 (expect change)
testHelperFnContentCheck(t, handler, "Test level 2 (expect change)", method, "/level1/level2", contentLevel2Replacement)
}
func testHelperFnContentCheck(t *testing.T, handler *ControllerRegister,
testName, method, path, expectedBodyContent string) {
r, err := http.NewRequest(method, path, nil)
if err != nil {
t.Errorf("httpRecorderBodyTest NewRequest error: %v", err)
return
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
body := w.Body.String()
if body != expectedBodyContent {
t.Errorf("%s: expected [%s], got [%s];", testName, expectedBodyContent, body)
}
}

View File

@ -31,6 +31,22 @@ func TestSet(t *testing.T) {
} }
} }
func TestReSet(t *testing.T) {
safeMap := NewBeeMap()
if ok := safeMap.Set("astaxie", 1); !ok {
t.Error("expected", true, "got", false)
}
// set diff value
if ok := safeMap.Set("astaxie", -1); !ok {
t.Error("expected", true, "got", false)
}
// set same value
if ok := safeMap.Set("astaxie", -1); ok {
t.Error("expected", false, "got", true)
}
}
func TestCheck(t *testing.T) { func TestCheck(t *testing.T) {
if exists := safeMap.Check("astaxie"); !exists { if exists := safeMap.Check("astaxie"); !exists {
t.Error("expected", true, "got", false) t.Error("expected", true, "got", false)
@ -50,6 +66,21 @@ func TestDelete(t *testing.T) {
} }
} }
func TestItems(t *testing.T) {
safeMap := NewBeeMap()
safeMap.Set("astaxie", "hello")
for k, v := range safeMap.Items() {
key := k.(string)
value := v.(string)
if key != "astaxie" {
t.Error("expected the key should be astaxie")
}
if value != "hello" {
t.Error("expected the value should be hello")
}
}
}
func TestCount(t *testing.T) { func TestCount(t *testing.T) {
if count := safeMap.Count(); count != 0 { if count := safeMap.Count(); count != 0 {
t.Error("expected count to be", 0, "got", count) t.Error("expected count to be", 0, "got", count)

View File

@ -112,7 +112,7 @@ type Validation struct {
RequiredFirst bool RequiredFirst bool
Errors []*Error Errors []*Error
ErrorsMap map[string]*Error ErrorsMap map[string][]*Error
} }
// Clear Clean all ValidationError. // Clear Clean all ValidationError.
@ -129,7 +129,7 @@ func (v *Validation) HasErrors() bool {
// ErrorMap Return the errors mapped by key. // ErrorMap Return the errors mapped by key.
// If there are multiple validation errors associated with a single key, the // If there are multiple validation errors associated with a single key, the
// first one "wins". (Typically the first validation will be the more basic). // first one "wins". (Typically the first validation will be the more basic).
func (v *Validation) ErrorMap() map[string]*Error { func (v *Validation) ErrorMap() map[string][]*Error {
return v.ErrorsMap return v.ErrorsMap
} }
@ -278,14 +278,35 @@ func (v *Validation) apply(chk Validator, obj interface{}) *Result {
} }
} }
// AddError adds independent error message for the provided key
func (v *Validation) AddError(key, message string) {
Name := key
Field := ""
parts := strings.Split(key, ".")
if len(parts) == 2 {
Field = parts[0]
Name = parts[1]
}
err := &Error{
Message: message,
Key: key,
Name: Name,
Field: Field,
}
v.setError(err)
}
func (v *Validation) setError(err *Error) { func (v *Validation) setError(err *Error) {
v.Errors = append(v.Errors, err) v.Errors = append(v.Errors, err)
if v.ErrorsMap == nil { if v.ErrorsMap == nil {
v.ErrorsMap = make(map[string]*Error) v.ErrorsMap = make(map[string][]*Error)
} }
if _, ok := v.ErrorsMap[err.Field]; !ok { if _, ok := v.ErrorsMap[err.Field]; !ok {
v.ErrorsMap[err.Field] = err v.ErrorsMap[err.Field] = []*Error{}
} }
v.ErrorsMap[err.Field] = append(v.ErrorsMap[err.Field], err)
} }
// SetError Set error message for one field in ValidationError // SetError Set error message for one field in ValidationError