Merge branch 'master' into master

This commit is contained in:
Daniel Lin 2017-10-17 04:30:59 -05:00 committed by GitHub
commit 72ec4df679
134 changed files with 5886 additions and 1023 deletions

1
.gitignore vendored
View File

@ -1,4 +1,5 @@
.idea .idea
.vscode
.DS_Store .DS_Store
*.swp *.swp
*.swo *.swo

4
.gosimpleignore Normal file
View File

@ -0,0 +1,4 @@
github.com/astaxie/beego/*/*:S1012
github.com/astaxie/beego/*:S1012
github.com/astaxie/beego/*/*:S1007
github.com/astaxie/beego/*:S1007

View File

@ -1,9 +1,9 @@
language: go language: go
go: go:
- 1.6 - 1.6.4
- 1.5.3 - 1.7.5
- 1.4.3 - 1.8.1
services: services:
- redis-server - redis-server
- mysql - mysql
@ -31,6 +31,14 @@ install:
- go get github.com/siddontang/ledisdb/config - go get github.com/siddontang/ledisdb/config
- go get github.com/siddontang/ledisdb/ledis - go get github.com/siddontang/ledisdb/ledis
- go get github.com/ssdb/gossdb/ssdb - go get github.com/ssdb/gossdb/ssdb
- go get github.com/cloudflare/golz4
- go get github.com/gogo/protobuf/proto
- go get github.com/Knetic/govaluate
- go get github.com/casbin/casbin
- go get -u honnef.co/go/tools/cmd/gosimple
- go get -u github.com/mdempsky/unconvert
- go get -u github.com/gordonklaus/ineffassign
- go get -u github.com/golang/lint/golint
before_script: before_script:
- psql --version - psql --version
- sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi" - sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi"
@ -45,5 +53,10 @@ after_script:
- rm -rf ./res/var/* - rm -rf ./res/var/*
script: script:
- go test -v ./... - go test -v ./...
- gosimple -ignore "$(cat .gosimpleignore)" $(go list ./... | grep -v /vendor/)
- unconvert $(go list ./... | grep -v /vendor/)
- ineffassign .
- find . ! \( -path './vendor' -prune \) -type f -name '*.go' -print0 | xargs -0 gofmt -l -s
- golint ./...
addons: addons:
postgresql: "9.4" postgresql: "9.4"

View File

@ -1,20 +1,17 @@
## Beego # Beego [![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego) [![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego) [![Foundation](https://img.shields.io/badge/Golang-Foundation-green.svg)](http://golangfoundation.org)
[![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego)
[![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego)
[![Foundation](https://img.shields.io/badge/Golang-Foundation-green.svg)](http://golangfoundation.org)
beego is used for rapid development of RESTful APIs, web apps and backend services in Go. beego is used for rapid development of RESTful APIs, web apps and backend services in Go.
It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct embedding. It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct embedding.
More info [beego.me](http://beego.me) ###### More info at [beego.me](http://beego.me).
##Quick Start ## Quick Start
######Download and install
#### Download and install
go get github.com/astaxie/beego go get github.com/astaxie/beego
######Create file `hello.go` #### Create file `hello.go`
```go ```go
package main package main
@ -24,15 +21,16 @@ func main(){
beego.Run() beego.Run()
} }
``` ```
######Build and run #### Build and run
```bash
go build hello.go go build hello.go
./hello ./hello
```
######Congratulations! #### Go to [http://localhost:8080](http://localhost:8080)
You just built your first beego app.
Open your browser and visit `http://localhost:8080`. Congratulations! You've just built your first **beego** app.
Please see [Documentation](http://beego.me/docs) for more.
###### Please see [Documentation](http://beego.me/docs) for more.
## Features ## Features
@ -56,7 +54,7 @@ Please see [Documentation](http://beego.me/docs) for more.
* [http://beego.me/community](http://beego.me/community) * [http://beego.me/community](http://beego.me/community)
* Welcome to join us in Slack: [https://beego.slack.com](https://beego.slack.com), you can get invited from [here](https://github.com/beego/beedoc/issues/232) * Welcome to join us in Slack: [https://beego.slack.com](https://beego.slack.com), you can get invited from [here](https://github.com/beego/beedoc/issues/232)
## LICENSE ## License
beego source code is licensed under the Apache Licence, Version 2.0 beego source code is licensed under the Apache Licence, Version 2.0
(http://www.apache.org/licenses/LICENSE-2.0.html). (http://www.apache.org/licenses/LICENSE-2.0.html).

View File

@ -37,7 +37,7 @@ var beeAdminApp *adminApp
// FilterMonitorFunc is default monitor filter when admin module is enable. // FilterMonitorFunc is default monitor filter when admin module is enable.
// if this func returns, admin module records qbs for this request by condition of this function logic. // if this func returns, admin module records qbs for this request by condition of this function logic.
// usage: // usage:
// func MyFilterMonitor(method, requestPath string, t time.Duration) bool { // func MyFilterMonitor(method, requestPath string, t time.Duration, pattern string, statusCode int) bool {
// if method == "POST" { // if method == "POST" {
// return false // return false
// } // }
@ -50,7 +50,7 @@ var beeAdminApp *adminApp
// return true // return true
// } // }
// beego.FilterMonitorFunc = MyFilterMonitor. // beego.FilterMonitorFunc = MyFilterMonitor.
var FilterMonitorFunc func(string, string, time.Duration) bool var FilterMonitorFunc func(string, string, time.Duration, string, int) bool
func init() { func init() {
beeAdminApp = &adminApp{ beeAdminApp = &adminApp{
@ -62,7 +62,7 @@ func init() {
beeAdminApp.Route("/healthcheck", healthcheck) beeAdminApp.Route("/healthcheck", healthcheck)
beeAdminApp.Route("/task", taskStatus) beeAdminApp.Route("/task", taskStatus)
beeAdminApp.Route("/listconf", listConf) beeAdminApp.Route("/listconf", listConf)
FilterMonitorFunc = func(string, string, time.Duration) bool { return true } FilterMonitorFunc = func(string, string, time.Duration, string, int) bool { return true }
} }
// AdminIndex is the default http.Handler for admin module. // AdminIndex is the default http.Handler for admin module.
@ -105,29 +105,12 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
tmpl.Execute(rw, data) tmpl.Execute(rw, data)
case "router": case "router":
var ( content := PrintTree()
content = map[string]interface{}{ content["Fields"] = []string{
"Fields": []string{ "Router Pattern",
"Router Pattern", "Methods",
"Methods", "Controller",
"Controller",
},
}
methods = []string{}
methodsData = make(map[string]interface{})
)
for method, t := range BeeApp.Handlers.routers {
resultList := new([][]string)
printTree(resultList, t)
methods = append(methods, method)
methodsData[method] = resultList
} }
content["Data"] = methodsData
content["Methods"] = methods
data["Content"] = content data["Content"] = content
data["Title"] = "Routers" data["Title"] = "Routers"
execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl) execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl)
@ -157,8 +140,8 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
resultList := new([][]string) resultList := new([][]string)
for _, f := range bf { for _, f := range bf {
var result = []string{ var result = []string{
fmt.Sprintf("%s", f.pattern), f.pattern,
fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)), utils.GetFuncName(f.filterFunc),
} }
*resultList = append(*resultList, result) *resultList = append(*resultList, result)
} }
@ -200,6 +183,28 @@ func list(root string, p interface{}, m map[string]interface{}) {
} }
} }
// PrintTree prints all registered routers.
func PrintTree() map[string]interface{} {
var (
content = map[string]interface{}{}
methods = []string{}
methodsData = make(map[string]interface{})
)
for method, t := range BeeApp.Handlers.routers {
resultList := new([][]string)
printTree(resultList, t)
methods = append(methods, method)
methodsData[method] = resultList
}
content["Data"] = methodsData
content["Methods"] = methods
return content
}
func printTree(resultList *[][]string, t *Tree) { func printTree(resultList *[][]string, t *Tree) {
for _, tr := range t.fixrouters { for _, tr := range t.fixrouters {
printTree(resultList, tr) printTree(resultList, tr)
@ -208,12 +213,12 @@ func printTree(resultList *[][]string, t *Tree) {
printTree(resultList, t.wildcard) printTree(resultList, t.wildcard)
} }
for _, l := range t.leaves { for _, l := range t.leaves {
if v, ok := l.runObject.(*controllerInfo); ok { if v, ok := l.runObject.(*ControllerInfo); ok {
if v.routerType == routerTypeBeego { if v.routerType == routerTypeBeego {
var result = []string{ var result = []string{
v.pattern, v.pattern,
fmt.Sprintf("%s", v.methods), fmt.Sprintf("%s", v.methods),
fmt.Sprintf("%s", v.controllerType), v.controllerType.String(),
} }
*resultList = append(*resultList, result) *resultList = append(*resultList, result)
} else if v.routerType == routerTypeRESTFul { } else if v.routerType == routerTypeRESTFul {
@ -276,8 +281,8 @@ func profIndex(rw http.ResponseWriter, r *http.Request) {
// it's in "/healthcheck" pattern in admin module. // it's in "/healthcheck" pattern in admin module.
func healthcheck(rw http.ResponseWriter, req *http.Request) { func healthcheck(rw http.ResponseWriter, req *http.Request) {
var ( var (
result []string
data = make(map[interface{}]interface{}) data = make(map[interface{}]interface{})
result = []string{}
resultList = new([][]string) resultList = new([][]string)
content = map[string]interface{}{ content = map[string]interface{}{
"Fields": []string{"Name", "Message", "Status"}, "Fields": []string{"Name", "Message", "Status"},
@ -287,21 +292,20 @@ func healthcheck(rw http.ResponseWriter, req *http.Request) {
for name, h := range toolbox.AdminCheckList { for name, h := range toolbox.AdminCheckList {
if err := h.Check(); err != nil { if err := h.Check(); err != nil {
result = []string{ result = []string{
fmt.Sprintf("error"), "error",
fmt.Sprintf("%s", name), name,
fmt.Sprintf("%s", err.Error()), err.Error(),
} }
} else { } else {
result = []string{ result = []string{
fmt.Sprintf("success"), "success",
fmt.Sprintf("%s", name), name,
fmt.Sprintf("OK"), "OK",
} }
} }
*resultList = append(*resultList, result) *resultList = append(*resultList, result)
} }
content["Data"] = resultList content["Data"] = resultList
data["Content"] = content data["Content"] = content
data["Title"] = "Health Check" data["Title"] = "Health Check"
@ -330,7 +334,6 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) {
// List Tasks // List Tasks
content := make(map[string]interface{}) content := make(map[string]interface{})
resultList := new([][]string) resultList := new([][]string)
var result = []string{}
var fields = []string{ var fields = []string{
"Task Name", "Task Name",
"Task Spec", "Task Spec",
@ -339,10 +342,10 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) {
"", "",
} }
for tname, tk := range toolbox.AdminTaskList { for tname, tk := range toolbox.AdminTaskList {
result = []string{ result := []string{
tname, tname,
fmt.Sprintf("%s", tk.GetSpec()), tk.GetSpec(),
fmt.Sprintf("%s", tk.GetStatus()), tk.GetStatus(),
tk.GetPrev().String(), tk.GetPrev().String(),
} }
*resultList = append(*resultList, result) *resultList = append(*resultList, result)

6
app.go
View File

@ -348,9 +348,9 @@ func Any(rootpath string, f FilterFunc) *App {
// Handler used to register a Handler router // Handler used to register a Handler router
// usage: // usage:
// beego.Handler("/api", func(ctx *context.Context){ // beego.Handler("/api", http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {
// ctx.Output.Body("hello world") // fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path))
// }) // }))
func Handler(rootpath string, h http.Handler, options ...interface{}) *App { func Handler(rootpath string, h http.Handler, options ...interface{}) *App {
BeeApp.Handlers.Handler(rootpath, h, options...) BeeApp.Handlers.Handler(rootpath, h, options...)
return BeeApp return BeeApp

View File

@ -23,7 +23,7 @@ import (
const ( const (
// VERSION represent beego web framework version. // VERSION represent beego web framework version.
VERSION = "1.7.2" VERSION = "1.9.0"
// DEV is for develop // DEV is for develop
DEV = "dev" DEV = "dev"
@ -40,9 +40,9 @@ var (
// AddAPPStartHook is used to register the hookfunc // AddAPPStartHook is used to register the hookfunc
// The hookfuncs will run in beego.Run() // The hookfuncs will run in beego.Run()
// such as sessionInit, middleware start, buildtemplate, admin start // such as initiating session , starting middleware , building template, starting admin control and so on.
func AddAPPStartHook(hf hookfunc) { func AddAPPStartHook(hf ...hookfunc) {
hooks = append(hooks, hf) hooks = append(hooks, hf...)
} }
// Run beego application. // Run beego application.
@ -69,12 +69,14 @@ func Run(params ...string) {
func initBeforeHTTPRun() { func initBeforeHTTPRun() {
//init hooks //init hooks
AddAPPStartHook(registerMime) AddAPPStartHook(
AddAPPStartHook(registerDefaultErrorHandler) registerMime,
AddAPPStartHook(registerSession) registerDefaultErrorHandler,
AddAPPStartHook(registerTemplate) registerSession,
AddAPPStartHook(registerAdmin) registerTemplate,
AddAPPStartHook(registerGzip) registerAdmin,
registerGzip,
)
for _, hk := range hooks { for _, hk := range hooks {
if err := hk(); err != nil { if err := hk(); err != nil {

2
cache/conv.go vendored
View File

@ -28,7 +28,7 @@ func GetString(v interface{}) string {
return string(result) return string(result)
default: default:
if v != nil { if v != nil {
return fmt.Sprintf("%v", result) return fmt.Sprint(result)
} }
} }
return "" return ""

6
cache/conv_test.go vendored
View File

@ -118,14 +118,14 @@ func TestGetFloat64(t *testing.T) {
func TestGetBool(t *testing.T) { func TestGetBool(t *testing.T) {
var t1 = true var t1 = true
if true != GetBool(t1) { if !GetBool(t1) {
t.Error("get bool from bool error") t.Error("get bool from bool error")
} }
var t2 = "true" var t2 = "true"
if true != GetBool(t2) { if !GetBool(t2) {
t.Error("get bool from string error") t.Error("get bool from string error")
} }
if false != GetBool(nil) { if GetBool(nil) {
t.Error("get bool from nil error") t.Error("get bool from nil error")
} }
} }

25
cache/file.go vendored
View File

@ -22,6 +22,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
@ -222,33 +223,13 @@ func exists(path string) (bool, error) {
// FileGetContents Get bytes to file. // FileGetContents Get bytes to file.
// if non-exist, create this file. // if non-exist, create this file.
func FileGetContents(filename string) (data []byte, e error) { func FileGetContents(filename string) (data []byte, e error) {
f, e := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, os.ModePerm) return ioutil.ReadFile(filename)
if e != nil {
return
}
defer f.Close()
stat, e := f.Stat()
if e != nil {
return
}
data = make([]byte, stat.Size())
result, e := f.Read(data)
if e != nil || int64(result) != stat.Size() {
return nil, e
}
return
} }
// FilePutContents Put bytes to file. // FilePutContents Put bytes to file.
// if non-exist, create this file. // if non-exist, create this file.
func FilePutContents(filename string, content []byte) error { func FilePutContents(filename string, content []byte) error {
fp, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, os.ModePerm) return ioutil.WriteFile(filename, content, os.ModePerm)
if err != nil {
return err
}
defer fp.Close()
_, err = fp.Write(content)
return err
} }
// GobEncode Gob encodes file cache item. // GobEncode Gob encodes file cache item.

View File

@ -146,10 +146,7 @@ func (rc *Cache) IsExist(key string) bool {
} }
} }
_, err := rc.conn.Get(key) _, err := rc.conn.Get(key)
if err != nil { return !(err != nil)
return false
}
return true
} }
// ClearAll clear all cached in memcache. // ClearAll clear all cached in memcache.

31
cache/memory.go vendored
View File

@ -217,26 +217,31 @@ func (bc *MemoryCache) vaccuum() {
if bc.items == nil { if bc.items == nil {
return return
} }
for name := range bc.items { if keys := bc.expiredKeys(); len(keys) != 0 {
bc.itemExpired(name) bc.clearItems(keys)
} }
} }
} }
// itemExpired returns true if an item is expired. // expiredKeys returns key list which are expired.
func (bc *MemoryCache) itemExpired(name string) bool { func (bc *MemoryCache) expiredKeys() (keys []string) {
bc.RLock()
defer bc.RUnlock()
for key, itm := range bc.items {
if itm.isExpire() {
keys = append(keys, key)
}
}
return
}
// clearItems removes all the items which key in keys.
func (bc *MemoryCache) clearItems(keys []string) {
bc.Lock() bc.Lock()
defer bc.Unlock() defer bc.Unlock()
for _, key := range keys {
itm, ok := bc.items[name] delete(bc.items, key)
if !ok {
return true
} }
if itm.isExpire() {
delete(bc.items, name)
return true
}
return false
} }
func init() { func init() {

View File

@ -137,7 +137,7 @@ func (rc *Cache) IsExist(key string) bool {
if err != nil { if err != nil {
return false return false
} }
if v == false { if !v {
if _, err = rc.do("HDEL", rc.key, key); err != nil { if _, err = rc.do("HDEL", rc.key, key); err != nil {
return false return false
} }

21
cache/ssdb/ssdb.go vendored
View File

@ -53,7 +53,7 @@ func (rc *Cache) GetMulti(keys []string) []interface{} {
resSize := len(res) resSize := len(res)
if err == nil { if err == nil {
for i := 1; i < resSize; i += 2 { for i := 1; i < resSize; i += 2 {
values = append(values, string(res[i+1])) values = append(values, res[i+1])
} }
return values return values
} }
@ -71,10 +71,7 @@ func (rc *Cache) DelMulti(keys []string) error {
} }
} }
_, err := rc.conn.Do("multi_del", keys) _, err := rc.conn.Do("multi_del", keys)
if err != nil { return err
return err
}
return nil
} }
// Put put value to memcache. only support string. // Put put value to memcache. only support string.
@ -113,10 +110,7 @@ func (rc *Cache) Delete(key string) error {
} }
} }
_, err := rc.conn.Del(key) _, err := rc.conn.Del(key)
if err != nil { return err
return err
}
return nil
} }
// Incr increase counter. // Incr increase counter.
@ -152,7 +146,7 @@ func (rc *Cache) IsExist(key string) bool {
if err != nil { if err != nil {
return false return false
} }
if resp[1] == "1" { if len(resp) == 2 && resp[1] == "1" {
return true return true
} }
return false return false
@ -175,7 +169,7 @@ func (rc *Cache) ClearAll() error {
} }
keys := []string{} keys := []string{}
for i := 1; i < size; i += 2 { for i := 1; i < size; i += 2 {
keys = append(keys, string(resp[i])) keys = append(keys, resp[i])
} }
_, e := rc.conn.Do("multi_del", keys) _, e := rc.conn.Do("multi_del", keys)
if e != nil { if e != nil {
@ -229,10 +223,7 @@ func (rc *Cache) connectInit() error {
} }
var err error var err error
rc.conn, err = ssdb.Connect(host, port) rc.conn, err = ssdb.Connect(host, port)
if err != nil { return err
return err
}
return nil
} }
func init() { func init() {

View File

@ -41,6 +41,7 @@ type Config struct {
EnableGzip bool EnableGzip bool
MaxMemory int64 MaxMemory int64
EnableErrorsShow bool EnableErrorsShow bool
EnableErrorsRender bool
Listen Listen Listen Listen
WebConfig WebConfig WebConfig WebConfig
Log LogConfig Log LogConfig
@ -144,9 +145,6 @@ func init() {
if err = parseConfig(appConfigPath); err != nil { if err = parseConfig(appConfigPath); err != nil {
panic(err) panic(err)
} }
if err = os.Chdir(AppPath); err != nil {
panic(err)
}
} }
func recoverPanic(ctx *context.Context) { func recoverPanic(ctx *context.Context) {
@ -174,7 +172,7 @@ func recoverPanic(ctx *context.Context) {
logs.Critical(fmt.Sprintf("%s:%d", file, line)) logs.Critical(fmt.Sprintf("%s:%d", file, line))
stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line)) stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line))
} }
if BConfig.RunMode == DEV { if BConfig.RunMode == DEV && BConfig.EnableErrorsRender {
showErr(err, ctx, stack) showErr(err, ctx, stack)
} }
} }
@ -192,6 +190,7 @@ func newBConfig() *Config {
EnableGzip: false, EnableGzip: false,
MaxMemory: 1 << 26, //64MB MaxMemory: 1 << 26, //64MB
EnableErrorsShow: true, EnableErrorsShow: true,
EnableErrorsRender: true,
Listen: Listen{ Listen: Listen{
Graceful: false, Graceful: false,
ServerTimeOut: 0, ServerTimeOut: 0,
@ -257,6 +256,9 @@ func parseConfig(appConfigPath string) (err error) {
} }
func assignConfig(ac config.Configer) error { func assignConfig(ac config.Configer) error {
for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} {
assignSingleConfig(i, ac)
}
// set the run mode first // set the run mode first
if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" { if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" {
BConfig.RunMode = envRunMode BConfig.RunMode = envRunMode
@ -264,10 +266,6 @@ func assignConfig(ac config.Configer) error {
BConfig.RunMode = runMode BConfig.RunMode = runMode
} }
for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} {
assignSingleConfig(i, ac)
}
if sd := ac.String("StaticDir"); sd != "" { if sd := ac.String("StaticDir"); sd != "" {
BConfig.WebConfig.StaticDir = map[string]string{} BConfig.WebConfig.StaticDir = map[string]string{}
sds := strings.Fields(sd) sds := strings.Fields(sd)
@ -347,7 +345,7 @@ func assignSingleConfig(p interface{}, ac config.Configer) {
case reflect.String: case reflect.String:
pf.SetString(ac.DefaultString(name, pf.String())) pf.SetString(ac.DefaultString(name, pf.String()))
case reflect.Int, reflect.Int64: case reflect.Int, reflect.Int64:
pf.SetInt(int64(ac.DefaultInt64(name, pf.Int()))) pf.SetInt(ac.DefaultInt64(name, pf.Int()))
case reflect.Bool: case reflect.Bool:
pf.SetBool(ac.DefaultBool(name, pf.Bool())) pf.SetBool(ac.DefaultBool(name, pf.Bool()))
case reflect.Struct: case reflect.Struct:

View File

@ -189,16 +189,16 @@ func ParseBool(val interface{}) (value bool, err error) {
return false, nil return false, nil
} }
case int8, int32, int64: case int8, int32, int64:
strV := fmt.Sprintf("%s", v) strV := fmt.Sprintf("%d", v)
if strV == "1" { if strV == "1" {
return true, nil return true, nil
} else if strV == "0" { } else if strV == "0" {
return false, nil return false, nil
} }
case float64: case float64:
if v == 1 { if v == 1.0 {
return true, nil return true, nil
} else if v == 0 { } else if v == 0.0 {
return false, nil return false, nil
} }
} }

87
config/env/env.go vendored Normal file
View File

@ -0,0 +1,87 @@
// Copyright 2014 beego Author. All Rights Reserved.
// Copyright 2017 Faissal Elamraoui. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package env is used to parse environment.
package env
import (
"fmt"
"os"
"strings"
"github.com/astaxie/beego/utils"
)
var env *utils.BeeMap
func init() {
env = utils.NewBeeMap()
for _, e := range os.Environ() {
splits := strings.Split(e, "=")
env.Set(splits[0], os.Getenv(splits[0]))
}
}
// Get returns a value by key.
// If the key does not exist, the default value will be returned.
func Get(key string, defVal string) string {
if val := env.Get(key); val != nil {
return val.(string)
}
return defVal
}
// MustGet returns a value by key.
// If the key does not exist, it will return an error.
func MustGet(key string) (string, error) {
if val := env.Get(key); val != nil {
return val.(string), nil
}
return "", fmt.Errorf("no env variable with %s", key)
}
// Set sets a value in the ENV copy.
// This does not affect the child process environment.
func Set(key string, value string) {
env.Set(key, value)
}
// MustSet sets a value in the ENV copy and the child process environment.
// It returns an error in case the set operation failed.
func MustSet(key string, value string) error {
err := os.Setenv(key, value)
if err != nil {
return err
}
env.Set(key, value)
return nil
}
// GetAll returns all keys/values in the current child process environment.
func GetAll() map[string]string {
items := env.Items()
envs := make(map[string]string, env.Count())
for key, val := range items {
switch key := key.(type) {
case string:
switch val := val.(type) {
case string:
envs[key] = val
}
}
}
return envs
}

75
config/env/env_test.go vendored Normal file
View File

@ -0,0 +1,75 @@
// Copyright 2014 beego Author. All Rights Reserved.
// Copyright 2017 Faissal Elamraoui. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package env
import (
"os"
"testing"
)
func TestEnvGet(t *testing.T) {
gopath := Get("GOPATH", "")
if gopath != os.Getenv("GOPATH") {
t.Error("expected GOPATH not empty.")
}
noExistVar := Get("NOEXISTVAR", "foo")
if noExistVar != "foo" {
t.Errorf("expected NOEXISTVAR to equal foo, got %s.", noExistVar)
}
}
func TestEnvMustGet(t *testing.T) {
gopath, err := MustGet("GOPATH")
if err != nil {
t.Error(err)
}
if gopath != os.Getenv("GOPATH") {
t.Errorf("expected GOPATH to be the same, got %s.", gopath)
}
_, err = MustGet("NOEXISTVAR")
if err == nil {
t.Error("expected error to be non-nil")
}
}
func TestEnvSet(t *testing.T) {
Set("MYVAR", "foo")
myVar := Get("MYVAR", "bar")
if myVar != "foo" {
t.Errorf("expected MYVAR to equal foo, got %s.", myVar)
}
}
func TestEnvMustSet(t *testing.T) {
err := MustSet("FOO", "bar")
if err != nil {
t.Error(err)
}
fooVar := os.Getenv("FOO")
if fooVar != "bar" {
t.Errorf("expected FOO variable to equal bar, got %s.", fooVar)
}
}
func TestEnvGetAll(t *testing.T) {
envMap := GetAll()
if len(envMap) == 0 {
t.Error("expected environment not empty.")
}
}

View File

@ -18,16 +18,14 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"errors" "errors"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
"path" "os/user"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time"
) )
var ( var (
@ -52,24 +50,26 @@ func (ini *IniConfig) Parse(name string) (Configer, error) {
} }
func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
file, err := os.Open(name) data, err := ioutil.ReadFile(name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return ini.parseData(filepath.Dir(name), data)
}
func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, error) {
cfg := &IniConfigContainer{ cfg := &IniConfigContainer{
file.Name(), data: make(map[string]map[string]string),
make(map[string]map[string]string), sectionComment: make(map[string]string),
make(map[string]string), keyComment: make(map[string]string),
make(map[string]string), RWMutex: sync.RWMutex{},
sync.RWMutex{},
} }
cfg.Lock() cfg.Lock()
defer cfg.Unlock() defer cfg.Unlock()
defer file.Close()
var comment bytes.Buffer var comment bytes.Buffer
buf := bufio.NewReader(file) buf := bufio.NewReader(bytes.NewBuffer(data))
// check the BOM // check the BOM
head, err := buf.Peek(3) head, err := buf.Peek(3)
if err == nil && head[0] == 239 && head[1] == 187 && head[2] == 191 { if err == nil && head[0] == 239 && head[1] == 187 && head[2] == 191 {
@ -130,16 +130,20 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
// handle include "other.conf" // handle include "other.conf"
if len(keyValue) == 1 && strings.HasPrefix(key, "include") { if len(keyValue) == 1 && strings.HasPrefix(key, "include") {
includefiles := strings.Fields(key) includefiles := strings.Fields(key)
if includefiles[0] == "include" && len(includefiles) == 2 { if includefiles[0] == "include" && len(includefiles) == 2 {
otherfile := strings.Trim(includefiles[1], "\"") otherfile := strings.Trim(includefiles[1], "\"")
if !filepath.IsAbs(otherfile) { if !filepath.IsAbs(otherfile) {
otherfile = filepath.Join(filepath.Dir(name), otherfile) otherfile = filepath.Join(dir, otherfile)
} }
i, err := ini.parseFile(otherfile) i, err := ini.parseFile(otherfile)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for sec, dt := range i.data { for sec, dt := range i.data {
if _, ok := cfg.data[sec]; !ok { if _, ok := cfg.data[sec]; !ok {
cfg.data[sec] = make(map[string]string) cfg.data[sec] = make(map[string]string)
@ -148,12 +152,15 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
cfg.data[sec][k] = v cfg.data[sec][k] = v
} }
} }
for sec, comm := range i.sectionComment { for sec, comm := range i.sectionComment {
cfg.sectionComment[sec] = comm cfg.sectionComment[sec] = comm
} }
for k, comm := range i.keyComment { for k, comm := range i.keyComment {
cfg.keyComment[k] = comm cfg.keyComment[k] = comm
} }
continue continue
} }
} }
@ -177,20 +184,25 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
} }
// ParseData parse ini the data // ParseData parse ini the data
// When include other.conf,other.conf is either absolute directory
// or under beego in default temporary directory(/tmp/beego[-username]).
func (ini *IniConfig) ParseData(data []byte) (Configer, error) { func (ini *IniConfig) ParseData(data []byte) (Configer, error) {
// Save memory data to temporary file dir := "beego"
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond())) currentUser, err := user.Current()
os.MkdirAll(path.Dir(tmpName), os.ModePerm) if err == nil {
if err := ioutil.WriteFile(tmpName, data, 0655); err != nil { dir = "beego-" + currentUser.Username
}
dir = filepath.Join(os.TempDir(), dir)
if err = os.MkdirAll(dir, os.ModePerm); err != nil {
return nil, err return nil, err
} }
return ini.Parse(tmpName)
return ini.parseData(dir, data)
} }
// IniConfigContainer A Config represents the ini configuration. // IniConfigContainer A Config represents the ini configuration.
// When set and get value, support key as section:name type. // When set and get value, support key as section:name type.
type IniConfigContainer struct { type IniConfigContainer struct {
filename string
data map[string]map[string]string // section=> key:val data map[string]map[string]string // section=> key:val
sectionComment map[string]string // section : comment sectionComment map[string]string // section : comment
keyComment map[string]string // id: []{comment, key...}; id 1 is for main comment. keyComment map[string]string // id: []{comment, key...}; id 1 is for main comment.
@ -297,7 +309,7 @@ func (c *IniConfigContainer) GetSection(section string) (map[string]string, erro
if v, ok := c.data[section]; ok { if v, ok := c.data[section]; ok {
return v, nil return v, nil
} }
return nil, errors.New("not exist setction") return nil, errors.New("not exist section")
} }
// SaveConfigFile save the config into file. // SaveConfigFile save the config into file.
@ -313,7 +325,10 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
// Get section or key comments. Fixed #1607 // Get section or key comments. Fixed #1607
getCommentStr := func(section, key string) string { getCommentStr := func(section, key string) string {
comment, ok := "", false var (
comment string
ok bool
)
if len(key) == 0 { if len(key) == 0 {
comment, ok = c.sectionComment[section] comment, ok = c.sectionComment[section]
} else { } else {
@ -393,11 +408,8 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
} }
} }
} }
_, err = buf.WriteTo(f)
if _, err = buf.WriteTo(f); err != nil { return err
return err
}
return nil
} }
// Set writes a new value for key. // Set writes a new value for key.
@ -412,7 +424,7 @@ func (c *IniConfigContainer) Set(key, value string) error {
var ( var (
section, k string section, k string
sectionKey = strings.Split(key, "::") sectionKey = strings.Split(strings.ToLower(key), "::")
) )
if len(sectionKey) >= 2 { if len(sectionKey) >= 2 {

View File

@ -181,7 +181,7 @@ name=mysql
cfgData := string(data) cfgData := string(data)
datas := strings.Split(saveResult, "\n") datas := strings.Split(saveResult, "\n")
for _, line := range datas { for _, line := range datas {
if strings.Contains(cfgData, line+"\n") == false { if !strings.Contains(cfgData, line+"\n") {
t.Fatalf("different after save ini config file. need contains %q", line) t.Fatalf("different after save ini config file. need contains %q", line)
} }
} }

View File

@ -35,11 +35,9 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"path"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time"
"github.com/astaxie/beego/config" "github.com/astaxie/beego/config"
"github.com/beego/x2j" "github.com/beego/x2j"
@ -52,36 +50,26 @@ type Config struct{}
// Parse returns a ConfigContainer with parsed xml config map. // Parse returns a ConfigContainer with parsed xml config map.
func (xc *Config) Parse(filename string) (config.Configer, error) { func (xc *Config) Parse(filename string) (config.Configer, error) {
file, err := os.Open(filename) context, err := ioutil.ReadFile(filename)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer file.Close()
return xc.ParseData(context)
}
// ParseData xml data
func (xc *Config) ParseData(data []byte) (config.Configer, error) {
x := &ConfigContainer{data: make(map[string]interface{})} x := &ConfigContainer{data: make(map[string]interface{})}
content, err := ioutil.ReadAll(file)
if err != nil {
return nil, err
}
d, err := x2j.DocToMap(string(content)) d, err := x2j.DocToMap(string(data))
if err != nil { if err != nil {
return nil, err return nil, err
} }
x.data = config.ExpandValueEnvForMap(d["config"].(map[string]interface{})) x.data = config.ExpandValueEnvForMap(d["config"].(map[string]interface{}))
return x, nil
}
// ParseData xml data return x, nil
func (xc *Config) ParseData(data []byte) (config.Configer, error) {
// Save memory data to temporary file
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond()))
os.MkdirAll(path.Dir(tmpName), os.ModePerm)
if err := ioutil.WriteFile(tmpName, data, 0655); err != nil {
return nil, err
}
return xc.Parse(tmpName)
} }
// ConfigContainer A Config represents the xml configuration. // ConfigContainer A Config represents the xml configuration.

View File

@ -37,10 +37,8 @@ import (
"io/ioutil" "io/ioutil"
"log" "log"
"os" "os"
"path"
"strings" "strings"
"sync" "sync"
"time"
"github.com/astaxie/beego/config" "github.com/astaxie/beego/config"
"github.com/beego/goyaml2" "github.com/beego/goyaml2"
@ -63,26 +61,30 @@ func (yaml *Config) Parse(filename string) (y config.Configer, err error) {
// ParseData parse yaml data // ParseData parse yaml data
func (yaml *Config) ParseData(data []byte) (config.Configer, error) { func (yaml *Config) ParseData(data []byte) (config.Configer, error) {
// Save memory data to temporary file cnf, err := parseYML(data)
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond())) if err != nil {
os.MkdirAll(path.Dir(tmpName), os.ModePerm)
if err := ioutil.WriteFile(tmpName, data, 0655); err != nil {
return nil, err return nil, err
} }
return yaml.Parse(tmpName)
return &ConfigContainer{
data: cnf,
}, nil
} }
// ReadYmlReader Read yaml file to map. // ReadYmlReader Read yaml file to map.
// if json like, use json package, unless goyaml2 package. // if json like, use json package, unless goyaml2 package.
func ReadYmlReader(path string) (cnf map[string]interface{}, err error) { func ReadYmlReader(path string) (cnf map[string]interface{}, err error) {
f, err := os.Open(path) buf, err := ioutil.ReadFile(path)
if err != nil { if err != nil {
return return
} }
defer f.Close()
buf, err := ioutil.ReadAll(f) return parseYML(buf)
if err != nil || len(buf) < 3 { }
// parseYML parse yaml formatted []byte to map.
func parseYML(buf []byte) (cnf map[string]interface{}, err error) {
if len(buf) < 3 {
return return
} }
@ -250,7 +252,7 @@ func (c *ConfigContainer) GetSection(section string) (map[string]string, error)
if v, ok := c.data[section]; ok { if v, ok := c.data[section]; ok {
return v.(map[string]string), nil return v.(map[string]string), nil
} }
return nil, errors.New("not exist setction") return nil, errors.New("not exist section")
} }
// SaveConfigFile save the config into file // SaveConfigFile save the config into file

View File

@ -39,6 +39,7 @@ var (
getMethodOnly bool getMethodOnly bool
) )
// InitGzip init the gzipcompress
func InitGzip(minLength, compressLevel int, methods []string) { func InitGzip(minLength, compressLevel int, methods []string) {
if minLength >= 0 { if minLength >= 0 {
gzipMinLength = minLength gzipMinLength = minLength

View File

@ -171,6 +171,22 @@ func (ctx *Context) CheckXSRFCookie() bool {
return true return true
} }
// RenderMethodResult renders the return value of a controller method to the output
func (ctx *Context) RenderMethodResult(result interface{}) {
if result != nil {
renderer, ok := result.(Renderer)
if !ok {
err, ok := result.(error)
if ok {
renderer = errorRenderer(err)
} else {
renderer = jsonRenderer(result)
}
}
renderer.Render(ctx)
}
}
//Response is a wrapper for the http.ResponseWriter //Response is a wrapper for the http.ResponseWriter
//started set to true if response was written to then don't execute other handler //started set to true if response was written to then don't execute other handler
type Response struct { type Response struct {

View File

@ -16,9 +16,11 @@ package context
import ( import (
"bytes" "bytes"
"compress/gzip"
"errors" "errors"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http"
"net/url" "net/url"
"reflect" "reflect"
"regexp" "regexp"
@ -349,11 +351,22 @@ func (input *BeegoInput) CopyBody(MaxMemory int64) []byte {
if input.Context.Request.Body == nil { if input.Context.Request.Body == nil {
return []byte{} return []byte{}
} }
var requestbody []byte
safe := &io.LimitedReader{R: input.Context.Request.Body, N: MaxMemory} safe := &io.LimitedReader{R: input.Context.Request.Body, N: MaxMemory}
requestbody, _ := ioutil.ReadAll(safe) if input.Header("Content-Encoding") == "gzip" {
reader, err := gzip.NewReader(safe)
if err != nil {
return nil
}
requestbody, _ = ioutil.ReadAll(reader)
} else {
requestbody, _ = ioutil.ReadAll(safe)
}
input.Context.Request.Body.Close() input.Context.Request.Body.Close()
bf := bytes.NewBuffer(requestbody) bf := bytes.NewBuffer(requestbody)
input.Context.Request.Body = ioutil.NopCloser(bf) input.Context.Request.Body = http.MaxBytesReader(input.Context.ResponseWriter, ioutil.NopCloser(bf), MaxMemory)
input.RequestBody = requestbody input.RequestBody = requestbody
return requestbody return requestbody
} }
@ -413,7 +426,13 @@ func (input *BeegoInput) Bind(dest interface{}, key string) error {
if !value.CanSet() { if !value.CanSet() {
return errors.New("beego: non-settable variable passed to Bind: " + key) return errors.New("beego: non-settable variable passed to Bind: " + key)
} }
rv := input.bind(key, value.Type()) typ := value.Type()
// Get real type if dest define with interface{}.
// e.g var dest interface{} dest=1.0
if value.Kind() == reflect.Interface {
typ = value.Elem().Type()
}
rv := input.bind(key, typ)
if !rv.IsValid() { if !rv.IsValid() {
return errors.New("beego: reflect value is empty") return errors.New("beego: reflect value is empty")
} }
@ -422,6 +441,9 @@ func (input *BeegoInput) Bind(dest interface{}, key string) error {
} }
func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value { func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value {
if input.Context.Request.Form == nil {
input.Context.Request.ParseForm()
}
rv := reflect.Zero(typ) rv := reflect.Zero(typ)
switch typ.Kind() { switch typ.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:

View File

@ -15,81 +15,97 @@
package context package context
import ( import (
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
"testing" "testing"
) )
func TestParse(t *testing.T) { func TestBind(t *testing.T) {
r, _ := http.NewRequest("GET", "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil) type testItem struct {
beegoInput := NewInput() field string
beegoInput.Context = NewContext() empty interface{}
beegoInput.Context.Reset(httptest.NewRecorder(), r) want interface{}
beegoInput.ParseFormOrMulitForm(1 << 20) }
type Human struct {
ID int
Nick string
Pwd string
Ms bool
}
var id int cases := []struct {
err := beegoInput.Bind(&id, "id") request string
if id != 123 || err != nil { valueGp []testItem
t.Fatal("id should has int value") }{
} {"/?p=str", []testItem{{"p", interface{}(""), interface{}("str")}}},
fmt.Println(id)
var isok bool {"/?p=", []testItem{{"p", "", ""}}},
err = beegoInput.Bind(&isok, "isok") {"/?p=str", []testItem{{"p", "", "str"}}},
if !isok || err != nil {
t.Fatal("isok should be true")
}
fmt.Println(isok)
var float float64 {"/?p=123", []testItem{{"p", 0, 123}}},
err = beegoInput.Bind(&float, "ft") {"/?p=123", []testItem{{"p", uint(0), uint(123)}}},
if float != 1.2 || err != nil {
t.Fatal("float should be equal to 1.2")
}
fmt.Println(float)
ol := make([]int, 0, 2) {"/?p=1.0", []testItem{{"p", 0.0, 1.0}}},
err = beegoInput.Bind(&ol, "ol") {"/?p=1", []testItem{{"p", false, true}}},
if len(ol) != 2 || err != nil || ol[0] != 1 || ol[1] != 2 {
t.Fatal("ol should has two elements")
}
fmt.Println(ol)
ul := make([]string, 0, 2) {"/?p=true", []testItem{{"p", false, true}}},
err = beegoInput.Bind(&ul, "ul") {"/?p=ON", []testItem{{"p", false, true}}},
if len(ul) != 2 || err != nil || ul[0] != "str" || ul[1] != "array" { {"/?p=on", []testItem{{"p", false, true}}},
t.Fatal("ul should has two elements") {"/?p=1", []testItem{{"p", false, true}}},
} {"/?p=2", []testItem{{"p", false, false}}},
fmt.Println(ul) {"/?p=false", []testItem{{"p", false, false}}},
type User struct { {"/?p[a]=1&p[b]=2&p[c]=3", []testItem{{"p", map[string]int{}, map[string]int{"a": 1, "b": 2, "c": 3}}}},
Name string {"/?p[a]=v1&p[b]=v2&p[c]=v3", []testItem{{"p", map[string]string{}, map[string]string{"a": "v1", "b": "v2", "c": "v3"}}}},
}
user := User{}
err = beegoInput.Bind(&user, "user")
if err != nil || user.Name != "astaxie" {
t.Fatal("user should has name")
}
fmt.Println(user)
}
func TestParse2(t *testing.T) { {"/?p[]=8&p[]=9&p[]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}},
r, _ := http.NewRequest("GET", "/?user[0][Username]=Raph&user[1].Username=Leo&user[0].Password=123456&user[1][Password]=654321", nil) {"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}},
beegoInput := NewInput() {"/?p[0]=8&p[1]=9&p[2]=10&p[5]=14", []testItem{{"p", []int{}, []int{8, 9, 10, 0, 0, 14}}}},
beegoInput.Context = NewContext() {"/?p[0]=8.0&p[1]=9.0&p[2]=10.0", []testItem{{"p", []float64{}, []float64{8.0, 9.0, 10.0}}}},
beegoInput.Context.Reset(httptest.NewRecorder(), r)
beegoInput.ParseFormOrMulitForm(1 << 20) {"/?p[]=10&p[]=9&p[]=8", []testItem{{"p", []string{}, []string{"10", "9", "8"}}}},
type User struct { {"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []string{}, []string{"8", "9", "10"}}}},
Username string
Password string {"/?p[0]=true&p[1]=false&p[2]=true&p[5]=1&p[6]=ON&p[7]=other", []testItem{{"p", []bool{}, []bool{true, false, true, false, false, true, true, false}}}},
{"/?human.Nick=astaxie", []testItem{{"human", Human{}, Human{Nick: "astaxie"}}}},
{"/?human.ID=888&human.Nick=astaxie&human.Ms=true&human[Pwd]=pass", []testItem{{"human", Human{}, Human{ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass"}}}},
{"/?human[0].ID=888&human[0].Nick=astaxie&human[0].Ms=true&human[0][Pwd]=pass01&human[1].ID=999&human[1].Nick=ysqi&human[1].Ms=On&human[1].Pwd=pass02",
[]testItem{{"human", []Human{}, []Human{
{ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass01"},
{ID: 999, Nick: "ysqi", Ms: true, Pwd: "pass02"},
}}}},
{
"/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&human.Nick=astaxie",
[]testItem{
{"id", 0, 123},
{"isok", false, true},
{"ft", 0.0, 1.2},
{"ol", []int{}, []int{1, 2}},
{"ul", []string{}, []string{"str", "array"}},
{"human", Human{}, Human{Nick: "astaxie"}},
},
},
} }
var users []User for _, c := range cases {
err := beegoInput.Bind(&users, "user") r, _ := http.NewRequest("GET", c.request, nil)
fmt.Println(users) beegoInput := NewInput()
if err != nil || users[0].Username != "Raph" || users[0].Password != "123456" || users[1].Username != "Leo" || users[1].Password != "654321" { beegoInput.Context = NewContext()
t.Fatal("users info wrong") beegoInput.Context.Reset(httptest.NewRecorder(), r)
for _, item := range c.valueGp {
got := item.empty
err := beegoInput.Bind(&got, item.field)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, item.want) {
t.Fatalf("Bind %q error,should be:\n%#v \ngot:\n%#v", item.field, item.want, got)
}
}
} }
} }

View File

@ -67,6 +67,7 @@ func (output *BeegoOutput) Body(content []byte) error {
} }
if b, n, _ := WriteBody(encoding, buf, content); b { if b, n, _ := WriteBody(encoding, buf, content); b {
output.Header("Content-Encoding", n) output.Header("Content-Encoding", n)
output.Header("Content-Length", strconv.Itoa(buf.Len()))
} else { } else {
output.Header("Content-Length", strconv.Itoa(len(content))) output.Header("Content-Length", strconv.Itoa(len(content)))
} }
@ -167,6 +168,19 @@ func sanitizeValue(v string) string {
return cookieValueSanitizer.Replace(v) return cookieValueSanitizer.Replace(v)
} }
func jsonRenderer(value interface{}) Renderer {
return rendererFunc(func(ctx *Context) {
ctx.Output.JSON(value, false, false)
})
}
func errorRenderer(err error) Renderer {
return rendererFunc(func(ctx *Context) {
ctx.Output.SetStatus(500)
ctx.Output.Body([]byte(err.Error()))
})
}
// JSON writes json to response body. // JSON writes json to response body.
// if coding is true, it converts utf-8 to \u0000 type. // if coding is true, it converts utf-8 to \u0000 type.
func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, coding bool) error { func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, coding bool) error {
@ -329,17 +343,17 @@ func (output *BeegoOutput) IsServerError() bool {
} }
func stringsToJSON(str string) string { func stringsToJSON(str string) string {
rs := []rune(str) var jsons bytes.Buffer
jsons := "" for _, r := range str {
for _, r := range rs {
rint := int(r) rint := int(r)
if rint < 128 { if rint < 128 {
jsons += string(r) jsons.WriteRune(r)
} else { } else {
jsons += "\\u" + strconv.FormatInt(int64(rint), 16) // json jsons.WriteString("\\u")
jsons.WriteString(strconv.FormatInt(int64(rint), 16))
} }
} }
return jsons return jsons.String()
} }
// Session sets session item value with given key. // Session sets session item value with given key.

78
context/param/conv.go Normal file
View File

@ -0,0 +1,78 @@
package param
import (
"fmt"
"reflect"
beecontext "github.com/astaxie/beego/context"
"github.com/astaxie/beego/logs"
)
// ConvertParams converts http method params to values that will be passed to the method controller as arguments
func ConvertParams(methodParams []*MethodParam, methodType reflect.Type, ctx *beecontext.Context) (result []reflect.Value) {
result = make([]reflect.Value, 0, len(methodParams))
for i := 0; i < len(methodParams); i++ {
reflectValue := convertParam(methodParams[i], methodType.In(i), ctx)
result = append(result, reflectValue)
}
return
}
func convertParam(param *MethodParam, paramType reflect.Type, ctx *beecontext.Context) (result reflect.Value) {
paramValue := getParamValue(param, ctx)
if paramValue == "" {
if param.required {
ctx.Abort(400, fmt.Sprintf("Missing parameter %s", param.name))
} else {
paramValue = param.defaultValue
}
}
reflectValue, err := parseValue(param, paramValue, paramType)
if err != nil {
logs.Debug(fmt.Sprintf("Error converting param %s to type %s. Value: %v, Error: %s", param.name, paramType, paramValue, err))
ctx.Abort(400, fmt.Sprintf("Invalid parameter %s. Can not convert %v to type %s", param.name, paramValue, paramType))
}
return reflectValue
}
func getParamValue(param *MethodParam, ctx *beecontext.Context) string {
switch param.in {
case body:
return string(ctx.Input.RequestBody)
case header:
return ctx.Input.Header(param.name)
case path:
return ctx.Input.Query(":" + param.name)
default:
return ctx.Input.Query(param.name)
}
}
func parseValue(param *MethodParam, paramValue string, paramType reflect.Type) (result reflect.Value, err error) {
if paramValue == "" {
return reflect.Zero(paramType), nil
}
parser := getParser(param, paramType)
value, err := parser.parse(paramValue, paramType)
if err != nil {
return result, err
}
return safeConvert(reflect.ValueOf(value), paramType)
}
func safeConvert(value reflect.Value, t reflect.Type) (result reflect.Value, err error) {
defer func() {
if r := recover(); r != nil {
var ok bool
err, ok = r.(error)
if !ok {
err = fmt.Errorf("%v", r)
}
}
}()
result = value.Convert(t)
return
}

View File

@ -0,0 +1,69 @@
package param
import (
"fmt"
"strings"
)
//MethodParam keeps param information to be auto passed to controller methods
type MethodParam struct {
name string
in paramType
required bool
defaultValue string
}
type paramType byte
const (
param paramType = iota
path
body
header
)
//New creates a new MethodParam with name and specific options
func New(name string, opts ...MethodParamOption) *MethodParam {
return newParam(name, nil, opts)
}
func newParam(name string, parser paramParser, opts []MethodParamOption) (param *MethodParam) {
param = &MethodParam{name: name}
for _, option := range opts {
option(param)
}
return
}
//Make creates an array of MethodParmas or an empty array
func Make(list ...*MethodParam) []*MethodParam {
if len(list) > 0 {
return list
}
return nil
}
func (mp *MethodParam) String() string {
options := []string{}
result := "param.New(\"" + mp.name + "\""
if mp.required {
options = append(options, "param.IsRequired")
}
switch mp.in {
case path:
options = append(options, "param.InPath")
case body:
options = append(options, "param.InBody")
case header:
options = append(options, "param.InHeader")
}
if mp.defaultValue != "" {
options = append(options, fmt.Sprintf(`param.Default("%s")`, mp.defaultValue))
}
if len(options) > 0 {
result += ", "
}
result += strings.Join(options, ", ")
result += ")"
return result
}

37
context/param/options.go Normal file
View File

@ -0,0 +1,37 @@
package param
import (
"fmt"
)
// MethodParamOption defines a func which apply options on a MethodParam
type MethodParamOption func(*MethodParam)
// IsRequired indicates that this param is required and can not be omitted from the http request
var IsRequired MethodParamOption = func(p *MethodParam) {
p.required = true
}
// InHeader indicates that this param is passed via an http header
var InHeader MethodParamOption = func(p *MethodParam) {
p.in = header
}
// InPath indicates that this param is part of the URL path
var InPath MethodParamOption = func(p *MethodParam) {
p.in = path
}
// InBody indicates that this param is passed as an http request body
var InBody MethodParamOption = func(p *MethodParam) {
p.in = body
}
// Default provides a default value for the http param
func Default(defaultValue interface{}) MethodParamOption {
return func(p *MethodParam) {
if defaultValue != nil {
p.defaultValue = fmt.Sprint(defaultValue)
}
}
}

149
context/param/parsers.go Normal file
View File

@ -0,0 +1,149 @@
package param
import (
"encoding/json"
"reflect"
"strconv"
"strings"
"time"
)
type paramParser interface {
parse(value string, toType reflect.Type) (interface{}, error)
}
func getParser(param *MethodParam, t reflect.Type) paramParser {
switch t.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return intParser{}
case reflect.Slice:
if t.Elem().Kind() == reflect.Uint8 { //treat []byte as string
return stringParser{}
}
if param.in == body {
return jsonParser{}
}
elemParser := getParser(param, t.Elem())
if elemParser == (jsonParser{}) {
return elemParser
}
return sliceParser(elemParser)
case reflect.Bool:
return boolParser{}
case reflect.String:
return stringParser{}
case reflect.Float32, reflect.Float64:
return floatParser{}
case reflect.Ptr:
elemParser := getParser(param, t.Elem())
if elemParser == (jsonParser{}) {
return elemParser
}
return ptrParser(elemParser)
default:
if t.PkgPath() == "time" && t.Name() == "Time" {
return timeParser{}
}
return jsonParser{}
}
}
type parserFunc func(value string, toType reflect.Type) (interface{}, error)
func (f parserFunc) parse(value string, toType reflect.Type) (interface{}, error) {
return f(value, toType)
}
type boolParser struct {
}
func (p boolParser) parse(value string, toType reflect.Type) (interface{}, error) {
return strconv.ParseBool(value)
}
type stringParser struct {
}
func (p stringParser) parse(value string, toType reflect.Type) (interface{}, error) {
return value, nil
}
type intParser struct {
}
func (p intParser) parse(value string, toType reflect.Type) (interface{}, error) {
return strconv.Atoi(value)
}
type floatParser struct {
}
func (p floatParser) parse(value string, toType reflect.Type) (interface{}, error) {
if toType.Kind() == reflect.Float32 {
res, err := strconv.ParseFloat(value, 32)
if err != nil {
return nil, err
}
return float32(res), nil
}
return strconv.ParseFloat(value, 64)
}
type timeParser struct {
}
func (p timeParser) parse(value string, toType reflect.Type) (result interface{}, err error) {
result, err = time.Parse(time.RFC3339, value)
if err != nil {
result, err = time.Parse("2006-01-02", value)
}
return
}
type jsonParser struct {
}
func (p jsonParser) parse(value string, toType reflect.Type) (interface{}, error) {
pResult := reflect.New(toType)
v := pResult.Interface()
err := json.Unmarshal([]byte(value), v)
if err != nil {
return nil, err
}
return pResult.Elem().Interface(), nil
}
func sliceParser(elemParser paramParser) paramParser {
return parserFunc(func(value string, toType reflect.Type) (interface{}, error) {
values := strings.Split(value, ",")
result := reflect.MakeSlice(toType, 0, len(values))
elemType := toType.Elem()
for _, v := range values {
parsedValue, err := elemParser.parse(v, elemType)
if err != nil {
return nil, err
}
result = reflect.Append(result, reflect.ValueOf(parsedValue))
}
return result.Interface(), nil
})
}
func ptrParser(elemParser paramParser) paramParser {
return parserFunc(func(value string, toType reflect.Type) (interface{}, error) {
parsedValue, err := elemParser.parse(value, toType.Elem())
if err != nil {
return nil, err
}
newValPtr := reflect.New(toType.Elem())
newVal := reflect.Indirect(newValPtr)
convertedVal, err := safeConvert(reflect.ValueOf(parsedValue), toType.Elem())
if err != nil {
return nil, err
}
newVal.Set(convertedVal)
return newValPtr.Interface(), nil
})
}

View File

@ -0,0 +1,84 @@
package param
import "testing"
import "reflect"
import "time"
type testDefinition struct {
strValue string
expectedValue interface{}
expectedParser paramParser
}
func Test_Parsers(t *testing.T) {
//ints
checkParser(testDefinition{"1", 1, intParser{}}, t)
checkParser(testDefinition{"-1", int64(-1), intParser{}}, t)
checkParser(testDefinition{"1", uint64(1), intParser{}}, t)
//floats
checkParser(testDefinition{"1.0", float32(1.0), floatParser{}}, t)
checkParser(testDefinition{"-1.0", float64(-1.0), floatParser{}}, t)
//strings
checkParser(testDefinition{"AB", "AB", stringParser{}}, t)
checkParser(testDefinition{"AB", []byte{65, 66}, stringParser{}}, t)
//bools
checkParser(testDefinition{"true", true, boolParser{}}, t)
checkParser(testDefinition{"0", false, boolParser{}}, t)
//timeParser
checkParser(testDefinition{"2017-05-30T13:54:53Z", time.Date(2017, 5, 30, 13, 54, 53, 0, time.UTC), timeParser{}}, t)
checkParser(testDefinition{"2017-05-30", time.Date(2017, 5, 30, 0, 0, 0, 0, time.UTC), timeParser{}}, t)
//json
checkParser(testDefinition{`{"X": 5, "Y":"Z"}`, struct {
X int
Y string
}{5, "Z"}, jsonParser{}}, t)
//slice in query is parsed as comma delimited
checkParser(testDefinition{`1,2`, []int{1, 2}, sliceParser(intParser{})}, t)
//slice in body is parsed as json
checkParser(testDefinition{`["a","b"]`, []string{"a", "b"}, jsonParser{}}, t, MethodParam{in: body})
//pointers
var someInt = 1
checkParser(testDefinition{`1`, &someInt, ptrParser(intParser{})}, t)
var someStruct = struct{ X int }{5}
checkParser(testDefinition{`{"X": 5}`, &someStruct, jsonParser{}}, t)
}
func checkParser(def testDefinition, t *testing.T, methodParam ...MethodParam) {
toType := reflect.TypeOf(def.expectedValue)
var mp MethodParam
if len(methodParam) == 0 {
mp = MethodParam{}
} else {
mp = methodParam[0]
}
parser := getParser(&mp, toType)
if reflect.TypeOf(parser) != reflect.TypeOf(def.expectedParser) {
t.Errorf("Invalid parser for value %v. Expected: %v, actual: %v", def.strValue, reflect.TypeOf(def.expectedParser).Name(), reflect.TypeOf(parser).Name())
return
}
result, err := parser.parse(def.strValue, toType)
if err != nil {
t.Errorf("Parsing error for value %v. Expected result: %v, error: %v", def.strValue, def.expectedValue, err)
return
}
convResult, err := safeConvert(reflect.ValueOf(result), toType)
if err != nil {
t.Errorf("Conversion error for %v. from value: %v, toType: %v, error: %v", def.strValue, result, toType, err)
return
}
if !reflect.DeepEqual(convResult.Interface(), def.expectedValue) {
t.Errorf("Parsing error for value %v. Expected result: %v, actual: %v", def.strValue, def.expectedValue, result)
}
}

12
context/renderer.go Normal file
View File

@ -0,0 +1,12 @@
package context
// Renderer defines an http response renderer
type Renderer interface {
Render(ctx *Context)
}
type rendererFunc func(ctx *Context)
func (f rendererFunc) Render(ctx *Context) {
f(ctx)
}

27
context/response.go Normal file
View File

@ -0,0 +1,27 @@
package context
import (
"strconv"
"net/http"
)
const (
//BadRequest indicates http error 400
BadRequest StatusCode = http.StatusBadRequest
//NotFound indicates http error 404
NotFound StatusCode = http.StatusNotFound
)
// StatusCode sets the http response status code
type StatusCode int
func (s StatusCode) Error() string {
return strconv.Itoa(int(s))
}
// Render sets the http status code
func (s StatusCode) Render(ctx *Context) {
ctx.Output.SetStatus(int(s))
}

View File

@ -28,6 +28,7 @@ import (
"strings" "strings"
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
"github.com/astaxie/beego/context/param"
"github.com/astaxie/beego/session" "github.com/astaxie/beego/session"
) )
@ -51,8 +52,16 @@ type ControllerComments struct {
Router string Router string
AllowHTTPMethods []string AllowHTTPMethods []string
Params []map[string]string Params []map[string]string
MethodParams []*param.MethodParam
} }
// ControllerCommentsSlice implements the sort interface
type ControllerCommentsSlice []ControllerComments
func (p ControllerCommentsSlice) Len() int { return len(p) }
func (p ControllerCommentsSlice) Less(i, j int) bool { return p[i].Router < p[j].Router }
func (p ControllerCommentsSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// Controller defines some basic http request handler operations, such as // Controller defines some basic http request handler operations, such as
// http context, template and view, session and xsrf. // http context, template and view, session and xsrf.
type Controller struct { type Controller struct {
@ -69,6 +78,7 @@ type Controller struct {
// template data // template data
TplName string TplName string
ViewPath string
Layout string Layout string
LayoutSections map[string]string // the key is the section name and the value is the template name LayoutSections map[string]string // the key is the section name and the value is the template name
TplPrefix string TplPrefix string
@ -185,7 +195,11 @@ func (c *Controller) Render() error {
if err != nil { if err != nil {
return err return err
} }
c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8")
if c.Ctx.ResponseWriter.Header().Get("Content-Type") == "" {
c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8")
}
return c.Ctx.Output.Body(rb) return c.Ctx.Output.Body(rb)
} }
@ -209,7 +223,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
continue continue
} }
buf.Reset() buf.Reset()
err = ExecuteTemplate(&buf, sectionTpl, c.Data) err = ExecuteViewPathTemplate(&buf, sectionTpl, c.viewPath(), c.Data)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -218,7 +232,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
} }
buf.Reset() buf.Reset()
ExecuteTemplate(&buf, c.Layout, c.Data) ExecuteViewPathTemplate(&buf, c.Layout, c.viewPath(), c.Data)
} }
return buf.Bytes(), err return buf.Bytes(), err
} }
@ -244,9 +258,16 @@ func (c *Controller) renderTemplate() (bytes.Buffer, error) {
} }
} }
} }
BuildTemplate(BConfig.WebConfig.ViewsPath, buildFiles...) BuildTemplate(c.viewPath(), buildFiles...)
} }
return buf, ExecuteTemplate(&buf, c.TplName, c.Data) return buf, ExecuteViewPathTemplate(&buf, c.TplName, c.viewPath(), c.Data)
}
func (c *Controller) viewPath() string {
if c.ViewPath == "" {
return BConfig.WebConfig.ViewsPath
}
return c.ViewPath
} }
// Redirect sends the redirection response to url with status code. // Redirect sends the redirection response to url with status code.
@ -302,7 +323,7 @@ func (c *Controller) ServeJSON(encoding ...bool) {
if BConfig.RunMode == PROD { if BConfig.RunMode == PROD {
hasIndent = false hasIndent = false
} }
if len(encoding) > 0 && encoding[0] == true { if len(encoding) > 0 && encoding[0] {
hasEncoding = true hasEncoding = true
} }
c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding) c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding)

View File

@ -20,6 +20,8 @@ import (
"testing" "testing"
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
"os"
"path/filepath"
) )
func TestGetInt(t *testing.T) { func TestGetInt(t *testing.T) {
@ -121,3 +123,59 @@ func TestGetUint64(t *testing.T) {
t.Errorf("TestGetUint64 expect %v,get %T,%v", uint64(math.MaxUint64), val, val) t.Errorf("TestGetUint64 expect %v,get %T,%v", uint64(math.MaxUint64), val, val)
} }
} }
func TestAdditionalViewPaths(t *testing.T) {
dir1 := "_beeTmp"
dir2 := "_beeTmp2"
defer os.RemoveAll(dir1)
defer os.RemoveAll(dir2)
dir1file := "file1.tpl"
dir2file := "file2.tpl"
genFile := func(dir string, name string, content string) {
os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777)
if f, err := os.Create(filepath.Join(dir, name)); err != nil {
t.Fatal(err)
} else {
defer f.Close()
f.WriteString(content)
f.Close()
}
}
genFile(dir1, dir1file, `<div>{{.Content}}</div>`)
genFile(dir2, dir2file, `<html>{{.Content}}</html>`)
AddViewPath(dir1)
AddViewPath(dir2)
ctrl := Controller{
TplName: "file1.tpl",
ViewPath: dir1,
}
ctrl.Data = map[interface{}]interface{}{
"Content": "value2",
}
if result, err := ctrl.RenderString(); err != nil {
t.Fatal(err)
} else {
if result != "<div>value2</div>" {
t.Fatalf("TestAdditionalViewPaths expect %s got %s", "<div>value2</div>", result)
}
}
func() {
ctrl.TplName = "file2.tpl"
defer func() {
if r := recover(); r == nil {
t.Fatal("TestAdditionalViewPaths expected error")
}
}()
ctrl.RenderString()
}()
ctrl.TplName = "file2.tpl"
ctrl.ViewPath = dir2
ctrl.RenderString()
}

View File

@ -252,6 +252,30 @@ func forbidden(rw http.ResponseWriter, r *http.Request) {
) )
} }
// show 422 missing xsrf token
func missingxsrf(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
422,
"<br>The page you have requested is forbidden."+
"<br>Perhaps you are here because:"+
"<br><br><ul>"+
"<br>'_xsrf' argument missing from POST"+
"</ul>",
)
}
// show 417 invalid xsrf token
func invalidxsrf(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
417,
"<br>The page you have requested is forbidden."+
"<br>Perhaps you are here because:"+
"<br><br><ul>"+
"<br>expected XSRF not found"+
"</ul>",
)
}
// show 404 not found error. // show 404 not found error.
func notFound(rw http.ResponseWriter, r *http.Request) { func notFound(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r, responseError(rw, r,

View File

@ -52,7 +52,7 @@ func TestErrorCode_01(t *testing.T) {
if w.Code != code { if w.Code != code {
t.Fail() t.Fail()
} }
if !strings.Contains(string(w.Body.Bytes()), http.StatusText(code)) { if !strings.Contains(w.Body.String(), http.StatusText(code)) {
t.Fail() t.Fail()
} }
} }
@ -82,7 +82,7 @@ func TestErrorCode_03(t *testing.T) {
if w.Code != 200 { if w.Code != 200 {
t.Fail() t.Fail()
} }
if string(w.Body.Bytes()) != parseCodeError { if w.Body.String() != parseCodeError {
t.Fail() t.Fail()
} }
} }

View File

@ -48,7 +48,7 @@ func TestFlashHeader(t *testing.T) {
// match for the expected header // match for the expected header
res := strings.Contains(sc, "BEEGO_FLASH=%00notice%23BEEGOFLASH%23TestFlashString%00") res := strings.Contains(sc, "BEEGO_FLASH=%00notice%23BEEGOFLASH%23TestFlashString%00")
// validate the assertion // validate the assertion
if res != true { if !res {
t.Errorf("TestFlashHeader() unable to validate flash message") t.Errorf("TestFlashHeader() unable to validate flash message")
} }
} }

View File

@ -3,14 +3,17 @@ package grace
import ( import (
"errors" "errors"
"net" "net"
"sync"
) )
type graceConn struct { type graceConn struct {
net.Conn net.Conn
server *Server server *Server
m sync.Mutex
closed bool
} }
func (c graceConn) Close() (err error) { func (c *graceConn) Close() (err error) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
switch x := r.(type) { switch x := r.(type) {
@ -23,6 +26,14 @@ func (c graceConn) Close() (err error) {
} }
} }
}() }()
c.m.Lock()
if c.closed {
c.m.Unlock()
return
}
c.server.wg.Done() c.server.wg.Done()
c.closed = true
c.m.Unlock()
return c.Conn.Close() return c.Conn.Close()
} }

View File

@ -85,23 +85,31 @@ var (
isChild bool isChild bool
socketOrder string socketOrder string
once sync.Once
hookableSignals []os.Signal
) )
func onceInit() { func init() {
regLock = &sync.Mutex{}
flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)") flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)")
flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started") flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started")
regLock = &sync.Mutex{}
runningServers = make(map[string]*Server) runningServers = make(map[string]*Server)
runningServersOrder = []string{} runningServersOrder = []string{}
socketPtrOffsetMap = make(map[string]uint) socketPtrOffsetMap = make(map[string]uint)
hookableSignals = []os.Signal{
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
}
} }
// NewServer returns a new graceServer. // NewServer returns a new graceServer.
func NewServer(addr string, handler http.Handler) (srv *Server) { func NewServer(addr string, handler http.Handler) (srv *Server) {
once.Do(onceInit)
regLock.Lock() regLock.Lock()
defer regLock.Unlock() defer regLock.Unlock()
if !flag.Parsed() { if !flag.Parsed() {
flag.Parse() flag.Parse()
} }

View File

@ -21,7 +21,7 @@ func newGraceListener(l net.Listener, srv *Server) (el *graceListener) {
server: srv, server: srv,
} }
go func() { go func() {
_ = <-el.stop <-el.stop
el.stopped = true el.stopped = true
el.stop <- el.Listener.Close() el.stop <- el.Listener.Close()
}() }()
@ -37,7 +37,7 @@ func (gl *graceListener) Accept() (c net.Conn, err error) {
tc.SetKeepAlive(true) tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute) tc.SetKeepAlivePeriod(3 * time.Minute)
c = graceConn{ c = &graceConn{
Conn: tc, Conn: tc,
server: gl.server, server: gl.server,
} }

View File

@ -162,9 +162,7 @@ func (srv *Server) handleSignals() {
signal.Notify( signal.Notify(
srv.sigChan, srv.sigChan,
syscall.SIGHUP, hookableSignals...,
syscall.SIGINT,
syscall.SIGTERM,
) )
pid := syscall.Getpid() pid := syscall.Getpid()
@ -198,7 +196,6 @@ func (srv *Server) signalHooks(ppFlag int, sig os.Signal) {
for _, f := range srv.SignalHooks[ppFlag][sig] { for _, f := range srv.SignalHooks[ppFlag][sig] {
f() f()
} }
return
} }
// shutdown closes the listener so that no new connections are accepted. it also // shutdown closes the listener so that no new connections are accepted. it also
@ -290,3 +287,19 @@ func (srv *Server) fork() (err error) {
return return
} }
// RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal.
func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) {
if ppFlag != PreSignal && ppFlag != PostSignal {
err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal")
return
}
for _, s := range hookableSignals {
if s == sig {
srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f)
return
}
}
err = fmt.Errorf("Signal '%v' is not supported", sig)
return
}

View File

@ -32,6 +32,8 @@ func registerDefaultErrorHandler() error {
"502": badGateway, "502": badGateway,
"503": serviceUnavailable, "503": serviceUnavailable,
"504": gatewayTimeout, "504": gatewayTimeout,
"417": invalidxsrf,
"422": missingxsrf,
} }
for e, h := range m { for e, h := range m {
if _, ok := ErrorMaps[e]; !ok { if _, ok := ErrorMaps[e]; !ok {
@ -55,9 +57,9 @@ func registerSession() error {
conf.ProviderConfig = filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig) conf.ProviderConfig = filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig)
conf.DisableHTTPOnly = BConfig.WebConfig.Session.SessionDisableHTTPOnly conf.DisableHTTPOnly = BConfig.WebConfig.Session.SessionDisableHTTPOnly
conf.Domain = BConfig.WebConfig.Session.SessionDomain conf.Domain = BConfig.WebConfig.Session.SessionDomain
conf.EnableSidInHttpHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader conf.EnableSidInHTTPHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader
conf.SessionNameInHttpHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader conf.SessionNameInHTTPHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader
conf.EnableSidInUrlQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery conf.EnableSidInURLQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery
} else { } else {
if err = json.Unmarshal([]byte(sessionConfig), conf); err != nil { if err = json.Unmarshal([]byte(sessionConfig), conf); err != nil {
return err return err
@ -72,7 +74,8 @@ func registerSession() error {
} }
func registerTemplate() error { func registerTemplate() error {
if err := BuildTemplate(BConfig.WebConfig.ViewsPath); err != nil { defer lockViewPaths()
if err := AddViewPath(BConfig.WebConfig.ViewsPath); err != nil {
if BConfig.RunMode == DEV { if BConfig.RunMode == DEV {
logs.Warn(err) logs.Warn(err)
} }

View File

@ -32,7 +32,7 @@ The default timeout is `60` seconds, function prototype:
SetTimeout(connectTimeout, readWriteTimeout time.Duration) SetTimeout(connectTimeout, readWriteTimeout time.Duration)
Exmaple: Example:
// GET // GET
httplib.Get("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second) httplib.Get("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second)

View File

@ -140,6 +140,7 @@ type BeegoHTTPSettings struct {
EnableCookie bool EnableCookie bool
Gzip bool Gzip bool
DumpBody bool DumpBody bool
Retries int // if set to -1 means will retry forever
} }
// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request. // BeegoHTTPRequest provides more useful methods for requesting one url than http.Request.
@ -189,6 +190,15 @@ func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest {
return b return b
} }
// Retries sets Retries times.
// default is 0 means no retried.
// -1 means retried forever.
// others means retried times.
func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest {
b.setting.Retries = times
return b
}
// DumpBody setting whether need to Dump the Body. // DumpBody setting whether need to Dump the Body.
func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest {
b.setting.DumpBody = isdump b.setting.DumpBody = isdump
@ -325,7 +335,7 @@ func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error)
func (b *BeegoHTTPRequest) buildURL(paramBody string) { func (b *BeegoHTTPRequest) buildURL(paramBody string) {
// build GET url with query string // build GET url with query string
if b.req.Method == "GET" && len(paramBody) > 0 { if b.req.Method == "GET" && len(paramBody) > 0 {
if strings.Index(b.url, "?") != -1 { if strings.Contains(b.url, "?") {
b.url += "&" + paramBody b.url += "&" + paramBody
} else { } else {
b.url = b.url + "?" + paramBody b.url = b.url + "?" + paramBody
@ -334,7 +344,7 @@ func (b *BeegoHTTPRequest) buildURL(paramBody string) {
} }
// build POST/PUT/PATCH url and body // build POST/PUT/PATCH url and body
if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH") && b.req.Body == nil { if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH" || b.req.Method == "DELETE") && b.req.Body == nil {
// with files // with files
if len(b.files) > 0 { if len(b.files) > 0 {
pr, pw := io.Pipe() pr, pw := io.Pipe()
@ -390,7 +400,7 @@ func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) {
} }
// DoRequest will do the client.Do // DoRequest will do the client.Do
func (b *BeegoHTTPRequest) DoRequest() (*http.Response, error) { func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) {
var paramBody string var paramBody string
if len(b.params) > 0 { if len(b.params) > 0 {
var buf bytes.Buffer var buf bytes.Buffer
@ -467,7 +477,16 @@ func (b *BeegoHTTPRequest) DoRequest() (*http.Response, error) {
} }
b.dump = dump b.dump = dump
} }
return client.Do(b.req) // retries default value is 0, it will run once.
// retries equal to -1, it will run forever until success
// retries is setted, it will retries fixed times.
for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ {
resp, err = client.Do(b.req)
if err == nil {
break
}
}
return resp, err
} }
// String returns the body string in response. // String returns the body string in response.
@ -501,9 +520,9 @@ func (b *BeegoHTTPRequest) Bytes() ([]byte, error) {
return nil, err return nil, err
} }
b.body, err = ioutil.ReadAll(reader) b.body, err = ioutil.ReadAll(reader)
} else { return b.body, err
b.body, err = ioutil.ReadAll(resp.Body)
} }
b.body, err = ioutil.ReadAll(resp.Body)
return b.body, err return b.body, err
} }

View File

@ -102,6 +102,14 @@ func TestSimpleDelete(t *testing.T) {
t.Log(str) t.Log(str)
} }
func TestSimpleDeleteParam(t *testing.T) {
str, err := Delete("http://httpbin.org/delete").Param("key", "val").String()
if err != nil {
t.Fatal(err)
}
t.Log(str)
}
func TestWithCookie(t *testing.T) { func TestWithCookie(t *testing.T) {
v := "smallfish" v := "smallfish"
str, err := Get("http://httpbin.org/cookies/set?k1=" + v).SetEnableCookie(true).String() str, err := Get("http://httpbin.org/cookies/set?k1=" + v).SetEnableCookie(true).String()

186
logs/alils/alils.go Normal file
View File

@ -0,0 +1,186 @@
package alils
import (
"encoding/json"
"strings"
"sync"
"time"
"github.com/astaxie/beego/logs"
"github.com/gogo/protobuf/proto"
)
const (
// CacheSize set the flush size
CacheSize int = 64
// Delimiter define the topic delimiter
Delimiter string = "##"
)
// Config is the Config for Ali Log
type Config struct {
Project string `json:"project"`
Endpoint string `json:"endpoint"`
KeyID string `json:"key_id"`
KeySecret string `json:"key_secret"`
LogStore string `json:"log_store"`
Topics []string `json:"topics"`
Source string `json:"source"`
Level int `json:"level"`
FlushWhen int `json:"flush_when"`
}
// aliLSWriter implements LoggerInterface.
// it writes messages in keep-live tcp connection.
type aliLSWriter struct {
store *LogStore
group []*LogGroup
withMap bool
groupMap map[string]*LogGroup
lock *sync.Mutex
Config
}
// NewAliLS create a new Logger
func NewAliLS() logs.Logger {
alils := new(aliLSWriter)
alils.Level = logs.LevelTrace
return alils
}
// Init parse config and init struct
func (c *aliLSWriter) Init(jsonConfig string) (err error) {
json.Unmarshal([]byte(jsonConfig), c)
if c.FlushWhen > CacheSize {
c.FlushWhen = CacheSize
}
prj := &LogProject{
Name: c.Project,
Endpoint: c.Endpoint,
AccessKeyID: c.KeyID,
AccessKeySecret: c.KeySecret,
}
c.store, err = prj.GetLogStore(c.LogStore)
if err != nil {
return err
}
// Create default Log Group
c.group = append(c.group, &LogGroup{
Topic: proto.String(""),
Source: proto.String(c.Source),
Logs: make([]*Log, 0, c.FlushWhen),
})
// Create other Log Group
c.groupMap = make(map[string]*LogGroup)
for _, topic := range c.Topics {
lg := &LogGroup{
Topic: proto.String(topic),
Source: proto.String(c.Source),
Logs: make([]*Log, 0, c.FlushWhen),
}
c.group = append(c.group, lg)
c.groupMap[topic] = lg
}
if len(c.group) == 1 {
c.withMap = false
} else {
c.withMap = true
}
c.lock = &sync.Mutex{}
return nil
}
// WriteMsg write message in connection.
// if connection is down, try to re-connect.
func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error) {
if level > c.Level {
return nil
}
var topic string
var content string
var lg *LogGroup
if c.withMap {
// TopicLogGroup
strs := strings.SplitN(msg, Delimiter, 2)
if len(strs) == 2 {
pos := strings.LastIndex(strs[0], " ")
topic = strs[0][pos+1 : len(strs[0])]
content = strs[0][0:pos] + strs[1]
lg = c.groupMap[topic]
}
// send to empty Topic
if lg == nil {
content = msg
lg = c.group[0]
}
} else {
content = msg
lg = c.group[0]
}
c1 := &LogContent{
Key: proto.String("msg"),
Value: proto.String(content),
}
l := &Log{
Time: proto.Uint32(uint32(when.Unix())),
Contents: []*LogContent{
c1,
},
}
c.lock.Lock()
lg.Logs = append(lg.Logs, l)
c.lock.Unlock()
if len(lg.Logs) >= c.FlushWhen {
c.flush(lg)
}
return nil
}
// Flush implementing method. empty.
func (c *aliLSWriter) Flush() {
// flush all group
for _, lg := range c.group {
c.flush(lg)
}
}
// Destroy destroy connection writer and close tcp listener.
func (c *aliLSWriter) Destroy() {
}
func (c *aliLSWriter) flush(lg *LogGroup) {
c.lock.Lock()
defer c.lock.Unlock()
err := c.store.PutLogs(lg)
if err != nil {
return
}
lg.Logs = make([]*Log, 0, c.FlushWhen)
}
func init() {
logs.Register(logs.AdapterAliLS, NewAliLS)
}

13
logs/alils/config.go Executable file
View File

@ -0,0 +1,13 @@
package alils
const (
version = "0.5.0" // SDK version
signatureMethod = "hmac-sha1" // Signature method
// OffsetNewest stands for the log head offset, i.e. the offset that will be
// assigned to the next message that will be produced to the shard.
OffsetNewest = "end"
// OffsetOldest stands for the oldest offset available on the logstore for a
// shard.
OffsetOldest = "begin"
)

1038
logs/alils/log.pb.go Executable file

File diff suppressed because it is too large Load Diff

42
logs/alils/log_config.go Executable file
View File

@ -0,0 +1,42 @@
package alils
// InputDetail define log detail
type InputDetail struct {
LogType string `json:"logType"`
LogPath string `json:"logPath"`
FilePattern string `json:"filePattern"`
LocalStorage bool `json:"localStorage"`
TimeFormat string `json:"timeFormat"`
LogBeginRegex string `json:"logBeginRegex"`
Regex string `json:"regex"`
Keys []string `json:"key"`
FilterKeys []string `json:"filterKey"`
FilterRegex []string `json:"filterRegex"`
TopicFormat string `json:"topicFormat"`
}
// OutputDetail define the output detail
type OutputDetail struct {
Endpoint string `json:"endpoint"`
LogStoreName string `json:"logstoreName"`
}
// LogConfig define Log Config
type LogConfig struct {
Name string `json:"configName"`
InputType string `json:"inputType"`
InputDetail InputDetail `json:"inputDetail"`
OutputType string `json:"outputType"`
OutputDetail OutputDetail `json:"outputDetail"`
CreateTime uint32
LastModifyTime uint32
project *LogProject
}
// GetAppliedMachineGroup returns applied machine group of this config.
func (c *LogConfig) GetAppliedMachineGroup(confName string) (groupNames []string, err error) {
groupNames, err = c.project.GetAppliedMachineGroups(c.Name)
return
}

819
logs/alils/log_project.go Executable file
View File

@ -0,0 +1,819 @@
/*
Package alils implements the SDK(v0.5.0) of Simple Log Service(abbr. SLS).
For more description about SLS, please read this article:
http://gitlab.alibaba-inc.com/sls/doc.
*/
package alils
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httputil"
)
// Error message in SLS HTTP response.
type errorMessage struct {
Code string `json:"errorCode"`
Message string `json:"errorMessage"`
}
// LogProject Define the Ali Project detail
type LogProject struct {
Name string // Project name
Endpoint string // IP or hostname of SLS endpoint
AccessKeyID string
AccessKeySecret string
}
// NewLogProject creates a new SLS project.
func NewLogProject(name, endpoint, AccessKeyID, accessKeySecret string) (p *LogProject, err error) {
p = &LogProject{
Name: name,
Endpoint: endpoint,
AccessKeyID: AccessKeyID,
AccessKeySecret: accessKeySecret,
}
return p, nil
}
// ListLogStore returns all logstore names of project p.
func (p *LogProject) ListLogStore() (storeNames []string, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
uri := fmt.Sprintf("/logstores")
r, err := request(p, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to list logstore")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
type Body struct {
Count int
LogStores []string
}
body := &Body{}
err = json.Unmarshal(buf, body)
if err != nil {
return
}
storeNames = body.LogStores
return
}
// GetLogStore returns logstore according by logstore name.
func (p *LogProject) GetLogStore(name string) (s *LogStore, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
r, err := request(p, "GET", "/logstores/"+name, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to get logstore")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
s = &LogStore{}
err = json.Unmarshal(buf, s)
if err != nil {
return
}
s.project = p
return
}
// CreateLogStore creates a new logstore in SLS,
// where name is logstore name,
// and ttl is time-to-live(in day) of logs,
// and shardCnt is the number of shards.
func (p *LogProject) CreateLogStore(name string, ttl, shardCnt int) (err error) {
type Body struct {
Name string `json:"logstoreName"`
TTL int `json:"ttl"`
ShardCount int `json:"shardCount"`
}
store := &Body{
Name: name,
TTL: ttl,
ShardCount: shardCnt,
}
body, err := json.Marshal(store)
if err != nil {
return
}
h := map[string]string{
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
"Content-Type": "application/json",
"Accept-Encoding": "deflate", // TODO: support lz4
}
r, err := request(p, "POST", "/logstores", h, body)
if err != nil {
return
}
body, err = ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(body, errMsg)
if err != nil {
err = fmt.Errorf("failed to create logstore")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// DeleteLogStore deletes a logstore according by logstore name.
func (p *LogProject) DeleteLogStore(name string) (err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
r, err := request(p, "DELETE", "/logstores/"+name, h, nil)
if err != nil {
return
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(body, errMsg)
if err != nil {
err = fmt.Errorf("failed to delete logstore")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// UpdateLogStore updates a logstore according by logstore name,
// obviously we can't modify the logstore name itself.
func (p *LogProject) UpdateLogStore(name string, ttl, shardCnt int) (err error) {
type Body struct {
Name string `json:"logstoreName"`
TTL int `json:"ttl"`
ShardCount int `json:"shardCount"`
}
store := &Body{
Name: name,
TTL: ttl,
ShardCount: shardCnt,
}
body, err := json.Marshal(store)
if err != nil {
return
}
h := map[string]string{
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
"Content-Type": "application/json",
"Accept-Encoding": "deflate", // TODO: support lz4
}
r, err := request(p, "PUT", "/logstores", h, body)
if err != nil {
return
}
body, err = ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(body, errMsg)
if err != nil {
err = fmt.Errorf("failed to update logstore")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// ListMachineGroup returns machine group name list and the total number of machine groups.
// The offset starts from 0 and the size is the max number of machine groups could be returned.
func (p *LogProject) ListMachineGroup(offset, size int) (m []string, total int, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
if size <= 0 {
size = 500
}
uri := fmt.Sprintf("/machinegroups?offset=%v&size=%v", offset, size)
r, err := request(p, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to list machine group")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
type Body struct {
MachineGroups []string
Count int
Total int
}
body := &Body{}
err = json.Unmarshal(buf, body)
if err != nil {
return
}
m = body.MachineGroups
total = body.Total
return
}
// GetMachineGroup retruns machine group according by machine group name.
func (p *LogProject) GetMachineGroup(name string) (m *MachineGroup, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
r, err := request(p, "GET", "/machinegroups/"+name, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to get machine group:%v", name)
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
m = &MachineGroup{}
err = json.Unmarshal(buf, m)
if err != nil {
return
}
m.project = p
return
}
// CreateMachineGroup creates a new machine group in SLS.
func (p *LogProject) CreateMachineGroup(m *MachineGroup) (err error) {
body, err := json.Marshal(m)
if err != nil {
return
}
h := map[string]string{
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
"Content-Type": "application/json",
"Accept-Encoding": "deflate", // TODO: support lz4
}
r, err := request(p, "POST", "/machinegroups", h, body)
if err != nil {
return
}
body, err = ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(body, errMsg)
if err != nil {
err = fmt.Errorf("failed to create machine group")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// UpdateMachineGroup updates a machine group.
func (p *LogProject) UpdateMachineGroup(m *MachineGroup) (err error) {
body, err := json.Marshal(m)
if err != nil {
return
}
h := map[string]string{
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
"Content-Type": "application/json",
"Accept-Encoding": "deflate", // TODO: support lz4
}
r, err := request(p, "PUT", "/machinegroups/"+m.Name, h, body)
if err != nil {
return
}
body, err = ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(body, errMsg)
if err != nil {
err = fmt.Errorf("failed to update machine group")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// DeleteMachineGroup deletes machine group according machine group name.
func (p *LogProject) DeleteMachineGroup(name string) (err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
r, err := request(p, "DELETE", "/machinegroups/"+name, h, nil)
if err != nil {
return
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(body, errMsg)
if err != nil {
err = fmt.Errorf("failed to delete machine group")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// ListConfig returns config names list and the total number of configs.
// The offset starts from 0 and the size is the max number of configs could be returned.
func (p *LogProject) ListConfig(offset, size int) (cfgNames []string, total int, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
if size <= 0 {
size = 100
}
uri := fmt.Sprintf("/configs?offset=%v&size=%v", offset, size)
r, err := request(p, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to delete machine group")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
type Body struct {
Total int
Configs []string
}
body := &Body{}
err = json.Unmarshal(buf, body)
if err != nil {
return
}
cfgNames = body.Configs
total = body.Total
return
}
// GetConfig returns config according by config name.
func (p *LogProject) GetConfig(name string) (c *LogConfig, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
r, err := request(p, "GET", "/configs/"+name, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to delete config")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
c = &LogConfig{}
err = json.Unmarshal(buf, c)
if err != nil {
return
}
c.project = p
return
}
// UpdateConfig updates a config.
func (p *LogProject) UpdateConfig(c *LogConfig) (err error) {
body, err := json.Marshal(c)
if err != nil {
return
}
h := map[string]string{
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
"Content-Type": "application/json",
"Accept-Encoding": "deflate", // TODO: support lz4
}
r, err := request(p, "PUT", "/configs/"+c.Name, h, body)
if err != nil {
return
}
body, err = ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(body, errMsg)
if err != nil {
err = fmt.Errorf("failed to update config")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// CreateConfig creates a new config in SLS.
func (p *LogProject) CreateConfig(c *LogConfig) (err error) {
body, err := json.Marshal(c)
if err != nil {
return
}
h := map[string]string{
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
"Content-Type": "application/json",
"Accept-Encoding": "deflate", // TODO: support lz4
}
r, err := request(p, "POST", "/configs", h, body)
if err != nil {
return
}
body, err = ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(body, errMsg)
if err != nil {
err = fmt.Errorf("failed to update config")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// DeleteConfig deletes a config according by config name.
func (p *LogProject) DeleteConfig(name string) (err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
r, err := request(p, "DELETE", "/configs/"+name, h, nil)
if err != nil {
return
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(body, errMsg)
if err != nil {
err = fmt.Errorf("failed to delete config")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// GetAppliedMachineGroups returns applied machine group names list according config name.
func (p *LogProject) GetAppliedMachineGroups(confName string) (groupNames []string, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
uri := fmt.Sprintf("/configs/%v/machinegroups", confName)
r, err := request(p, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to get applied machine groups")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
type Body struct {
Count int
Machinegroups []string
}
body := &Body{}
err = json.Unmarshal(buf, body)
if err != nil {
return
}
groupNames = body.Machinegroups
return
}
// GetAppliedConfigs returns applied config names list according machine group name groupName.
func (p *LogProject) GetAppliedConfigs(groupName string) (confNames []string, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
uri := fmt.Sprintf("/machinegroups/%v/configs", groupName)
r, err := request(p, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to applied configs")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
type Cfg struct {
Count int `json:"count"`
Configs []string `json:"configs"`
}
body := &Cfg{}
err = json.Unmarshal(buf, body)
if err != nil {
return
}
confNames = body.Configs
return
}
// ApplyConfigToMachineGroup applies config to machine group.
func (p *LogProject) ApplyConfigToMachineGroup(confName, groupName string) (err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName)
r, err := request(p, "PUT", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to apply config to machine group")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// RemoveConfigFromMachineGroup removes config from machine group.
func (p *LogProject) RemoveConfigFromMachineGroup(confName, groupName string) (err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName)
r, err := request(p, "DELETE", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to remove config from machine group")
dump, _ := httputil.DumpResponse(r, true)
fmt.Printf("%s\n", dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}

271
logs/alils/log_store.go Executable file
View File

@ -0,0 +1,271 @@
package alils
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httputil"
"strconv"
lz4 "github.com/cloudflare/golz4"
"github.com/gogo/protobuf/proto"
)
// LogStore Store the logs
type LogStore struct {
Name string `json:"logstoreName"`
TTL int
ShardCount int
CreateTime uint32
LastModifyTime uint32
project *LogProject
}
// Shard define the Log Shard
type Shard struct {
ShardID int `json:"shardID"`
}
// ListShards returns shard id list of this logstore.
func (s *LogStore) ListShards() (shardIDs []int, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
uri := fmt.Sprintf("/logstores/%v/shards", s.Name)
r, err := request(s.project, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to list logstore")
dump, _ := httputil.DumpResponse(r, true)
fmt.Println(dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
var shards []*Shard
err = json.Unmarshal(buf, &shards)
if err != nil {
return
}
for _, v := range shards {
shardIDs = append(shardIDs, v.ShardID)
}
return
}
// PutLogs put logs into logstore.
// The callers should transform user logs into LogGroup.
func (s *LogStore) PutLogs(lg *LogGroup) (err error) {
body, err := proto.Marshal(lg)
if err != nil {
return
}
// Compresse body with lz4
out := make([]byte, lz4.CompressBound(body))
n, err := lz4.Compress(body, out)
if err != nil {
return
}
h := map[string]string{
"x-sls-compresstype": "lz4",
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
"Content-Type": "application/x-protobuf",
}
uri := fmt.Sprintf("/logstores/%v", s.Name)
r, err := request(s.project, "POST", uri, h, out[:n])
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to put logs")
dump, _ := httputil.DumpResponse(r, true)
fmt.Println(dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// GetCursor gets log cursor of one shard specified by shardID.
// The from can be in three form: a) unix timestamp in seccond, b) "begin", c) "end".
// For more detail please read: http://gitlab.alibaba-inc.com/sls/doc/blob/master/api/shard.md#logstore
func (s *LogStore) GetCursor(shardID int, from string) (cursor string, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
uri := fmt.Sprintf("/logstores/%v/shards/%v?type=cursor&from=%v",
s.Name, shardID, from)
r, err := request(s.project, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to get cursor")
dump, _ := httputil.DumpResponse(r, true)
fmt.Println(dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
type Body struct {
Cursor string
}
body := &Body{}
err = json.Unmarshal(buf, body)
if err != nil {
return
}
cursor = body.Cursor
return
}
// GetLogsBytes gets logs binary data from shard specified by shardID according cursor.
// The logGroupMaxCount is the max number of logGroup could be returned.
// The nextCursor is the next curosr can be used to read logs at next time.
func (s *LogStore) GetLogsBytes(shardID int, cursor string,
logGroupMaxCount int) (out []byte, nextCursor string, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
"Accept": "application/x-protobuf",
"Accept-Encoding": "lz4",
}
uri := fmt.Sprintf("/logstores/%v/shards/%v?type=logs&cursor=%v&count=%v",
s.Name, shardID, cursor, logGroupMaxCount)
r, err := request(s.project, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to get cursor")
dump, _ := httputil.DumpResponse(r, true)
fmt.Println(dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
v, ok := r.Header["X-Sls-Compresstype"]
if !ok || len(v) == 0 {
err = fmt.Errorf("can't find 'x-sls-compresstype' header")
return
}
if v[0] != "lz4" {
err = fmt.Errorf("unexpected compress type:%v", v[0])
return
}
v, ok = r.Header["X-Sls-Cursor"]
if !ok || len(v) == 0 {
err = fmt.Errorf("can't find 'x-sls-cursor' header")
return
}
nextCursor = v[0]
v, ok = r.Header["X-Sls-Bodyrawsize"]
if !ok || len(v) == 0 {
err = fmt.Errorf("can't find 'x-sls-bodyrawsize' header")
return
}
bodyRawSize, err := strconv.Atoi(v[0])
if err != nil {
return
}
out = make([]byte, bodyRawSize)
err = lz4.Uncompress(buf, out)
if err != nil {
return
}
return
}
// LogsBytesDecode decodes logs binary data retruned by GetLogsBytes API
func LogsBytesDecode(data []byte) (gl *LogGroupList, err error) {
gl = &LogGroupList{}
err = proto.Unmarshal(data, gl)
if err != nil {
return
}
return
}
// GetLogs gets logs from shard specified by shardID according cursor.
// The logGroupMaxCount is the max number of logGroup could be returned.
// The nextCursor is the next curosr can be used to read logs at next time.
func (s *LogStore) GetLogs(shardID int, cursor string,
logGroupMaxCount int) (gl *LogGroupList, nextCursor string, err error) {
out, nextCursor, err := s.GetLogsBytes(shardID, cursor, logGroupMaxCount)
if err != nil {
return
}
gl, err = LogsBytesDecode(out)
if err != nil {
return
}
return
}

91
logs/alils/machine_group.go Executable file
View File

@ -0,0 +1,91 @@
package alils
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httputil"
)
// MachineGroupAttribute define the Attribute
type MachineGroupAttribute struct {
ExternalName string `json:"externalName"`
TopicName string `json:"groupTopic"`
}
// MachineGroup define the machine Group
type MachineGroup struct {
Name string `json:"groupName"`
Type string `json:"groupType"`
MachineIDType string `json:"machineIdentifyType"`
MachineIDList []string `json:"machineList"`
Attribute MachineGroupAttribute `json:"groupAttribute"`
CreateTime uint32
LastModifyTime uint32
project *LogProject
}
// Machine define the Machine
type Machine struct {
IP string
UniqueID string `json:"machine-uniqueid"`
UserdefinedID string `json:"userdefined-id"`
}
// MachineList define the Machine List
type MachineList struct {
Total int
Machines []*Machine
}
// ListMachines returns machine list of this machine group.
func (m *MachineGroup) ListMachines() (ms []*Machine, total int, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
uri := fmt.Sprintf("/machinegroups/%v/machines", m.Name)
r, err := request(m.project, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to remove config from machine group")
dump, _ := httputil.DumpResponse(r, true)
fmt.Println(dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
body := &MachineList{}
err = json.Unmarshal(buf, body)
if err != nil {
return
}
ms = body.Machines
total = body.Total
return
}
// GetAppliedConfigs returns applied configs of this machine group.
func (m *MachineGroup) GetAppliedConfigs() (confNames []string, err error) {
confNames, err = m.project.GetAppliedConfigs(m.Name)
return
}

62
logs/alils/request.go Executable file
View File

@ -0,0 +1,62 @@
package alils
import (
"bytes"
"crypto/md5"
"fmt"
"net/http"
)
// request sends a request to SLS.
func request(project *LogProject, method, uri string, headers map[string]string,
body []byte) (resp *http.Response, err error) {
// The caller should provide 'x-sls-bodyrawsize' header
if _, ok := headers["x-sls-bodyrawsize"]; !ok {
err = fmt.Errorf("Can't find 'x-sls-bodyrawsize' header")
return
}
// SLS public request headers
headers["Host"] = project.Name + "." + project.Endpoint
headers["Date"] = nowRFC1123()
headers["x-sls-apiversion"] = version
headers["x-sls-signaturemethod"] = signatureMethod
if body != nil {
bodyMD5 := fmt.Sprintf("%X", md5.Sum(body))
headers["Content-MD5"] = bodyMD5
if _, ok := headers["Content-Type"]; !ok {
err = fmt.Errorf("Can't find 'Content-Type' header")
return
}
}
// Calc Authorization
// Authorization = "SLS <AccessKeyID>:<Signature>"
digest, err := signature(project, method, uri, headers)
if err != nil {
return
}
auth := fmt.Sprintf("SLS %v:%v", project.AccessKeyID, digest)
headers["Authorization"] = auth
// Initialize http request
reader := bytes.NewReader(body)
urlStr := fmt.Sprintf("http://%v.%v%v", project.Name, project.Endpoint, uri)
req, err := http.NewRequest(method, urlStr, reader)
if err != nil {
return
}
for k, v := range headers {
req.Header.Add(k, v)
}
// Get ready to do request
resp, err = http.DefaultClient.Do(req)
if err != nil {
return
}
return
}

111
logs/alils/signature.go Executable file
View File

@ -0,0 +1,111 @@
package alils
import (
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"fmt"
"net/url"
"sort"
"strings"
"time"
)
// GMT location
var gmtLoc = time.FixedZone("GMT", 0)
// NowRFC1123 returns now time in RFC1123 format with GMT timezone,
// eg. "Mon, 02 Jan 2006 15:04:05 GMT".
func nowRFC1123() string {
return time.Now().In(gmtLoc).Format(time.RFC1123)
}
// signature calculates a request's signature digest.
func signature(project *LogProject, method, uri string,
headers map[string]string) (digest string, err error) {
var contentMD5, contentType, date, canoHeaders, canoResource string
var slsHeaderKeys sort.StringSlice
// SignString = VERB + "\n"
// + CONTENT-MD5 + "\n"
// + CONTENT-TYPE + "\n"
// + DATE + "\n"
// + CanonicalizedSLSHeaders + "\n"
// + CanonicalizedResource
if val, ok := headers["Content-MD5"]; ok {
contentMD5 = val
}
if val, ok := headers["Content-Type"]; ok {
contentType = val
}
date, ok := headers["Date"]
if !ok {
err = fmt.Errorf("Can't find 'Date' header")
return
}
// Calc CanonicalizedSLSHeaders
slsHeaders := make(map[string]string, len(headers))
for k, v := range headers {
l := strings.TrimSpace(strings.ToLower(k))
if strings.HasPrefix(l, "x-sls-") {
slsHeaders[l] = strings.TrimSpace(v)
slsHeaderKeys = append(slsHeaderKeys, l)
}
}
sort.Sort(slsHeaderKeys)
for i, k := range slsHeaderKeys {
canoHeaders += k + ":" + slsHeaders[k]
if i+1 < len(slsHeaderKeys) {
canoHeaders += "\n"
}
}
// Calc CanonicalizedResource
u, err := url.Parse(uri)
if err != nil {
return
}
canoResource += url.QueryEscape(u.Path)
if u.RawQuery != "" {
var keys sort.StringSlice
vals := u.Query()
for k := range vals {
keys = append(keys, k)
}
sort.Sort(keys)
canoResource += "?"
for i, k := range keys {
if i > 0 {
canoResource += "&"
}
for _, v := range vals[k] {
canoResource += k + "=" + v
}
}
}
signStr := method + "\n" +
contentMD5 + "\n" +
contentType + "\n" +
date + "\n" +
canoHeaders + "\n" +
canoResource
// Signature = base64(hmac-sha1(UTF8-Encoding-Of(SignString)AccessKeySecret))
mac := hmac.New(sha1.New, []byte(project.AccessKeySecret))
_, err = mac.Write([]byte(signStr))
if err != nil {
return
}
digest = base64.StdEncoding.EncodeToString(mac.Sum(nil))
return
}

View File

@ -361,7 +361,7 @@ func isParameterChar(b byte) bool {
} }
func (cw *ansiColorWriter) Write(p []byte) (int, error) { func (cw *ansiColorWriter) Write(p []byte) (int, error) {
r, nw, first, last := 0, 0, 0, 0 var r, nw, first, last int
if cw.mode != DiscardNonColorEscSeq { if cw.mode != DiscardNonColorEscSeq {
cw.state = outsideCsiCode cw.state = outsideCsiCode
cw.resetBuffer() cw.resetBuffer()

View File

@ -41,7 +41,7 @@ var colors = []brush{
newBrush("1;33"), // Warning yellow newBrush("1;33"), // Warning yellow
newBrush("1;32"), // Notice green newBrush("1;32"), // Notice green
newBrush("1;34"), // Informational blue newBrush("1;34"), // Informational blue
newBrush("1;34"), // Debug blue newBrush("1;44"), // Debug Background blue
} }
// consoleWriter implements LoggerInterface and writes messages to terminal. // consoleWriter implements LoggerInterface and writes messages to terminal.

View File

@ -56,17 +56,20 @@ type fileLogWriter struct {
Perm string `json:"perm"` Perm string `json:"perm"`
RotatePerm string `json:"rotateperm"`
fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix
} }
// newFileWriter create a FileLogWriter returning as LoggerInterface. // newFileWriter create a FileLogWriter returning as LoggerInterface.
func newFileWriter() Logger { func newFileWriter() Logger {
w := &fileLogWriter{ w := &fileLogWriter{
Daily: true, Daily: true,
MaxDays: 7, MaxDays: 7,
Rotate: true, Rotate: true,
Level: LevelTrace, RotatePerm: "0440",
Perm: "0660", Level: LevelTrace,
Perm: "0660",
} }
return w return w
} }
@ -170,7 +173,7 @@ func (w *fileLogWriter) initFd() error {
fd := w.fileWriter fd := w.fileWriter
fInfo, err := fd.Stat() fInfo, err := fd.Stat()
if err != nil { if err != nil {
return fmt.Errorf("get stat err: %s\n", err) return fmt.Errorf("get stat err: %s", err)
} }
w.maxSizeCurSize = int(fInfo.Size()) w.maxSizeCurSize = int(fInfo.Size())
w.dailyOpenTime = time.Now() w.dailyOpenTime = time.Now()
@ -193,16 +196,14 @@ func (w *fileLogWriter) dailyRotate(openTime time.Time) {
y, m, d := openTime.Add(24 * time.Hour).Date() y, m, d := openTime.Add(24 * time.Hour).Date()
nextDay := time.Date(y, m, d, 0, 0, 0, 0, openTime.Location()) nextDay := time.Date(y, m, d, 0, 0, 0, 0, openTime.Location())
tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100)) tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100))
select { <-tm.C
case <-tm.C: w.Lock()
w.Lock() if w.needRotate(0, time.Now().Day()) {
if w.needRotate(0, time.Now().Day()) { if err := w.doRotate(time.Now()); err != nil {
if err := w.doRotate(time.Now()); err != nil { fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
}
} }
w.Unlock()
} }
w.Unlock()
} }
func (w *fileLogWriter) lines() (int, error) { func (w *fileLogWriter) lines() (int, error) {
@ -239,8 +240,12 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error {
// Find the next available number // Find the next available number
num := 1 num := 1
fName := "" fName := ""
rotatePerm, err := strconv.ParseInt(w.RotatePerm, 8, 64)
if err != nil {
return err
}
_, err := os.Lstat(w.Filename) _, err = os.Lstat(w.Filename)
if err != nil { if err != nil {
//even if the file is not exist or other ,we should RESTART the logger //even if the file is not exist or other ,we should RESTART the logger
goto RESTART_LOGGER goto RESTART_LOGGER
@ -261,7 +266,7 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error {
} }
// return error if the last file checked still existed // return error if the last file checked still existed
if err == nil { if err == nil {
return fmt.Errorf("Rotate: Cannot find free log number to rename %s\n", w.Filename) return fmt.Errorf("Rotate: Cannot find free log number to rename %s", w.Filename)
} }
// close fileWriter before rename // close fileWriter before rename
@ -270,20 +275,24 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error {
// Rename the file to its new found name // Rename the file to its new found name
// even if occurs error,we MUST guarantee to restart new logger // even if occurs error,we MUST guarantee to restart new logger
err = os.Rename(w.Filename, fName) err = os.Rename(w.Filename, fName)
// re-start logger if err != nil {
goto RESTART_LOGGER
}
err = os.Chmod(fName, os.FileMode(rotatePerm))
RESTART_LOGGER: RESTART_LOGGER:
startLoggerErr := w.startLogger() startLoggerErr := w.startLogger()
go w.deleteOldLog() go w.deleteOldLog()
if startLoggerErr != nil { if startLoggerErr != nil {
return fmt.Errorf("Rotate StartLogger: %s\n", startLoggerErr) return fmt.Errorf("Rotate StartLogger: %s", startLoggerErr)
} }
if err != nil { if err != nil {
return fmt.Errorf("Rotate: %s\n", err) return fmt.Errorf("Rotate: %s", err)
} }
return nil return nil
} }
func (w *fileLogWriter) deleteOldLog() { func (w *fileLogWriter) deleteOldLog() {

View File

@ -162,14 +162,35 @@ func TestFileRotate_05(t *testing.T) {
testFileDailyRotate(t, fn1, fn2) testFileDailyRotate(t, fn1, fn2)
os.Remove(fn) os.Remove(fn)
} }
func TestFileRotate_06(t *testing.T) { //test file mode
log := NewLogger(10000)
log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`)
log.Debug("debug")
log.Info("info")
log.Notice("notice")
log.Warning("warning")
log.Error("error")
log.Alert("alert")
log.Critical("critical")
log.Emergency("emergency")
rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log"
s, _ := os.Lstat(rotateName)
if s.Mode() != 0440 {
os.Remove(rotateName)
os.Remove("test3.log")
t.Fatal("rotate file mode error")
}
os.Remove(rotateName)
os.Remove("test3.log")
}
func testFileRotate(t *testing.T, fn1, fn2 string) { func testFileRotate(t *testing.T, fn1, fn2 string) {
fw := &fileLogWriter{ fw := &fileLogWriter{
Daily: true, Daily: true,
MaxDays: 7, MaxDays: 7,
Rotate: true, Rotate: true,
Level: LevelTrace, Level: LevelTrace,
Perm: "0660", Perm: "0660",
RotatePerm: "0440",
} }
fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1))
fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) fw.dailyOpenTime = time.Now().Add(-24 * time.Hour)
@ -188,11 +209,12 @@ func testFileRotate(t *testing.T, fn1, fn2 string) {
func testFileDailyRotate(t *testing.T, fn1, fn2 string) { func testFileDailyRotate(t *testing.T, fn1, fn2 string) {
fw := &fileLogWriter{ fw := &fileLogWriter{
Daily: true, Daily: true,
MaxDays: 7, MaxDays: 7,
Rotate: true, Rotate: true,
Level: LevelTrace, Level: LevelTrace,
Perm: "0660", Perm: "0660",
RotatePerm: "0440",
} }
fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1))
fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) fw.dailyOpenTime = time.Now().Add(-24 * time.Hour)

View File

@ -25,11 +25,7 @@ func newJLWriter() Logger {
// Init JLWriter with json config string // Init JLWriter with json config string
func (s *JLWriter) Init(jsonconfig string) error { func (s *JLWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), s) return json.Unmarshal([]byte(jsonconfig), s)
if err != nil {
return err
}
return nil
} }
// WriteMsg write message in smtp writer. // WriteMsg write message in smtp writer.
@ -65,12 +61,10 @@ func (s *JLWriter) WriteMsg(when time.Time, msg string, level int) error {
// Flush implementing method. empty. // Flush implementing method. empty.
func (s *JLWriter) Flush() { func (s *JLWriter) Flush() {
return
} }
// Destroy implementing method. empty. // Destroy implementing method. empty.
func (s *JLWriter) Destroy() { func (s *JLWriter) Destroy() {
return
} }
func init() { func init() {

View File

@ -71,6 +71,7 @@ const (
AdapterEs = "es" AdapterEs = "es"
AdapterJianLiao = "jianliao" AdapterJianLiao = "jianliao"
AdapterSlack = "slack" AdapterSlack = "slack"
AdapterAliLS = "alils"
) )
// Legacy log level constants to ensure backwards compatibility. // Legacy log level constants to ensure backwards compatibility.
@ -274,7 +275,7 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error
line = 0 line = 0
} }
_, filename := path.Split(file) _, filename := path.Split(file)
msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "] " + msg msg = "[" + filename + ":" + strconv.Itoa(line) + "] " + msg
} }
//set level info in front of filename info //set level info in front of filename info
@ -491,9 +492,9 @@ func (bl *BeeLogger) flush() {
} }
// beeLogger references the used application logger. // beeLogger references the used application logger.
var beeLogger *BeeLogger = NewLogger() var beeLogger = NewLogger()
// GetLogger returns the default BeeLogger // GetBeeLogger returns the default BeeLogger
func GetBeeLogger() *BeeLogger { func GetBeeLogger() *BeeLogger {
return beeLogger return beeLogger
} }
@ -533,6 +534,7 @@ func Reset() {
beeLogger.Reset() beeLogger.Reset()
} }
// Async set the beelogger with Async mode and hold msglen messages
func Async(msgLen ...int64) *BeeLogger { func Async(msgLen ...int64) *BeeLogger {
return beeLogger.Async(msgLen...) return beeLogger.Async(msgLen...)
} }
@ -560,11 +562,7 @@ func SetLogFuncCallDepth(d int) {
// SetLogger sets a new logger. // SetLogger sets a new logger.
func SetLogger(adapter string, config ...string) error { func SetLogger(adapter string, config ...string) error {
err := beeLogger.SetLogger(adapter, config...) return beeLogger.SetLogger(adapter, config...)
if err != nil {
return err
}
return nil
} }
// Emergency logs a message at emergency level. // Emergency logs a message at emergency level.

View File

@ -139,6 +139,11 @@ var (
reset = string([]byte{27, 91, 48, 109}) reset = string([]byte{27, 91, 48, 109})
) )
// ColorByStatus return color by http code
// 2xx return Green
// 3xx return White
// 4xx return Yellow
// 5xx return Red
func ColorByStatus(cond bool, code int) string { func ColorByStatus(cond bool, code int) string {
switch { switch {
case code >= 200 && code < 300: case code >= 200 && code < 300:
@ -152,6 +157,14 @@ func ColorByStatus(cond bool, code int) string {
} }
} }
// ColorByMethod return color by http code
// GET return Blue
// POST return Cyan
// PUT return Yellow
// DELETE return Red
// PATCH return Green
// HEAD return Magenta
// OPTIONS return WHITE
func ColorByMethod(cond bool, method string) string { func ColorByMethod(cond bool, method string) string {
switch method { switch method {
case "GET": case "GET":
@ -173,10 +186,10 @@ func ColorByMethod(cond bool, method string) string {
} }
} }
// Guard Mutex to guarantee atomicity of W32Debug(string) function // Guard Mutex to guarantee atomic of W32Debug(string) function
var mu sync.Mutex var mu sync.Mutex
// Helper method to output colored logs in Windows terminals // W32Debug Helper method to output colored logs in Windows terminals
func W32Debug(msg string) { func W32Debug(msg string) {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()

View File

@ -21,11 +21,7 @@ func newSLACKWriter() Logger {
// Init SLACKWriter with json config string // Init SLACKWriter with json config string
func (s *SLACKWriter) Init(jsonconfig string) error { func (s *SLACKWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), s) return json.Unmarshal([]byte(jsonconfig), s)
if err != nil {
return err
}
return nil
} }
// WriteMsg write message in smtp writer. // WriteMsg write message in smtp writer.
@ -53,12 +49,10 @@ func (s *SLACKWriter) WriteMsg(when time.Time, msg string, level int) error {
// Flush implementing method. empty. // Flush implementing method. empty.
func (s *SLACKWriter) Flush() { func (s *SLACKWriter) Flush() {
return
} }
// Destroy implementing method. empty. // Destroy implementing method. empty.
func (s *SLACKWriter) Destroy() { func (s *SLACKWriter) Destroy() {
return
} }
func init() { func init() {

View File

@ -52,11 +52,7 @@ func newSMTPWriter() Logger {
// "level":LevelError // "level":LevelError
// } // }
func (s *SMTPWriter) Init(jsonconfig string) error { func (s *SMTPWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), s) return json.Unmarshal([]byte(jsonconfig), s)
if err != nil {
return err
}
return nil
} }
func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth { func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth {
@ -106,7 +102,7 @@ func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd
if err != nil { if err != nil {
return err return err
} }
_, err = w.Write([]byte(msgContent)) _, err = w.Write(msgContent)
if err != nil { if err != nil {
return err return err
} }
@ -116,12 +112,7 @@ func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd
return err return err
} }
err = client.Quit() return client.Quit()
if err != nil {
return err
}
return nil
} }
// WriteMsg write message in smtp writer. // WriteMsg write message in smtp writer.
@ -147,12 +138,10 @@ func (s *SMTPWriter) WriteMsg(when time.Time, msg string, level int) error {
// Flush implementing method. empty. // Flush implementing method. empty.
func (s *SMTPWriter) Flush() { func (s *SMTPWriter) Flush() {
return
} }
// Destroy implementing method. empty. // Destroy implementing method. empty.
func (s *SMTPWriter) Destroy() { func (s *SMTPWriter) Destroy() {
return
} }
func init() { func init() {

View File

@ -14,40 +14,382 @@
package migration package migration
// Table store the tablename and Column import (
type Table struct { "fmt"
TableName string
Columns []*Column "github.com/astaxie/beego"
)
// Index struct defines the structure of Index Columns
type Index struct {
Name string
} }
// Create return the create sql // Unique struct defines a single unique key combination
func (t *Table) Create() string { type Unique struct {
return "" Definition string
Columns []*Column
} }
// Drop return the drop sql //Column struct defines a single column of a table
func (t *Table) Drop() string {
return ""
}
// Column define the columns name type and Default
type Column struct { type Column struct {
Name string Name string
Type string Inc string
Default interface{} Null string
Default string
Unsign string
DataType string
remove bool
Modify bool
} }
// Create return create sql with the provided tbname and columns // Foreign struct defines a single foreign relationship
func Create(tbname string, columns ...Column) string { type Foreign struct {
return "" ForeignTable string
ForeignColumn string
OnDelete string
OnUpdate string
Column
} }
// Drop return the drop sql with the provided tbname and columns // RenameColumn struct allows renaming of columns
func Drop(tbname string, columns ...Column) string { type RenameColumn struct {
return "" OldName string
OldNull string
OldDefault string
OldUnsign string
OldDataType string
NewName string
Column
} }
// TableDDL is still in think // CreateTable creates the table on system
func TableDDL(tbname string, columns ...Column) string { func (m *Migration) CreateTable(tablename, engine, charset string, p ...func()) {
return "" m.TableName = tablename
m.Engine = engine
m.Charset = charset
m.ModifyType = "create"
}
// AlterTable set the ModifyType to alter
func (m *Migration) AlterTable(tablename string) {
m.TableName = tablename
m.ModifyType = "alter"
}
// NewCol creates a new standard column and attaches it to m struct
func (m *Migration) NewCol(name string) *Column {
col := &Column{Name: name}
m.AddColumns(col)
return col
}
//PriCol creates a new primary column and attaches it to m struct
func (m *Migration) PriCol(name string) *Column {
col := &Column{Name: name}
m.AddColumns(col)
m.AddPrimary(col)
return col
}
//UniCol creates / appends columns to specified unique key and attaches it to m struct
func (m *Migration) UniCol(uni, name string) *Column {
col := &Column{Name: name}
m.AddColumns(col)
uniqueOriginal := &Unique{}
for _, unique := range m.Uniques {
if unique.Definition == uni {
unique.AddColumnsToUnique(col)
uniqueOriginal = unique
}
}
if uniqueOriginal.Definition == "" {
unique := &Unique{Definition: uni}
unique.AddColumnsToUnique(col)
m.AddUnique(unique)
}
return col
}
//ForeignCol creates a new foreign column and returns the instance of column
func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreign *Foreign) {
foreign = &Foreign{ForeignColumn: foreigncol, ForeignTable: foreigntable}
foreign.Name = colname
m.AddForeign(foreign)
return foreign
}
//SetOnDelete sets the on delete of foreign
func (foreign *Foreign) SetOnDelete(del string) *Foreign {
foreign.OnDelete = "ON DELETE" + del
return foreign
}
//SetOnUpdate sets the on update of foreign
func (foreign *Foreign) SetOnUpdate(update string) *Foreign {
foreign.OnUpdate = "ON UPDATE" + update
return foreign
}
//Remove marks the columns to be removed.
//it allows reverse m to create the column.
func (c *Column) Remove() {
c.remove = true
}
//SetAuto enables auto_increment of column (can be used once)
func (c *Column) SetAuto(inc bool) *Column {
if inc {
c.Inc = "auto_increment"
}
return c
}
//SetNullable sets the column to be null
func (c *Column) SetNullable(null bool) *Column {
if null {
c.Null = ""
} else {
c.Null = "NOT NULL"
}
return c
}
//SetDefault sets the default value, prepend with "DEFAULT "
func (c *Column) SetDefault(def string) *Column {
c.Default = "DEFAULT " + def
return c
}
//SetUnsigned sets the column to be unsigned int
func (c *Column) SetUnsigned(unsign bool) *Column {
if unsign {
c.Unsign = "UNSIGNED"
}
return c
}
//SetDataType sets the dataType of the column
func (c *Column) SetDataType(dataType string) *Column {
c.DataType = dataType
return c
}
//SetOldNullable allows reverting to previous nullable on reverse ms
func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn {
if null {
c.OldNull = ""
} else {
c.OldNull = "NOT NULL"
}
return c
}
//SetOldDefault allows reverting to previous default on reverse ms
func (c *RenameColumn) SetOldDefault(def string) *RenameColumn {
c.OldDefault = def
return c
}
//SetOldUnsigned allows reverting to previous unsgined on reverse ms
func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn {
if unsign {
c.OldUnsign = "UNSIGNED"
}
return c
}
//SetOldDataType allows reverting to previous datatype on reverse ms
func (c *RenameColumn) SetOldDataType(dataType string) *RenameColumn {
c.OldDataType = dataType
return c
}
//SetPrimary adds the columns to the primary key (can only be used any number of times in only one m)
func (c *Column) SetPrimary(m *Migration) *Column {
m.Primary = append(m.Primary, c)
return c
}
//AddColumnsToUnique adds the columns to Unique Struct
func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique {
unique.Columns = append(unique.Columns, columns...)
return unique
}
//AddColumns adds columns to m struct
func (m *Migration) AddColumns(columns ...*Column) *Migration {
m.Columns = append(m.Columns, columns...)
return m
}
//AddPrimary adds the column to primary in m struct
func (m *Migration) AddPrimary(primary *Column) *Migration {
m.Primary = append(m.Primary, primary)
return m
}
//AddUnique adds the column to unique in m struct
func (m *Migration) AddUnique(unique *Unique) *Migration {
m.Uniques = append(m.Uniques, unique)
return m
}
//AddForeign adds the column to foreign in m struct
func (m *Migration) AddForeign(foreign *Foreign) *Migration {
m.Foreigns = append(m.Foreigns, foreign)
return m
}
//AddIndex adds the column to index in m struct
func (m *Migration) AddIndex(index *Index) *Migration {
m.Indexes = append(m.Indexes, index)
return m
}
//RenameColumn allows renaming of columns
func (m *Migration) RenameColumn(from, to string) *RenameColumn {
rename := &RenameColumn{OldName: from, NewName: to}
m.Renames = append(m.Renames, rename)
return rename
}
//GetSQL returns the generated sql depending on ModifyType
func (m *Migration) GetSQL() (sql string) {
sql = ""
switch m.ModifyType {
case "create":
{
sql += fmt.Sprintf("CREATE TABLE `%s` (", m.TableName)
for index, column := range m.Columns {
sql += fmt.Sprintf("\n `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
if len(m.Columns) > index+1 {
sql += ","
}
}
if len(m.Primary) > 0 {
sql += fmt.Sprintf(",\n PRIMARY KEY( ")
}
for index, column := range m.Primary {
sql += fmt.Sprintf(" `%s`", column.Name)
if len(m.Primary) > index+1 {
sql += ","
}
}
if len(m.Primary) > 0 {
sql += fmt.Sprintf(")")
}
for _, unique := range m.Uniques {
sql += fmt.Sprintf(",\n UNIQUE KEY `%s`( ", unique.Definition)
for index, column := range unique.Columns {
sql += fmt.Sprintf(" `%s`", column.Name)
if len(unique.Columns) > index+1 {
sql += ","
}
}
sql += fmt.Sprintf(")")
}
for _, foreign := range m.Foreigns {
sql += fmt.Sprintf(",\n `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default)
sql += fmt.Sprintf(",\n KEY `%s_%s_foreign`(`%s`),", m.TableName, foreign.Column.Name, foreign.Column.Name)
sql += fmt.Sprintf("\n CONSTRAINT `%s_%s_foreign` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) %s %s", m.TableName, foreign.Column.Name, foreign.Column.Name, foreign.ForeignTable, foreign.ForeignColumn, foreign.OnDelete, foreign.OnUpdate)
}
sql += fmt.Sprintf(")ENGINE=%s DEFAULT CHARSET=%s;", m.Engine, m.Charset)
break
}
case "alter":
{
sql += fmt.Sprintf("ALTER TABLE `%s` ", m.TableName)
for index, column := range m.Columns {
if !column.remove {
beego.BeeLogger.Info("col")
sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
} else {
sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name)
}
if len(m.Columns) > index {
sql += ","
}
}
for index, column := range m.Renames {
sql += fmt.Sprintf("CHANGE COLUMN `%s` `%s` %s %s %s %s %s", column.OldName, column.NewName, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
if len(m.Renames) > index+1 {
sql += ","
}
}
for index, foreign := range m.Foreigns {
sql += fmt.Sprintf("ADD `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default)
sql += fmt.Sprintf(",\n ADD KEY `%s_%s_foreign`(`%s`)", m.TableName, foreign.Column.Name, foreign.Column.Name)
sql += fmt.Sprintf(",\n ADD CONSTRAINT `%s_%s_foreign` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) %s %s", m.TableName, foreign.Column.Name, foreign.Column.Name, foreign.ForeignTable, foreign.ForeignColumn, foreign.OnDelete, foreign.OnUpdate)
if len(m.Foreigns) > index+1 {
sql += ","
}
}
sql += ";"
break
}
case "reverse":
{
sql += fmt.Sprintf("ALTER TABLE `%s`", m.TableName)
for index, column := range m.Columns {
if column.remove {
sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
} else {
sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name)
}
if len(m.Columns) > index {
sql += ","
}
}
if len(m.Primary) > 0 {
sql += fmt.Sprintf("\n DROP PRIMARY KEY,")
}
for index, unique := range m.Uniques {
sql += fmt.Sprintf("\n DROP KEY `%s`", unique.Definition)
if len(m.Uniques) > index {
sql += ","
}
}
for index, column := range m.Renames {
sql += fmt.Sprintf("\n CHANGE COLUMN `%s` `%s` %s %s %s %s", column.NewName, column.OldName, column.OldDataType, column.OldUnsign, column.OldNull, column.OldDefault)
if len(m.Renames) > index {
sql += ","
}
}
for _, foreign := range m.Foreigns {
sql += fmt.Sprintf("\n DROP KEY `%s_%s_foreign`", m.TableName, foreign.Column.Name)
sql += fmt.Sprintf(",\n DROP FOREIGN KEY `%s_%s_foreign`", m.TableName, foreign.Column.Name)
sql += fmt.Sprintf(",\n DROP COLUMN `%s`", foreign.Name)
}
sql += ";"
}
case "delete":
{
sql += fmt.Sprintf("DROP TABLE IF EXISTS `%s`;", m.TableName)
}
}
return
} }

32
migration/doc.go Normal file
View File

@ -0,0 +1,32 @@
// Package migration enables you to generate migrations back and forth. It generates both migrations.
//
// //Creates a table
// m.CreateTable("tablename","InnoDB","utf8");
//
// //Alter a table
// m.AlterTable("tablename")
//
// Standard Column Methods
// * SetDataType
// * SetNullable
// * SetDefault
// * SetUnsigned (use only on integer types unless produces error)
//
// //Sets a primary column, multiple calls allowed, standard column methods available
// m.PriCol("id").SetAuto(true).SetNullable(false).SetDataType("INT(10)").SetUnsigned(true)
//
// //UniCol Can be used multiple times, allows standard Column methods. Use same "index" string to add to same index
// m.UniCol("index","column")
//
// //Standard Column Initialisation, can call .Remove() after NewCol("") on alter to remove
// m.NewCol("name").SetDataType("VARCHAR(255) COLLATE utf8_unicode_ci").SetNullable(false)
// m.NewCol("value").SetDataType("DOUBLE(8,2)").SetNullable(false)
//
// //Rename Columns , only use with Alter table, doesn't works with Create, prefix standard column methods with "Old" to
// //create a true reversible migration eg: SetOldDataType("DOUBLE(12,3)")
// m.RenameColumn("from","to")...
//
// //Foreign Columns, single columns are only supported, SetOnDelete & SetOnUpdate are available, call appropriately.
// //Supports standard column methods, automatic reverse.
// m.ForeignCol("local_col","foreign_col","foreign_table")
package migration

View File

@ -52,6 +52,26 @@ type Migrationer interface {
GetCreated() int64 GetCreated() int64
} }
//Migration defines the migrations by either SQL or DDL
type Migration struct {
sqls []string
Created string
TableName string
Engine string
Charset string
ModifyType string
Columns []*Column
Indexes []*Index
Primary []*Column
Uniques []*Unique
Foreigns []*Foreign
Renames []*RenameColumn
RemoveColumns []*Column
RemoveIndexes []*Index
RemoveUniques []*Unique
RemoveForeigns []*Foreign
}
var ( var (
migrationMap map[string]Migrationer migrationMap map[string]Migrationer
) )
@ -60,20 +80,34 @@ func init() {
migrationMap = make(map[string]Migrationer) migrationMap = make(map[string]Migrationer)
} }
// Migration the basic type which will implement the basic type
type Migration struct {
sqls []string
Created string
}
// Up implement in the Inheritance struct for upgrade // Up implement in the Inheritance struct for upgrade
func (m *Migration) Up() { func (m *Migration) Up() {
switch m.ModifyType {
case "reverse":
m.ModifyType = "alter"
case "delete":
m.ModifyType = "create"
}
m.sqls = append(m.sqls, m.GetSQL())
} }
// Down implement in the Inheritance struct for down // Down implement in the Inheritance struct for down
func (m *Migration) Down() { func (m *Migration) Down() {
switch m.ModifyType {
case "alter":
m.ModifyType = "reverse"
case "create":
m.ModifyType = "delete"
}
m.sqls = append(m.sqls, m.GetSQL())
}
//Migrate adds the SQL to the execution list
func (m *Migration) Migrate(migrationType string) {
m.ModifyType = migrationType
m.sqls = append(m.sqls, m.GetSQL())
} }
// SQL add sql want to execute // SQL add sql want to execute

View File

@ -267,13 +267,12 @@ func addPrefix(t *Tree, prefix string) {
addPrefix(t.wildcard, prefix) addPrefix(t.wildcard, prefix)
} }
for _, l := range t.leaves { for _, l := range t.leaves {
if c, ok := l.runObject.(*controllerInfo); ok { if c, ok := l.runObject.(*ControllerInfo); ok {
if !strings.HasPrefix(c.pattern, prefix) { if !strings.HasPrefix(c.pattern, prefix) {
c.pattern = prefix + c.pattern c.pattern = prefix + c.pattern
} }
} }
} }
} }
// NSCond is Namespace Condition // NSCond is Namespace Condition
@ -284,16 +283,16 @@ func NSCond(cond namespaceCond) LinkNamespace {
} }
// NSBefore Namespace BeforeRouter filter // NSBefore Namespace BeforeRouter filter
func NSBefore(filiterList ...FilterFunc) LinkNamespace { func NSBefore(filterList ...FilterFunc) LinkNamespace {
return func(ns *Namespace) { return func(ns *Namespace) {
ns.Filter("before", filiterList...) ns.Filter("before", filterList...)
} }
} }
// NSAfter add Namespace FinishRouter filter // NSAfter add Namespace FinishRouter filter
func NSAfter(filiterList ...FilterFunc) LinkNamespace { func NSAfter(filterList ...FilterFunc) LinkNamespace {
return func(ns *Namespace) { return func(ns *Namespace) {
ns.Filter("after", filiterList...) ns.Filter("after", filterList...)
} }
} }

View File

@ -139,10 +139,7 @@ func TestNamespaceCond(t *testing.T) {
ns := NewNamespace("/v2") ns := NewNamespace("/v2")
ns.Cond(func(ctx *context.Context) bool { ns.Cond(func(ctx *context.Context) bool {
if ctx.Input.Domain() == "beego.me" { return ctx.Input.Domain() == "beego.me"
return true
}
return false
}). }).
AutoRouter(&TestController{}) AutoRouter(&TestController{})
AddNamespace(ns) AddNamespace(ns)

View File

@ -150,7 +150,7 @@ func (d *commandSyncDb) Run() error {
} }
for _, fi := range mi.fields.fieldsDB { for _, fi := range mi.fields.fieldsDB {
if _, ok := columns[fi.column]; ok == false { if _, ok := columns[fi.column]; !ok {
fields = append(fields, fi) fields = append(fields, fi)
} }
} }
@ -175,7 +175,7 @@ func (d *commandSyncDb) Run() error {
} }
for _, idx := range indexes[mi.table] { for _, idx := range indexes[mi.table] {
if d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) == false { if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) {
if !d.noInfo { if !d.noInfo {
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
} }

View File

@ -89,7 +89,7 @@ checkColumn:
col = T["float64"] col = T["float64"]
case TypeDecimalField: case TypeDecimalField:
s := T["float64-decimal"] s := T["float64-decimal"]
if strings.Index(s, "%d") == -1 { if !strings.Contains(s, "%d") {
col = s col = s
} else { } else {
col = fmt.Sprintf(s, fi.digits, fi.decimals) col = fmt.Sprintf(s, fi.digits, fi.decimals)
@ -120,7 +120,7 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string {
Q := al.DbBaser.TableQuote() Q := al.DbBaser.TableQuote()
typ := getColumnTyp(al, fi) typ := getColumnTyp(al, fi)
if fi.null == false { if !fi.null {
typ += " " + "NOT NULL" typ += " " + "NOT NULL"
} }
@ -172,7 +172,7 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
} else { } else {
column += col column += col
if fi.null == false { if !fi.null {
column += " " + "NOT NULL" column += " " + "NOT NULL"
} }
@ -192,7 +192,7 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
} }
} }
if strings.Index(column, "%COL%") != -1 { if strings.Contains(column, "%COL%") {
column = strings.Replace(column, "%COL%", fi.column, -1) column = strings.Replace(column, "%COL%", fi.column, -1)
} }

View File

@ -48,7 +48,7 @@ var (
"lte": true, "lte": true,
"eq": true, "eq": true,
"nq": true, "nq": true,
"ne": true, "ne": true,
"startswith": true, "startswith": true,
"endswith": true, "endswith": true,
"istartswith": true, "istartswith": true,
@ -87,7 +87,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
} else { } else {
panic(fmt.Errorf("wrong db field/column name `%s` for model `%s`", column, mi.fullName)) panic(fmt.Errorf("wrong db field/column name `%s` for model `%s`", column, mi.fullName))
} }
if fi.dbcol == false || fi.auto && skipAuto { if !fi.dbcol || fi.auto && skipAuto {
continue continue
} }
value, err := d.collectFieldValue(mi, fi, ind, insert, tz) value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
@ -224,7 +224,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
value = nil value = nil
} }
} }
if fi.null == false && value == nil { if !fi.null && value == nil {
return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName) return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName)
} }
} }
@ -271,7 +271,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
dbcols := make([]string, 0, len(mi.fields.dbcols)) dbcols := make([]string, 0, len(mi.fields.dbcols))
marks := make([]string, 0, len(mi.fields.dbcols)) marks := make([]string, 0, len(mi.fields.dbcols))
for _, fi := range mi.fields.fieldsDB { for _, fi := range mi.fields.fieldsDB {
if fi.auto == false { if !fi.auto {
dbcols = append(dbcols, fi.column) dbcols = append(dbcols, fi.column)
marks = append(marks, "?") marks = append(marks, "?")
} }
@ -326,7 +326,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
} else { } else {
// default use pk value as where condtion. // default use pk value as where condtion.
pkColumn, pkValue, ok := getExistPk(mi, ind) pkColumn, pkValue, ok := getExistPk(mi, ind)
if ok == false { if !ok {
return ErrMissPK return ErrMissPK
} }
whereCols = []string{pkColumn} whereCols = []string{pkColumn}
@ -507,10 +507,9 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
case DRPostgres: case DRPostgres:
if len(args) == 0 { if len(args) == 0 {
return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName) return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName)
} else {
args0 = strings.ToLower(args[0])
iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0)
} }
args0 = strings.ToLower(args[0])
iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0)
default: default:
return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName) return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName)
} }
@ -592,7 +591,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
row := q.QueryRow(query, values...) row := q.QueryRow(query, values...)
var id int64 var id int64
err = row.Scan(&id) err = row.Scan(&id)
if err.Error() == `pq: syntax error at or near "ON"` { if err != nil && err.Error() == `pq: syntax error at or near "ON"` {
err = fmt.Errorf("postgres version must 9.5 or higher") err = fmt.Errorf("postgres version must 9.5 or higher")
} }
return id, err return id, err
@ -601,7 +600,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
// execute update sql dbQuerier with given struct reflect.Value. // execute update sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind) pkName, pkValue, ok := getExistPk(mi, ind)
if ok == false { if !ok {
return 0, ErrMissPK return 0, ErrMissPK
} }
@ -654,7 +653,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
} else { } else {
// default use pk value as where condtion. // default use pk value as where condtion.
pkColumn, pkValue, ok := getExistPk(mi, ind) pkColumn, pkValue, ok := getExistPk(mi, ind)
if ok == false { if !ok {
return 0, ErrMissPK return 0, ErrMissPK
} }
whereCols = []string{pkColumn} whereCols = []string{pkColumn}
@ -699,7 +698,7 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
columns := make([]string, 0, len(params)) columns := make([]string, 0, len(params))
values := make([]interface{}, 0, len(params)) values := make([]interface{}, 0, len(params))
for col, val := range params { for col, val := range params {
if fi, ok := mi.fields.GetByAny(col); ok == false || fi.dbcol == false { if fi, ok := mi.fields.GetByAny(col); !ok || !fi.dbcol {
panic(fmt.Errorf("wrong field/column name `%s`", col)) panic(fmt.Errorf("wrong field/column name `%s`", col))
} else { } else {
columns = append(columns, fi.column) columns = append(columns, fi.column)
@ -834,7 +833,11 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
if err := rs.Scan(&ref); err != nil { if err := rs.Scan(&ref); err != nil {
return 0, err return 0, err
} }
args = append(args, reflect.ValueOf(ref).Interface()) pkValue, err := d.convertValueFromDB(mi.fields.pk, reflect.ValueOf(ref).Interface(), tz)
if err != nil {
return 0, err
}
args = append(args, pkValue)
cnt++ cnt++
} }
@ -929,7 +932,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
if hasRel { if hasRel {
for _, fi := range mi.fields.fieldsDB { for _, fi := range mi.fields.fieldsDB {
if fi.fieldType&IsRelField > 0 { if fi.fieldType&IsRelField > 0 {
if maps[fi.column] == false { if !maps[fi.column] {
tCols = append(tCols, fi.column) tCols = append(tCols, fi.column)
} }
} }
@ -987,7 +990,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
var cnt int64 var cnt int64
for rs.Next() { for rs.Next() {
if one && cnt == 0 || one == false { if one && cnt == 0 || !one {
if err := rs.Scan(refs...); err != nil { if err := rs.Scan(refs...); err != nil {
return 0, err return 0, err
} }
@ -1067,7 +1070,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
cnt++ cnt++
} }
if one == false { if !one {
if cnt > 0 { if cnt > 0 {
ind.Set(slice) ind.Set(slice)
} else { } else {
@ -1110,7 +1113,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
// generate sql with replacing operator string placeholders and replaced values. // generate sql with replacing operator string placeholders and replaced values.
func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
sql := "" var sql string
params := getFlatParams(fi, args, tz) params := getFlatParams(fi, args, tz)
if len(params) == 0 { if len(params) == 0 {
@ -1357,7 +1360,7 @@ end:
func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) { func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) {
fieldType := fi.fieldType fieldType := fi.fieldType
isNative := fi.isFielder == false isNative := !fi.isFielder
setValue: setValue:
switch { switch {
@ -1533,7 +1536,7 @@ setValue:
} }
} }
if isNative == false { if !isNative {
fd := field.Addr().Interface().(Fielder) fd := field.Addr().Interface().(Fielder)
err := fd.SetRaw(value) err := fd.SetRaw(value)
if err != nil { if err != nil {
@ -1594,7 +1597,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
infos = make([]*fieldInfo, 0, len(exprs)) infos = make([]*fieldInfo, 0, len(exprs))
for _, ex := range exprs { for _, ex := range exprs {
index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep)) index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep))
if suc == false { if !suc {
panic(fmt.Errorf("unknown field/column name `%s`", ex)) panic(fmt.Errorf("unknown field/column name `%s`", ex))
} }
cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q)) cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q))
@ -1733,7 +1736,7 @@ func (d *dbBase) TableQuote() string {
return "`" return "`"
} }
// replace value placeholer in parametered sql string. // replace value placeholder in parametered sql string.
func (d *dbBase) ReplaceMarks(query *string) { func (d *dbBase) ReplaceMarks(query *string) {
// default use `?` as mark, do nothing // default use `?` as mark, do nothing
} }

View File

@ -60,6 +60,8 @@ var (
"sqlite3": DRSqlite, "sqlite3": DRSqlite,
"tidb": DRTiDB, "tidb": DRTiDB,
"oracle": DROracle, "oracle": DROracle,
"oci8": DROracle, // github.com/mattn/go-oci8
"ora": DROracle, //https://github.com/rana/ora
} }
dbBasers = map[DriverType]dbBaser{ dbBasers = map[DriverType]dbBaser{
DRMySQL: newdbBaseMysql(), DRMySQL: newdbBaseMysql(),
@ -186,7 +188,7 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
} }
if dataBaseCache.add(aliasName, al) == false { if !dataBaseCache.add(aliasName, al) {
return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
} }
@ -244,11 +246,11 @@ end:
// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
func RegisterDriver(driverName string, typ DriverType) error { func RegisterDriver(driverName string, typ DriverType) error {
if t, ok := drivers[driverName]; ok == false { if t, ok := drivers[driverName]; !ok {
drivers[driverName] = typ drivers[driverName] = typ
} else { } else {
if t != typ { if t != typ {
return fmt.Errorf("driverName `%s` db driver already registered and is other type\n", driverName) return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName)
} }
} }
return nil return nil
@ -259,7 +261,7 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error {
if al, ok := dataBaseCache.get(aliasName); ok { if al, ok := dataBaseCache.get(aliasName); ok {
al.TZ = tz al.TZ = tz
} else { } else {
return fmt.Errorf("DataBase alias name `%s` not registered\n", aliasName) return fmt.Errorf("DataBase alias name `%s` not registered", aliasName)
} }
return nil return nil
} }
@ -294,5 +296,5 @@ func GetDB(aliasNames ...string) (*sql.DB, error) {
if ok { if ok {
return al.DB, nil return al.DB, nil
} }
return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name) return nil, fmt.Errorf("DataBase of alias name `%s` not found", name)
} }

View File

@ -103,8 +103,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool
// If no will insert // If no will insert
// Add "`" for mysql sql building // Add "`" for mysql sql building
func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
var iouStr string
iouStr := ""
argsMap := map[string]string{} argsMap := map[string]string{}
iouStr = "ON DUPLICATE KEY UPDATE" iouStr = "ON DUPLICATE KEY UPDATE"

View File

@ -94,3 +94,43 @@ func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool
row.Scan(&cnt) row.Scan(&cnt)
return cnt > 0 return cnt > 0
} }
// execute insert sql with given struct and given values.
// insert the given values, not the field values in struct.
func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
Q := d.ins.TableQuote()
marks := make([]string, len(names))
for i := range marks {
marks[i] = ":" + names[i]
}
sep := fmt.Sprintf("%s, %s", Q, Q)
qmarks := strings.Join(marks, ", ")
columns := strings.Join(names, sep)
multi := len(values) / len(names)
if isMulti {
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
}
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
res, err := q.Exec(query, values...)
if err == nil {
if isMulti {
return res.RowsAffected()
}
return res.LastInsertId()
}
return 0, err
}
row := q.QueryRow(query, values...)
var id int64
err := row.Scan(&id)
return id, err
}

View File

@ -134,7 +134,7 @@ func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
var tmp, index sql.NullString var tmp, index sql.NullString
rows.Scan(&tmp, &index, &tmp) rows.Scan(&tmp, &index, &tmp, &tmp, &tmp)
if name == index.String { if name == index.String {
return true return true
} }

View File

@ -63,7 +63,7 @@ func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
// add table info to collection. // add table info to collection.
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
name := strings.Join(names, ExprSep) name := strings.Join(names, ExprSep)
if _, ok := t.tablesM[name]; ok == false { if _, ok := t.tablesM[name]; !ok {
i := len(t.tables) + 1 i := len(t.tables) + 1
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
t.tablesM[name] = jt t.tablesM[name] = jt
@ -261,7 +261,7 @@ loopFor:
fiN, okN = mmi.fields.GetByAny(exprs[i+1]) fiN, okN = mmi.fields.GetByAny(exprs[i+1])
} }
if isRel && (fi.mi.isThrough == false || num != i) { if isRel && (!fi.mi.isThrough || num != i) {
if fi.null || t.skipEnd { if fi.null || t.skipEnd {
inner = false inner = false
} }
@ -364,7 +364,7 @@ func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (whe
} }
index, _, fi, suc := t.parseExprs(mi, exprs) index, _, fi, suc := t.parseExprs(mi, exprs)
if suc == false { if !suc {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep))) panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
} }
@ -383,7 +383,7 @@ func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (whe
} }
} }
if sub == false && where != "" { if !sub && where != "" {
where = "WHERE " + where where = "WHERE " + where
} }
@ -403,7 +403,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
exprs := strings.Split(group, ExprSep) exprs := strings.Split(group, ExprSep)
index, _, fi, suc := t.parseExprs(t.mi, exprs) index, _, fi, suc := t.parseExprs(t.mi, exprs)
if suc == false { if !suc {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
} }
@ -432,7 +432,7 @@ func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
exprs := strings.Split(order, ExprSep) exprs := strings.Split(order, ExprSep)
index, _, fi, suc := t.parseExprs(t.mi, exprs) index, _, fi, suc := t.parseExprs(t.mi, exprs)
if suc == false { if !suc {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
} }

View File

@ -41,6 +41,8 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac
vu := v.Int() vu := v.Int()
exist = true exist = true
value = vu value = vu
} else if fi.fieldType&IsRelField > 0 {
_, value, exist = getExistPk(fi.relModelInfo, reflect.Indirect(v))
} else { } else {
vu := v.String() vu := v.String()
exist = vu != "" exist = vu != ""

View File

@ -75,7 +75,7 @@ func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) {
} }
if mi.fields.pk == nil { if mi.fields.pk == nil {
fmt.Printf("<orm.RegisterModel> `%s` need a primary key field, default use 'id' if not set\n", name) fmt.Printf("<orm.RegisterModel> `%s` needs a primary key field, default is to use 'id' if not set\n", name)
os.Exit(2) os.Exit(2)
} }
@ -117,7 +117,7 @@ func bootStrap() {
name := getFullName(elm) name := getFullName(elm)
mii, ok := modelCache.getByFullName(name) mii, ok := modelCache.getByFullName(name)
if !ok || mii.pkg != elm.PkgPath() { if !ok || mii.pkg != elm.PkgPath() {
err = fmt.Errorf("can not found rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String())
goto end goto end
} }
fi.relModelInfo = mii fi.relModelInfo = mii
@ -128,7 +128,7 @@ func bootStrap() {
if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) {
pn := fi.relThrough[:i] pn := fi.relThrough[:i]
rmi, ok := modelCache.getByFullName(fi.relThrough) rmi, ok := modelCache.getByFullName(fi.relThrough)
if ok == false || pn != rmi.pkg { if !ok || pn != rmi.pkg {
err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough) err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough)
goto end goto end
} }
@ -171,7 +171,7 @@ func bootStrap() {
break break
} }
} }
if inModel == false { if !inModel {
rmi := fi.relModelInfo rmi := fi.relModelInfo
ffi := new(fieldInfo) ffi := new(fieldInfo)
ffi.name = mi.name ffi.name = mi.name
@ -185,7 +185,7 @@ func bootStrap() {
} else { } else {
ffi.fieldType = RelReverseMany ffi.fieldType = RelReverseMany
} }
if rmi.fields.Add(ffi) == false { if !rmi.fields.Add(ffi) {
added := false added := false
for cnt := 0; cnt < 5; cnt++ { for cnt := 0; cnt < 5; cnt++ {
ffi.name = fmt.Sprintf("%s%d", mi.name, cnt) ffi.name = fmt.Sprintf("%s%d", mi.name, cnt)
@ -195,7 +195,7 @@ func bootStrap() {
break break
} }
} }
if added == false { if !added {
panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName)) panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName))
} }
} }
@ -248,7 +248,7 @@ func bootStrap() {
break mForA break mForA
} }
} }
if found == false { if !found {
err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
goto end goto end
} }
@ -267,7 +267,7 @@ func bootStrap() {
break mForB break mForB
} }
} }
if found == false { if !found {
mForC: mForC:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] {
conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough || conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough ||
@ -287,7 +287,7 @@ func bootStrap() {
} }
} }
} }
if found == false { if !found {
err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
goto end goto end
} }

View File

@ -47,7 +47,7 @@ func (f *fields) Add(fi *fieldInfo) (added bool) {
} else { } else {
return return
} }
if _, ok := f.fieldsByType[fi.fieldType]; ok == false { if _, ok := f.fieldsByType[fi.fieldType]; !ok {
f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0) f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0)
} }
f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi) f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi)
@ -334,12 +334,12 @@ checkType:
switch onDelete { switch onDelete {
case odCascade, odDoNothing: case odCascade, odDoNothing:
case odSetDefault: case odSetDefault:
if initial.Exist() == false { if !initial.Exist() {
err = errors.New("on_delete: set_default need set field a default value") err = errors.New("on_delete: set_default need set field a default value")
goto end goto end
} }
case odSetNULL: case odSetNULL:
if fi.null == false { if !fi.null {
err = errors.New("on_delete: set_null need set field null") err = errors.New("on_delete: set_null need set field null")
goto end goto end
} }

View File

@ -78,7 +78,7 @@ func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int)
fi.fieldIndex = append(index, i) fi.fieldIndex = append(index, i)
fi.mi = mi fi.mi = mi
fi.inModel = true fi.inModel = true
if mi.fields.Add(fi) == false { if !mi.fields.Add(fi) {
err = fmt.Errorf("duplicate column name: %s", fi.column) err = fmt.Errorf("duplicate column name: %s", fi.column)
break break
} }

View File

@ -406,6 +406,11 @@ type UintPk struct {
Name string Name string
} }
type PtrPk struct {
ID *IntegerPk `orm:"pk;rel(one)"`
Positive bool
}
var DBARGS = struct { var DBARGS = struct {
Driver string Driver string
Source string Source string

View File

@ -107,7 +107,7 @@ func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect
if mi, ok := modelCache.getByFullName(name); ok { if mi, ok := modelCache.getByFullName(name); ok {
return mi, ind return mi, ind
} }
panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name)) panic(fmt.Errorf("<Ormer> table: `%s` not found, make sure it was registered with `RegisterModel()`", name))
} }
// get field info from model info by given field name // get field info from model info by given field name
@ -122,21 +122,13 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
// read data to model // read data to model
func (o *orm) Read(md interface{}, cols ...string) error { func (o *orm) Read(md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md, true) mi, ind := o.getMiInd(md, true)
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
if err != nil {
return err
}
return nil
} }
// read data to model, like Read(), but use "SELECT FOR UPDATE" form // read data to model, like Read(), but use "SELECT FOR UPDATE" form
func (o *orm) ReadForUpdate(md interface{}, cols ...string) error { func (o *orm) ReadForUpdate(md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md, true) mi, ind := o.getMiInd(md, true)
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true)
if err != nil {
return err
}
return nil
} }
// Try to read a row from the database, or insert one if it doesn't exist // Try to read a row from the database, or insert one if it doesn't exist
@ -153,6 +145,8 @@ func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, i
id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex) id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex)
if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
id = int64(vid.Uint()) id = int64(vid.Uint())
} else if mi.fields.pk.rel {
return o.ReadOrCreate(vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name)
} else { } else {
id = vid.Int() id = vid.Int()
} }
@ -236,15 +230,11 @@ func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64
// cols set the columns those want to update. // cols set the columns those want to update.
func (o *orm) Update(md interface{}, cols ...string) (int64, error) { func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
mi, ind := o.getMiInd(md, true) mi, ind := o.getMiInd(md, true)
num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
if err != nil {
return num, err
}
return num, nil
} }
// delete model in database // delete model in database
// cols shows the delete conditions values read from. deafult is pk // cols shows the delete conditions values read from. default is pk
func (o *orm) Delete(md interface{}, cols ...string) (int64, error) { func (o *orm) Delete(md interface{}, cols ...string) (int64, error) {
mi, ind := o.getMiInd(md, true) mi, ind := o.getMiInd(md, true)
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols)
@ -359,7 +349,7 @@ func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo,
fi := o.getFieldInfo(mi, name) fi := o.getFieldInfo(mi, name)
_, _, exist := getExistPk(mi, ind) _, _, exist := getExistPk(mi, ind)
if exist == false { if !exist {
panic(ErrMissPK) panic(ErrMissPK)
} }
@ -430,7 +420,7 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
// table name can be string or struct. // table name can be string or struct.
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
name := "" var name string
if table, ok := ptrStructOrTableName.(string); ok { if table, ok := ptrStructOrTableName.(string); ok {
name = snakeString(table) name = snakeString(table)
if mi, ok := modelCache.get(name); ok { if mi, ok := modelCache.get(name); ok {
@ -487,7 +477,7 @@ func (o *orm) Begin() error {
// commit transaction // commit transaction
func (o *orm) Commit() error { func (o *orm) Commit() error {
if o.isTx == false { if !o.isTx {
return ErrTxDone return ErrTxDone
} }
err := o.db.(txEnder).Commit() err := o.db.(txEnder).Commit()
@ -502,7 +492,7 @@ func (o *orm) Commit() error {
// rollback transaction // rollback transaction
func (o *orm) Rollback() error { func (o *orm) Rollback() error {
if o.isTx == false { if !o.isTx {
return ErrTxDone return ErrTxDone
} }
err := o.db.(txEnder).Rollback() err := o.db.(txEnder).Rollback()

View File

@ -72,7 +72,7 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
} }
_, v1, exist := getExistPk(o.mi, o.ind) _, v1, exist := getExistPk(o.mi, o.ind)
if exist == false { if !exist {
panic(ErrMissPK) panic(ErrMissPK)
} }
@ -87,7 +87,7 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
v2 = ind.Interface() v2 = ind.Interface()
} else { } else {
_, v2, exist = getExistPk(fi.relModelInfo, ind) _, v2, exist = getExistPk(fi.relModelInfo, ind)
if exist == false { if !exist {
panic(ErrMissPK) panic(ErrMissPK)
} }
} }
@ -104,11 +104,7 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
fi := o.fi fi := o.fi
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
nums, err := qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete() return qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete()
if err != nil {
return nums, err
}
return nums, nil
} }
// check model is existed in relationship of origin model // check model is existed in relationship of origin model

View File

@ -153,6 +153,11 @@ func (o querySet) SetCond(cond *Condition) QuerySeter {
return &o return &o
} }
// get condition from QuerySeter
func (o querySet) GetCond() *Condition {
return o.cond
}
// return QuerySeter execution result number // return QuerySeter execution result number
func (o *querySet) Count() (int64, error) { func (o *querySet) Count() (int64, error) {
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)

View File

@ -493,19 +493,33 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
} }
} }
} else { } else {
for i := 0; i < ind.NumField(); i++ { // define recursive function
f := ind.Field(i) var recursiveSetField func(rv reflect.Value)
fe := ind.Type().Field(i) recursiveSetField = func(rv reflect.Value) {
_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) for i := 0; i < rv.NumField(); i++ {
var col string f := rv.Field(i)
if col = tags["column"]; col == "" { fe := rv.Type().Field(i)
col = snakeString(fe.Name)
} // check if the field is a Struct
if v, ok := columnsMp[col]; ok { // recursive the Struct type
value := reflect.ValueOf(v).Elem().Interface() if fe.Type.Kind() == reflect.Struct {
o.setFieldValue(f, value) recursiveSetField(f)
}
_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
var col string
if col = tags["column"]; col == "" {
col = snakeString(fe.Name)
}
if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface()
o.setFieldValue(f, value)
}
} }
} }
// init call the recursive function
recursiveSetField(ind)
} }
if eTyps[0].Kind() == reflect.Ptr { if eTyps[0].Kind() == reflect.Ptr {
@ -671,7 +685,7 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
ind *reflect.Value ind *reflect.Value
) )
typ := 0 var typ int
switch container.(type) { switch container.(type) {
case *Params: case *Params:
typ = 1 typ = 1

View File

@ -93,14 +93,14 @@ wrongArg:
} }
func AssertIs(a interface{}, args ...interface{}) error { func AssertIs(a interface{}, args ...interface{}) error {
if ok, err := ValuesCompare(true, a, args...); ok == false { if ok, err := ValuesCompare(true, a, args...); !ok {
return err return err
} }
return nil return nil
} }
func AssertNot(a interface{}, args ...interface{}) error { func AssertNot(a interface{}, args ...interface{}) error {
if ok, err := ValuesCompare(false, a, args...); ok == false { if ok, err := ValuesCompare(false, a, args...); !ok {
return err return err
} }
return nil return nil
@ -135,7 +135,7 @@ func getCaller(skip int) string {
if i := strings.LastIndex(funName, "."); i > -1 { if i := strings.LastIndex(funName, "."); i > -1 {
funName = funName[i+1:] funName = funName[i+1:]
} }
return fmt.Sprintf("%s:%d: \n%s", fn, line, strings.Join(codes, "\n")) return fmt.Sprintf("%s:%s:%d: \n%s", fn, funName, line, strings.Join(codes, "\n"))
} }
func throwFail(t *testing.T, err error, args ...interface{}) { func throwFail(t *testing.T, err error, args ...interface{}) {
@ -193,6 +193,7 @@ func TestSyncDb(t *testing.T) {
RegisterModel(new(InLineOneToOne)) RegisterModel(new(InLineOneToOne))
RegisterModel(new(IntegerPk)) RegisterModel(new(IntegerPk))
RegisterModel(new(UintPk)) RegisterModel(new(UintPk))
RegisterModel(new(PtrPk))
err := RunSyncdb("default", true, Debug) err := RunSyncdb("default", true, Debug)
throwFail(t, err) throwFail(t, err)
@ -216,6 +217,7 @@ func TestRegisterModels(t *testing.T) {
RegisterModel(new(InLineOneToOne)) RegisterModel(new(InLineOneToOne))
RegisterModel(new(IntegerPk)) RegisterModel(new(IntegerPk))
RegisterModel(new(UintPk)) RegisterModel(new(UintPk))
RegisterModel(new(PtrPk))
BootStrap() BootStrap()
@ -1012,6 +1014,8 @@ func TestAll(t *testing.T) {
var users3 []*User var users3 []*User
qs = dORM.QueryTable("user") qs = dORM.QueryTable("user")
num, err = qs.Filter("user_name", "nothing").All(&users3) num, err = qs.Filter("user_name", "nothing").All(&users3)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 0))
throwFailNow(t, AssertIs(users3 == nil, false)) throwFailNow(t, AssertIs(users3 == nil, false))
} }
@ -1136,6 +1140,7 @@ func TestRelatedSel(t *testing.T) {
} }
err = qs.Filter("user_name", "nobody").RelatedSel("profile").One(&user) err = qs.Filter("user_name", "nobody").RelatedSel("profile").One(&user)
throwFail(t, err)
throwFail(t, AssertIs(num, 1)) throwFail(t, AssertIs(num, 1))
throwFail(t, AssertIs(user.Profile, nil)) throwFail(t, AssertIs(user.Profile, nil))
@ -1244,20 +1249,24 @@ func TestLoadRelated(t *testing.T) {
num, err = dORM.LoadRelated(&user, "Posts", true) num, err = dORM.LoadRelated(&user, "Posts", true)
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 2))
throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(len(user.Posts), 2))
throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie")) throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie"))
num, err = dORM.LoadRelated(&user, "Posts", true, 1) num, err = dORM.LoadRelated(&user, "Posts", true, 1)
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(len(user.Posts), 1)) throwFailNow(t, AssertIs(len(user.Posts), 1))
num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id") num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id")
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 2))
throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(len(user.Posts), 2))
throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting"))
num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id") num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id")
throwFailNow(t, err) throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(len(user.Posts), 1)) throwFailNow(t, AssertIs(len(user.Posts), 1))
throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting"))
@ -1652,6 +1661,13 @@ func TestRawQueryRow(t *testing.T) {
throwFail(t, AssertIs(pid, nil)) throwFail(t, AssertIs(pid, nil))
} }
// user_profile table
type userProfile struct {
User
Age int
Money float64
}
func TestQueryRows(t *testing.T) { func TestQueryRows(t *testing.T) {
Q := dDbBaser.TableQuote() Q := dDbBaser.TableQuote()
@ -1722,6 +1738,19 @@ func TestQueryRows(t *testing.T) {
throwFailNow(t, AssertIs(usernames[1], "astaxie")) throwFailNow(t, AssertIs(usernames[1], "astaxie"))
throwFailNow(t, AssertIs(ids[2], 4)) throwFailNow(t, AssertIs(ids[2], 4))
throwFailNow(t, AssertIs(usernames[2], "nobody")) throwFailNow(t, AssertIs(usernames[2], "nobody"))
//test query rows by nested struct
var l []userProfile
query = fmt.Sprintf("SELECT * FROM %suser_profile%s LEFT JOIN %suser%s ON %suser_profile%s.%sid%s = %suser%s.%sid%s", Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q)
num, err = dORM.Raw(query).QueryRows(&l)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 2))
throwFailNow(t, AssertIs(len(l), 2))
throwFailNow(t, AssertIs(l[0].UserName, "slene"))
throwFailNow(t, AssertIs(l[0].Age, 28))
throwFailNow(t, AssertIs(l[1].UserName, "astaxie"))
throwFailNow(t, AssertIs(l[1].Age, 30))
} }
func TestRawValues(t *testing.T) { func TestRawValues(t *testing.T) {
@ -1974,6 +2003,7 @@ func TestReadOrCreate(t *testing.T) {
created, pk, err := dORM.ReadOrCreate(u, "UserName") created, pk, err := dORM.ReadOrCreate(u, "UserName")
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(created, true)) throwFail(t, AssertIs(created, true))
throwFail(t, AssertIs(u.ID, pk))
throwFail(t, AssertIs(u.UserName, "Kyle")) throwFail(t, AssertIs(u.UserName, "Kyle"))
throwFail(t, AssertIs(u.Email, "kylemcc@gmail.com")) throwFail(t, AssertIs(u.Email, "kylemcc@gmail.com"))
throwFail(t, AssertIs(u.Password, "other_pass")) throwFail(t, AssertIs(u.Password, "other_pass"))
@ -2128,13 +2158,13 @@ func TestUintPk(t *testing.T) {
Name: name, Name: name,
} }
created, pk, err := dORM.ReadOrCreate(u, "ID") created, _, err := dORM.ReadOrCreate(u, "ID")
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(created, true)) throwFail(t, AssertIs(created, true))
throwFail(t, AssertIs(u.Name, name)) throwFail(t, AssertIs(u.Name, name))
nu := &UintPk{ID: 8} nu := &UintPk{ID: 8}
created, pk, err = dORM.ReadOrCreate(nu, "ID") created, pk, err := dORM.ReadOrCreate(nu, "ID")
throwFail(t, err) throwFail(t, err)
throwFail(t, AssertIs(created, false)) throwFail(t, AssertIs(created, false))
throwFail(t, AssertIs(nu.ID, u.ID)) throwFail(t, AssertIs(nu.ID, u.ID))
@ -2144,6 +2174,48 @@ func TestUintPk(t *testing.T) {
dORM.Delete(u) dORM.Delete(u)
} }
func TestPtrPk(t *testing.T) {
parent := &IntegerPk{ID: 10, Value: "10"}
id, _ := dORM.Insert(parent)
if !IsMysql {
// MySql does not support last_insert_id in this case: see #2382
throwFail(t, AssertIs(id, 10))
}
ptr := PtrPk{ID: parent, Positive: true}
num, err := dORM.InsertMulti(2, []PtrPk{ptr})
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
throwFail(t, AssertIs(ptr.ID, parent))
nptr := &PtrPk{ID: parent}
created, pk, err := dORM.ReadOrCreate(nptr, "ID")
throwFail(t, err)
throwFail(t, AssertIs(created, false))
throwFail(t, AssertIs(pk, 10))
throwFail(t, AssertIs(nptr.ID, parent))
throwFail(t, AssertIs(nptr.Positive, true))
nptr = &PtrPk{Positive: true}
created, pk, err = dORM.ReadOrCreate(nptr, "Positive")
throwFail(t, err)
throwFail(t, AssertIs(created, false))
throwFail(t, AssertIs(pk, 10))
throwFail(t, AssertIs(nptr.ID, parent))
nptr.Positive = false
num, err = dORM.Update(nptr)
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
throwFail(t, AssertIs(nptr.ID, parent))
throwFail(t, AssertIs(nptr.Positive, false))
num, err = dORM.Delete(nptr)
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
}
func TestSnake(t *testing.T) { func TestSnake(t *testing.T) {
cases := map[string]string{ cases := map[string]string{
"i": "i", "i": "i",

View File

@ -145,6 +145,16 @@ type QuerySeter interface {
// //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
// num, err := qs.SetCond(cond1).Count() // num, err := qs.SetCond(cond1).Count()
SetCond(*Condition) QuerySeter SetCond(*Condition) QuerySeter
// get condition from QuerySeter.
// sql's where condition
// cond := orm.NewCondition()
// cond = cond.And("profile__isnull", false).AndNot("status__in", 1)
// qs = qs.SetCond(cond)
// cond = qs.GetCond()
// cond := cond.Or("profile__age__gt", 2000)
// //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
// num, err := qs.SetCond(cond).Count()
GetCond() *Condition
// add LIMIT value. // add LIMIT value.
// args[0] means offset, e.g. LIMIT num,offset. // args[0] means offset, e.g. LIMIT num,offset.
// if Limit <= 0 then Limit will be set to default limit ,eg 1000 // if Limit <= 0 then Limit will be set to default limit ,eg 1000

View File

@ -92,11 +92,11 @@ func (f StrTo) Int64() (int64, error) {
i := new(big.Int) i := new(big.Int)
ni, ok := i.SetString(f.String(), 10) // octal ni, ok := i.SetString(f.String(), 10) // octal
if !ok { if !ok {
return int64(v), err return v, err
} }
return ni.Int64(), nil return ni.Int64(), nil
} }
return int64(v), err return v, err
} }
// Uint string to uint // Uint string to uint
@ -130,11 +130,11 @@ func (f StrTo) Uint64() (uint64, error) {
i := new(big.Int) i := new(big.Int)
ni, ok := i.SetString(f.String(), 10) ni, ok := i.SetString(f.String(), 10)
if !ok { if !ok {
return uint64(v), err return v, err
} }
return ni.Uint64(), nil return ni.Uint64(), nil
} }
return uint64(v), err return v, err
} }
// String string to string // String string to string
@ -219,22 +219,17 @@ func snakeString(s string) string {
// camel string, xx_yy to XxYy // camel string, xx_yy to XxYy
func camelString(s string) string { func camelString(s string) string {
data := make([]byte, 0, len(s)) data := make([]byte, 0, len(s))
j := false flag, num := true, len(s)-1
k := false
num := len(s) - 1
for i := 0; i <= num; i++ { for i := 0; i <= num; i++ {
d := s[i] d := s[i]
if k == false && d >= 'A' && d <= 'Z' { if d == '_' {
k = true flag = true
}
if d >= 'a' && d <= 'z' && (j || k == false) {
d = d - 32
j = false
k = true
}
if k && d == '_' && num > i && s[i+1] >= 'a' && s[i+1] <= 'z' {
j = true
continue continue
} else if flag {
if d >= 'a' && d <= 'z' {
d = d - 32
}
flag = false
} }
data = append(data, d) data = append(data, d)
} }

36
orm/utils_test.go Normal file
View File

@ -0,0 +1,36 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"testing"
)
func TestCamelString(t *testing.T) {
snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"}
camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"}
answer := make(map[string]string)
for i, v := range snake {
answer[v] = camel[i]
}
for _, v := range snake {
res := camelString(v)
if res != answer[v] {
t.Error("Unit Test Fail:", v, res, answer[v])
}
}
}

212
parser.go
View File

@ -24,9 +24,13 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"sort" "sort"
"strconv"
"strings" "strings"
"unicode"
"github.com/astaxie/beego/context/param"
"github.com/astaxie/beego/logs" "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
@ -35,6 +39,7 @@ var globalRouterTemplate = `package routers
import ( import (
"github.com/astaxie/beego" "github.com/astaxie/beego"
"github.com/astaxie/beego/context/param"
) )
func init() { func init() {
@ -81,7 +86,7 @@ func parserPkg(pkgRealpath, pkgpath string) error {
if specDecl.Recv != nil { if specDecl.Recv != nil {
exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser
if ok { if ok {
parserComments(specDecl.Doc, specDecl.Name.String(), fmt.Sprint(exp.X), pkgpath) parserComments(specDecl, fmt.Sprint(exp.X), pkgpath)
} }
} }
} }
@ -93,44 +98,170 @@ func parserPkg(pkgRealpath, pkgpath string) error {
return nil return nil
} }
func parserComments(comments *ast.CommentGroup, funcName, controllerName, pkgpath string) error { type parsedComment struct {
if comments != nil && comments.List != nil { routerPath string
for _, c := range comments.List { methods []string
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) params map[string]parsedParam
if strings.HasPrefix(t, "@router") { }
elements := strings.TrimLeft(t, "@router ")
e1 := strings.SplitN(elements, " ", 2) type parsedParam struct {
if len(e1) < 1 { name string
return errors.New("you should has router information") datatype string
} location string
key := pkgpath + ":" + controllerName defValue string
cc := ControllerComments{} required bool
cc.Method = funcName }
cc.Router = e1[0]
if len(e1) == 2 && e1[1] != "" { func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error {
e1 = strings.SplitN(e1[1], " ", 2) if f.Doc != nil {
if len(e1) >= 1 { parsedComment, err := parseComment(f.Doc.List)
cc.AllowHTTPMethods = strings.Split(strings.Trim(e1[0], "[]"), ",") if err != nil {
} else { return err
cc.AllowHTTPMethods = append(cc.AllowHTTPMethods, "get")
}
} else {
cc.AllowHTTPMethods = append(cc.AllowHTTPMethods, "get")
}
if len(e1) == 2 && e1[1] != "" {
keyval := strings.Split(strings.Trim(e1[1], "[]"), " ")
for _, kv := range keyval {
kk := strings.Split(kv, ":")
cc.Params = append(cc.Params, map[string]string{strings.Join(kk[:len(kk)-1], ":"): kk[len(kk)-1]})
}
}
genInfoList[key] = append(genInfoList[key], cc)
}
} }
if parsedComment.routerPath != "" {
key := pkgpath + ":" + controllerName
cc := ControllerComments{}
cc.Method = f.Name.String()
cc.Router = parsedComment.routerPath
cc.AllowHTTPMethods = parsedComment.methods
cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment)
genInfoList[key] = append(genInfoList[key], cc)
}
} }
return nil return nil
} }
func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam {
result := make([]*param.MethodParam, 0, len(funcParams))
for _, fparam := range funcParams {
for _, pName := range fparam.Names {
methodParam := buildMethodParam(fparam, pName.Name, pc)
result = append(result, methodParam)
}
}
return result
}
func buildMethodParam(fparam *ast.Field, name string, pc *parsedComment) *param.MethodParam {
options := []param.MethodParamOption{}
if cparam, ok := pc.params[name]; ok {
//Build param from comment info
name = cparam.name
if cparam.required {
options = append(options, param.IsRequired)
}
switch cparam.location {
case "body":
options = append(options, param.InBody)
case "header":
options = append(options, param.InHeader)
case "path":
options = append(options, param.InPath)
}
if cparam.defValue != "" {
options = append(options, param.Default(cparam.defValue))
}
} else {
if paramInPath(name, pc.routerPath) {
options = append(options, param.InPath)
}
}
return param.New(name, options...)
}
func paramInPath(name, route string) bool {
return strings.HasSuffix(route, ":"+name) ||
strings.Contains(route, ":"+name+"/")
}
var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`)
func parseComment(lines []*ast.Comment) (pc *parsedComment, err error) {
pc = &parsedComment{}
for _, c := range lines {
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
if strings.HasPrefix(t, "@router") {
matches := routeRegex.FindStringSubmatch(t)
if len(matches) == 3 {
pc.routerPath = matches[1]
methods := matches[2]
if methods == "" {
pc.methods = []string{"get"}
//pc.hasGet = true
} else {
pc.methods = strings.Split(methods, ",")
//pc.hasGet = strings.Contains(methods, "get")
}
} else {
return nil, errors.New("Router information is missing")
}
} else if strings.HasPrefix(t, "@Param") {
pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param")))
if len(pv) < 4 {
logs.Error("Invalid @Param format. Needs at least 4 parameters")
}
p := parsedParam{}
names := strings.SplitN(pv[0], "=>", 2)
p.name = names[0]
funcParamName := p.name
if len(names) > 1 {
funcParamName = names[1]
}
p.location = pv[1]
p.datatype = pv[2]
switch len(pv) {
case 5:
p.required, _ = strconv.ParseBool(pv[3])
case 6:
p.defValue = pv[3]
p.required, _ = strconv.ParseBool(pv[4])
}
if pc.params == nil {
pc.params = map[string]parsedParam{}
}
pc.params[funcParamName] = p
}
}
return
}
// direct copy from bee\g_docs.go
// analysis params return []string
// @Param query form string true "The email for login"
// [query form string true "The email for login"]
func getparams(str string) []string {
var s []rune
var j int
var start bool
var r []string
var quoted int8
for _, c := range str {
if unicode.IsSpace(c) && quoted == 0 {
if !start {
continue
} else {
start = false
j++
r = append(r, string(s))
s = make([]rune, 0)
continue
}
}
start = true
if c == '"' {
quoted ^= 1
continue
}
s = append(s, c)
}
if len(s) > 0 {
r = append(r, string(s))
}
return r
}
func genRouterCode(pkgRealpath string) { func genRouterCode(pkgRealpath string) {
os.Mkdir(getRouterDir(pkgRealpath), 0755) os.Mkdir(getRouterDir(pkgRealpath), 0755)
logs.Info("generate router from comments") logs.Info("generate router from comments")
@ -144,6 +275,7 @@ func genRouterCode(pkgRealpath string) {
sort.Strings(sortKey) sort.Strings(sortKey)
for _, k := range sortKey { for _, k := range sortKey {
cList := genInfoList[k] cList := genInfoList[k]
sort.Sort(ControllerCommentsSlice(cList))
for _, c := range cList { for _, c := range cList {
allmethod := "nil" allmethod := "nil"
if len(c.AllowHTTPMethods) > 0 { if len(c.AllowHTTPMethods) > 0 {
@ -163,12 +295,24 @@ func genRouterCode(pkgRealpath string) {
} }
params = strings.TrimRight(params, ",") + "}" params = strings.TrimRight(params, ",") + "}"
} }
methodParams := "param.Make("
if len(c.MethodParams) > 0 {
lines := make([]string, 0, len(c.MethodParams))
for _, m := range c.MethodParams {
lines = append(lines, fmt.Sprint(m))
}
methodParams += "\n " +
strings.Join(lines, ",\n ") +
",\n "
}
methodParams += ")"
globalinfo = globalinfo + ` globalinfo = globalinfo + `
beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"],
beego.ControllerComments{ beego.ControllerComments{
Method: "` + strings.TrimSpace(c.Method) + `", Method: "` + strings.TrimSpace(c.Method) + `",
` + "Router: `" + c.Router + "`" + `, ` + "Router: `" + c.Router + "`" + `,
AllowHTTPMethods: ` + allmethod + `, AllowHTTPMethods: ` + allmethod + `,
MethodParams: ` + methodParams + `,
Params: ` + params + `}) Params: ` + params + `})
` `
} }

View File

@ -56,6 +56,7 @@
package apiauth package apiauth
import ( import (
"bytes"
"crypto/hmac" "crypto/hmac"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
@ -128,53 +129,32 @@ func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc {
// Signature used to generate signature with the appsecret/method/params/RequestURI // Signature used to generate signature with the appsecret/method/params/RequestURI
func Signature(appsecret, method string, params url.Values, RequestURL string) (result string) { func Signature(appsecret, method string, params url.Values, RequestURL string) (result string) {
var query string var b bytes.Buffer
keys := make([]string, len(params))
pa := make(map[string]string) pa := make(map[string]string)
for k, v := range params { for k, v := range params {
pa[k] = v[0] pa[k] = v[0]
keys = append(keys, k)
} }
vs := mapSorter(pa)
vs.Sort() sort.Strings(keys)
for i := 0; i < vs.Len(); i++ {
if vs.Keys[i] == "signature" { for _, key := range keys {
if key == "signature" {
continue continue
} }
if vs.Keys[i] != "" && vs.Vals[i] != "" {
query = fmt.Sprintf("%v%v%v", query, vs.Keys[i], vs.Vals[i]) val := pa[key]
if key != "" && val != "" {
b.WriteString(key)
b.WriteString(val)
} }
} }
stringToSign := fmt.Sprintf("%v\n%v\n%v\n", method, query, RequestURL)
stringToSign := fmt.Sprintf("%v\n%v\n%v\n", method, b.String(), RequestURL)
sha256 := sha256.New sha256 := sha256.New
hash := hmac.New(sha256, []byte(appsecret)) hash := hmac.New(sha256, []byte(appsecret))
hash.Write([]byte(stringToSign)) hash.Write([]byte(stringToSign))
return base64.StdEncoding.EncodeToString(hash.Sum(nil)) return base64.StdEncoding.EncodeToString(hash.Sum(nil))
} }
type valSorter struct {
Keys []string
Vals []string
}
func mapSorter(m map[string]string) *valSorter {
vs := &valSorter{
Keys: make([]string, 0, len(m)),
Vals: make([]string, 0, len(m)),
}
for k, v := range m {
vs.Keys = append(vs.Keys, k)
vs.Vals = append(vs.Vals, v)
}
return vs
}
func (vs *valSorter) Sort() {
sort.Sort(vs)
}
func (vs *valSorter) Len() int { return len(vs.Keys) }
func (vs *valSorter) Less(i, j int) bool { return vs.Keys[i] < vs.Keys[j] }
func (vs *valSorter) Swap(i, j int) {
vs.Vals[i], vs.Vals[j] = vs.Vals[j], vs.Vals[i]
vs.Keys[i], vs.Keys[j] = vs.Keys[j], vs.Keys[i]
}

View File

@ -0,0 +1,20 @@
package apiauth
import (
"net/url"
"testing"
)
func TestSignature(t *testing.T) {
appsecret := "beego secret"
method := "GET"
RequestURL := "http://localhost/test/url"
params := make(url.Values)
params.Add("arg1", "hello")
params.Add("arg2", "beego")
signature := "mFdpvLh48ca4mDVEItE9++AKKQ/IVca7O/ZyyB8hR58="
if Signature(appsecret, method, params, RequestURL) != signature {
t.Error("Signature error")
}
}

86
plugins/authz/authz.go Normal file
View File

@ -0,0 +1,86 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package authz provides handlers to enable ACL, RBAC, ABAC authorization support.
// Simple Usage:
// import(
// "github.com/astaxie/beego"
// "github.com/astaxie/beego/plugins/authz"
// "github.com/casbin/casbin"
// )
//
// func main(){
// // mediate the access for every request
// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
// beego.Run()
// }
//
//
// Advanced Usage:
//
// func main(){
// e := casbin.NewEnforcer("authz_model.conf", "")
// e.AddRoleForUser("alice", "admin")
// e.AddPolicy(...)
//
// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(e))
// beego.Run()
// }
package authz
import (
"github.com/astaxie/beego"
"github.com/astaxie/beego/context"
"github.com/casbin/casbin"
"net/http"
)
// NewAuthorizer returns the authorizer.
// Use a casbin enforcer as input
func NewAuthorizer(e *casbin.Enforcer) beego.FilterFunc {
return func(ctx *context.Context) {
a := &BasicAuthorizer{enforcer: e}
if !a.CheckPermission(ctx.Request) {
a.RequirePermission(ctx.ResponseWriter)
}
}
}
// BasicAuthorizer stores the casbin handler
type BasicAuthorizer struct {
enforcer *casbin.Enforcer
}
// GetUserName gets the user name from the request.
// Currently, only HTTP basic authentication is supported
func (a *BasicAuthorizer) GetUserName(r *http.Request) string {
username, _, _ := r.BasicAuth()
return username
}
// CheckPermission checks the user/method/path combination from the request.
// Returns true (permission granted) or false (permission forbidden)
func (a *BasicAuthorizer) CheckPermission(r *http.Request) bool {
user := a.GetUserName(r)
method := r.Method
path := r.URL.Path
return a.enforcer.Enforce(user, path, method)
}
// RequirePermission returns the 403 Forbidden to the client
func (a *BasicAuthorizer) RequirePermission(w http.ResponseWriter) {
w.WriteHeader(403)
w.Write([]byte("403 Forbidden\n"))
}

View File

@ -0,0 +1,14 @@
[request_definition]
r = sub, obj, act
[policy_definition]
p = sub, obj, act
[role_definition]
g = _, _
[policy_effect]
e = some(where (p.eft == allow))
[matchers]
m = g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*")

View File

@ -0,0 +1,7 @@
p, alice, /dataset1/*, GET
p, alice, /dataset1/resource1, POST
p, bob, /dataset2/resource1, *
p, bob, /dataset2/resource2, GET
p, bob, /dataset2/folder1/*, POST
p, dataset1_admin, /dataset1/*, *
g, cathy, dataset1_admin
1 p, alice, /dataset1/*, GET
2 p, alice, /dataset1/resource1, POST
3 p, bob, /dataset2/resource1, *
4 p, bob, /dataset2/resource2, GET
5 p, bob, /dataset2/folder1/*, POST
6 p, dataset1_admin, /dataset1/*, *
7 g, cathy, dataset1_admin

107
plugins/authz/authz_test.go Normal file
View File

@ -0,0 +1,107 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package authz
import (
"github.com/astaxie/beego"
"github.com/astaxie/beego/context"
"github.com/astaxie/beego/plugins/auth"
"github.com/casbin/casbin"
"net/http"
"net/http/httptest"
"testing"
)
func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, path string, method string, code int) {
r, _ := http.NewRequest(method, path, nil)
r.SetBasicAuth(user, "123")
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
if w.Code != code {
t.Errorf("%s, %s, %s: %d, supposed to be %d", user, path, method, w.Code, code)
}
}
func TestBasic(t *testing.T) {
handler := beego.NewControllerRegister()
handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123"))
handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
handler.Any("*", func(ctx *context.Context) {
ctx.Output.SetStatus(200)
})
testRequest(t, handler, "alice", "/dataset1/resource1", "GET", 200)
testRequest(t, handler, "alice", "/dataset1/resource1", "POST", 200)
testRequest(t, handler, "alice", "/dataset1/resource2", "GET", 200)
testRequest(t, handler, "alice", "/dataset1/resource2", "POST", 403)
}
func TestPathWildcard(t *testing.T) {
handler := beego.NewControllerRegister()
handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123"))
handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
handler.Any("*", func(ctx *context.Context) {
ctx.Output.SetStatus(200)
})
testRequest(t, handler, "bob", "/dataset2/resource1", "GET", 200)
testRequest(t, handler, "bob", "/dataset2/resource1", "POST", 200)
testRequest(t, handler, "bob", "/dataset2/resource1", "DELETE", 200)
testRequest(t, handler, "bob", "/dataset2/resource2", "GET", 200)
testRequest(t, handler, "bob", "/dataset2/resource2", "POST", 403)
testRequest(t, handler, "bob", "/dataset2/resource2", "DELETE", 403)
testRequest(t, handler, "bob", "/dataset2/folder1/item1", "GET", 403)
testRequest(t, handler, "bob", "/dataset2/folder1/item1", "POST", 200)
testRequest(t, handler, "bob", "/dataset2/folder1/item1", "DELETE", 403)
testRequest(t, handler, "bob", "/dataset2/folder1/item2", "GET", 403)
testRequest(t, handler, "bob", "/dataset2/folder1/item2", "POST", 200)
testRequest(t, handler, "bob", "/dataset2/folder1/item2", "DELETE", 403)
}
func TestRBAC(t *testing.T) {
handler := beego.NewControllerRegister()
handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123"))
e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")
handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e))
handler.Any("*", func(ctx *context.Context) {
ctx.Output.SetStatus(200)
})
// cathy can access all /dataset1/* resources via all methods because it has the dataset1_admin role.
testRequest(t, handler, "cathy", "/dataset1/item", "GET", 200)
testRequest(t, handler, "cathy", "/dataset1/item", "POST", 200)
testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 200)
testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403)
testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403)
testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403)
// delete all roles on user cathy, so cathy cannot access any resources now.
e.DeleteRolesForUser("cathy")
testRequest(t, handler, "cathy", "/dataset1/item", "GET", 403)
testRequest(t, handler, "cathy", "/dataset1/item", "POST", 403)
testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 403)
testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403)
testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403)
testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403)
}

View File

@ -23,7 +23,7 @@ import (
// PolicyFunc defines a policy function which is invoked before the controller handler is executed. // PolicyFunc defines a policy function which is invoked before the controller handler is executed.
type PolicyFunc func(*context.Context) type PolicyFunc func(*context.Context)
// FindRouter Find Router info for URL // FindPolicy Find Router info for URL
func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc { func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc {
var urlPath = cont.Input.URL() var urlPath = cont.Input.URL()
if !BConfig.RouterCaseSensitive { if !BConfig.RouterCaseSensitive {
@ -71,7 +71,7 @@ func (p *ControllerRegister) addToPolicy(method, pattern string, r ...PolicyFunc
} }
} }
// Register new policy in beego // Policy Register new policy in beego
func Policy(pattern, method string, policy ...PolicyFunc) { func Policy(pattern, method string, policy ...PolicyFunc) {
BeeApp.Handlers.addToPolicy(method, pattern, policy...) BeeApp.Handlers.addToPolicy(method, pattern, policy...)
} }

134
router.go
View File

@ -17,7 +17,6 @@ package beego
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"os"
"path" "path"
"path/filepath" "path/filepath"
"reflect" "reflect"
@ -28,6 +27,7 @@ import (
"time" "time"
beecontext "github.com/astaxie/beego/context" beecontext "github.com/astaxie/beego/context"
"github.com/astaxie/beego/context/param"
"github.com/astaxie/beego/logs" "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/toolbox" "github.com/astaxie/beego/toolbox"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
@ -51,15 +51,22 @@ const (
var ( var (
// HTTPMETHOD list the supported http methods. // HTTPMETHOD list the supported http methods.
HTTPMETHOD = map[string]string{ HTTPMETHOD = map[string]string{
"GET": "GET", "GET": "GET",
"POST": "POST", "POST": "POST",
"PUT": "PUT", "PUT": "PUT",
"DELETE": "DELETE", "DELETE": "DELETE",
"PATCH": "PATCH", "PATCH": "PATCH",
"OPTIONS": "OPTIONS", "OPTIONS": "OPTIONS",
"HEAD": "HEAD", "HEAD": "HEAD",
"TRACE": "TRACE", "TRACE": "TRACE",
"CONNECT": "CONNECT", "CONNECT": "CONNECT",
"MKCOL": "MKCOL",
"COPY": "COPY",
"MOVE": "MOVE",
"PROPFIND": "PROPFIND",
"PROPPATCH": "PROPPATCH",
"LOCK": "LOCK",
"UNLOCK": "UNLOCK",
} }
// these beego.Controller's methods shouldn't reflect to AutoRouter // these beego.Controller's methods shouldn't reflect to AutoRouter
exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString", exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString",
@ -102,13 +109,15 @@ func ExceptMethodAppend(action string) {
exceptMethod = append(exceptMethod, action) exceptMethod = append(exceptMethod, action)
} }
type controllerInfo struct { // ControllerInfo holds information about the controller.
type ControllerInfo struct {
pattern string pattern string
controllerType reflect.Type controllerType reflect.Type
methods map[string]string methods map[string]string
handler http.Handler handler http.Handler
runFunction FilterFunc runFunction FilterFunc
routerType int routerType int
methodParams []*param.MethodParam
} }
// ControllerRegister containers registered router rules, controller handlers and filters. // ControllerRegister containers registered router rules, controller handlers and filters.
@ -144,6 +153,10 @@ func NewControllerRegister() *ControllerRegister {
// Add("/api",&RestController{},"get,post:ApiFunc" // Add("/api",&RestController{},"get,post:ApiFunc"
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
p.addWithMethodParams(pattern, c, nil, mappingMethods...)
}
func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) {
reflectVal := reflect.ValueOf(c) reflectVal := reflect.ValueOf(c)
t := reflect.Indirect(reflectVal).Type() t := reflect.Indirect(reflectVal).Type()
methods := make(map[string]string) methods := make(map[string]string)
@ -169,11 +182,12 @@ func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingM
} }
} }
route := &controllerInfo{} route := &ControllerInfo{}
route.pattern = pattern route.pattern = pattern
route.methods = methods route.methods = methods
route.routerType = routerTypeBeego route.routerType = routerTypeBeego
route.controllerType = t route.controllerType = t
route.methodParams = methodParams
if len(methods) == 0 { if len(methods) == 0 {
for _, m := range HTTPMETHOD { for _, m := range HTTPMETHOD {
p.addToRouter(m, pattern, route) p.addToRouter(m, pattern, route)
@ -191,7 +205,7 @@ func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingM
} }
} }
func (p *ControllerRegister) addToRouter(method, pattern string, r *controllerInfo) { func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) {
if !BConfig.RouterCaseSensitive { if !BConfig.RouterCaseSensitive {
pattern = strings.ToLower(pattern) pattern = strings.ToLower(pattern)
} }
@ -212,13 +226,11 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) {
for _, c := range cList { for _, c := range cList {
reflectVal := reflect.ValueOf(c) reflectVal := reflect.ValueOf(c)
t := reflect.Indirect(reflectVal).Type() t := reflect.Indirect(reflectVal).Type()
gopath := os.Getenv("GOPATH") wgopath := utils.GetGOPATHs()
if gopath == "" { if len(wgopath) == 0 {
panic("you are in dev mode. So please set gopath") panic("you are in dev mode. So please set gopath")
} }
pkgpath := "" pkgpath := ""
wgopath := filepath.SplitList(gopath)
for _, wg := range wgopath { for _, wg := range wgopath {
wg, _ = filepath.EvalSymlinks(filepath.Join(wg, "src", t.PkgPath())) wg, _ = filepath.EvalSymlinks(filepath.Join(wg, "src", t.PkgPath()))
if utils.FileExists(wg) { if utils.FileExists(wg) {
@ -240,7 +252,7 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) {
key := t.PkgPath() + ":" + t.Name() key := t.PkgPath() + ":" + t.Name()
if comm, ok := GlobalControllerRouter[key]; ok { if comm, ok := GlobalControllerRouter[key]; ok {
for _, a := range comm { for _, a := range comm {
p.Add(a.Router, c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)
} }
} }
} }
@ -328,7 +340,7 @@ func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
if _, ok := HTTPMETHOD[method]; method != "*" && !ok { if _, ok := HTTPMETHOD[method]; method != "*" && !ok {
panic("not support http method: " + method) panic("not support http method: " + method)
} }
route := &controllerInfo{} route := &ControllerInfo{}
route.pattern = pattern route.pattern = pattern
route.routerType = routerTypeRESTFul route.routerType = routerTypeRESTFul
route.runFunction = f route.runFunction = f
@ -354,7 +366,7 @@ func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
// Handler add user defined Handler // Handler add user defined Handler
func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) { func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) {
route := &controllerInfo{} route := &ControllerInfo{}
route.pattern = pattern route.pattern = pattern
route.routerType = routerTypeHandler route.routerType = routerTypeHandler
route.handler = h route.handler = h
@ -389,7 +401,7 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface)
controllerName := strings.TrimSuffix(ct.Name(), "Controller") controllerName := strings.TrimSuffix(ct.Name(), "Controller")
for i := 0; i < rt.NumMethod(); i++ { for i := 0; i < rt.NumMethod(); i++ {
if !utils.InSlice(rt.Method(i).Name, exceptMethod) { if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
route := &controllerInfo{} route := &ControllerInfo{}
route.routerType = routerTypeBeego route.routerType = routerTypeBeego
route.methods = map[string]string{"*": rt.Method(i).Name} route.methods = map[string]string{"*": rt.Method(i).Name}
route.controllerType = ct route.controllerType = ct
@ -495,7 +507,7 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin
} }
} }
for _, l := range t.leaves { for _, l := range t.leaves {
if c, ok := l.runObject.(*controllerInfo); ok { if c, ok := l.runObject.(*ControllerInfo); ok {
if c.routerType == routerTypeBeego && if c.routerType == routerTypeBeego &&
strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllName) { strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllName) {
find := false find := false
@ -619,11 +631,12 @@ func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath str
func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
startTime := time.Now() startTime := time.Now()
var ( var (
runRouter reflect.Type runRouter reflect.Type
findRouter bool findRouter bool
runMethod string runMethod string
routerInfo *controllerInfo methodParams []*param.MethodParam
isRunnable bool routerInfo *ControllerInfo
isRunnable bool
) )
context := p.pool.Get().(*beecontext.Context) context := p.pool.Get().(*beecontext.Context)
context.Reset(rw, r) context.Reset(rw, r)
@ -663,7 +676,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
goto Admin goto Admin
} }
if r.Method != "GET" && r.Method != "HEAD" { if r.Method != http.MethodGet && r.Method != http.MethodHead {
if BConfig.CopyRequestBody && !context.Input.IsUpload() { if BConfig.CopyRequestBody && !context.Input.IsUpload() {
context.Input.CopyBody(BConfig.MaxMemory) context.Input.CopyBody(BConfig.MaxMemory)
} }
@ -691,7 +704,6 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
// User can define RunController and RunMethod in filter // User can define RunController and RunMethod in filter
if context.Input.RunController != nil && context.Input.RunMethod != "" { if context.Input.RunController != nil && context.Input.RunMethod != "" {
findRouter = true findRouter = true
isRunnable = true
runMethod = context.Input.RunMethod runMethod = context.Input.RunMethod
runRouter = context.Input.RunController runRouter = context.Input.RunController
} else { } else {
@ -735,12 +747,13 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
routerInfo.handler.ServeHTTP(rw, r) routerInfo.handler.ServeHTTP(rw, r)
} else { } else {
runRouter = routerInfo.controllerType runRouter = routerInfo.controllerType
methodParams = routerInfo.methodParams
method := r.Method method := r.Method
if r.Method == "POST" && context.Input.Query("_method") == "PUT" { if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPost {
method = "PUT" method = http.MethodPut
} }
if r.Method == "POST" && context.Input.Query("_method") == "DELETE" { if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodDelete {
method = "DELETE" method = http.MethodDelete
} }
if m, ok := routerInfo.methods[method]; ok { if m, ok := routerInfo.methods[method]; ok {
runMethod = m runMethod = m
@ -770,8 +783,8 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
//if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf //if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf
if BConfig.WebConfig.EnableXSRF { if BConfig.WebConfig.EnableXSRF {
execController.XSRFToken() execController.XSRFToken()
if r.Method == "POST" || r.Method == "DELETE" || r.Method == "PUT" || if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut ||
(r.Method == "POST" && (context.Input.Query("_method") == "DELETE" || context.Input.Query("_method") == "PUT")) { (r.Method == http.MethodPost && (context.Input.Query("_method") == http.MethodDelete || context.Input.Query("_method") == http.MethodPut)) {
execController.CheckXSRFCookie() execController.CheckXSRFCookie()
} }
} }
@ -781,25 +794,30 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if !context.ResponseWriter.Started { if !context.ResponseWriter.Started {
//exec main logic //exec main logic
switch runMethod { switch runMethod {
case "GET": case http.MethodGet:
execController.Get() execController.Get()
case "POST": case http.MethodPost:
execController.Post() execController.Post()
case "DELETE": case http.MethodDelete:
execController.Delete() execController.Delete()
case "PUT": case http.MethodPut:
execController.Put() execController.Put()
case "HEAD": case http.MethodHead:
execController.Head() execController.Head()
case "PATCH": case http.MethodPatch:
execController.Patch() execController.Patch()
case "OPTIONS": case http.MethodOptions:
execController.Options() execController.Options()
default: default:
if !execController.HandlerFunc(runMethod) { if !execController.HandlerFunc(runMethod) {
var in []reflect.Value
method := vc.MethodByName(runMethod) method := vc.MethodByName(runMethod)
method.Call(in) in := param.ConvertParams(methodParams, method.Type(), context)
out := method.Call(in)
//For backward compatibility we only handle response if we had incoming methodParams
if methodParams != nil {
p.handleParamResponse(context, execController, out)
}
} }
} }
@ -830,7 +848,15 @@ Admin:
//admin module record QPS //admin module record QPS
if BConfig.Listen.EnableAdmin { if BConfig.Listen.EnableAdmin {
timeDur := time.Since(startTime) timeDur := time.Since(startTime)
if FilterMonitorFunc(r.Method, r.URL.Path, timeDur) { pattern := ""
if routerInfo != nil {
pattern = routerInfo.pattern
}
statusCode := context.ResponseWriter.Status
if statusCode == 0 {
statusCode = 200
}
if FilterMonitorFunc(r.Method, r.URL.Path, timeDur, pattern, statusCode) {
if runRouter != nil { if runRouter != nil {
go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runRouter.Name(), timeDur) go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runRouter.Name(), timeDur)
} else { } else {
@ -879,8 +905,22 @@ Admin:
} }
} }
func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) {
//looping in reverse order for the case when both error and value are returned and error sets the response status code
for i := len(results) - 1; i >= 0; i-- {
result := results[i]
if result.Kind() != reflect.Interface || !result.IsNil() {
resultValue := result.Interface()
context.RenderMethodResult(resultValue)
}
}
if !context.ResponseWriter.Started && context.Output.Status == 0 {
context.Output.SetStatus(200)
}
}
// FindRouter Find Router info for URL // FindRouter Find Router info for URL
func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *controllerInfo, isFind bool) { func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) {
var urlPath = context.Input.URL() var urlPath = context.Input.URL()
if !BConfig.RouterCaseSensitive { if !BConfig.RouterCaseSensitive {
urlPath = strings.ToLower(urlPath) urlPath = strings.ToLower(urlPath)
@ -888,7 +928,7 @@ func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo
httpMethod := context.Input.Method() httpMethod := context.Input.Method()
if t, ok := p.routers[httpMethod]; ok { if t, ok := p.routers[httpMethod]; ok {
runObject := t.Match(urlPath, context) runObject := t.Match(urlPath, context)
if r, ok := runObject.(*controllerInfo); ok { if r, ok := runObject.(*ControllerInfo); ok {
return r, true return r, true
} }
} }

Some files were not shown because too many files have changed in this diff Show More