mirror of
https://github.com/astaxie/beego.git
synced 2024-11-22 13:00:54 +00:00
Merge branch 'master' into master
This commit is contained in:
commit
72ec4df679
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,4 +1,5 @@
|
|||||||
.idea
|
.idea
|
||||||
|
.vscode
|
||||||
.DS_Store
|
.DS_Store
|
||||||
*.swp
|
*.swp
|
||||||
*.swo
|
*.swo
|
||||||
|
4
.gosimpleignore
Normal file
4
.gosimpleignore
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
github.com/astaxie/beego/*/*:S1012
|
||||||
|
github.com/astaxie/beego/*:S1012
|
||||||
|
github.com/astaxie/beego/*/*:S1007
|
||||||
|
github.com/astaxie/beego/*:S1007
|
19
.travis.yml
19
.travis.yml
@ -1,9 +1,9 @@
|
|||||||
language: go
|
language: go
|
||||||
|
|
||||||
go:
|
go:
|
||||||
- 1.6
|
- 1.6.4
|
||||||
- 1.5.3
|
- 1.7.5
|
||||||
- 1.4.3
|
- 1.8.1
|
||||||
services:
|
services:
|
||||||
- redis-server
|
- redis-server
|
||||||
- mysql
|
- mysql
|
||||||
@ -31,6 +31,14 @@ install:
|
|||||||
- go get github.com/siddontang/ledisdb/config
|
- go get github.com/siddontang/ledisdb/config
|
||||||
- go get github.com/siddontang/ledisdb/ledis
|
- go get github.com/siddontang/ledisdb/ledis
|
||||||
- go get github.com/ssdb/gossdb/ssdb
|
- go get github.com/ssdb/gossdb/ssdb
|
||||||
|
- go get github.com/cloudflare/golz4
|
||||||
|
- go get github.com/gogo/protobuf/proto
|
||||||
|
- go get github.com/Knetic/govaluate
|
||||||
|
- go get github.com/casbin/casbin
|
||||||
|
- go get -u honnef.co/go/tools/cmd/gosimple
|
||||||
|
- go get -u github.com/mdempsky/unconvert
|
||||||
|
- go get -u github.com/gordonklaus/ineffassign
|
||||||
|
- go get -u github.com/golang/lint/golint
|
||||||
before_script:
|
before_script:
|
||||||
- psql --version
|
- psql --version
|
||||||
- sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi"
|
- sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi"
|
||||||
@ -45,5 +53,10 @@ after_script:
|
|||||||
- rm -rf ./res/var/*
|
- rm -rf ./res/var/*
|
||||||
script:
|
script:
|
||||||
- go test -v ./...
|
- go test -v ./...
|
||||||
|
- gosimple -ignore "$(cat .gosimpleignore)" $(go list ./... | grep -v /vendor/)
|
||||||
|
- unconvert $(go list ./... | grep -v /vendor/)
|
||||||
|
- ineffassign .
|
||||||
|
- find . ! \( -path './vendor' -prune \) -type f -name '*.go' -print0 | xargs -0 gofmt -l -s
|
||||||
|
- golint ./...
|
||||||
addons:
|
addons:
|
||||||
postgresql: "9.4"
|
postgresql: "9.4"
|
||||||
|
32
README.md
32
README.md
@ -1,20 +1,17 @@
|
|||||||
## Beego
|
# Beego [![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego) [![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego) [![Foundation](https://img.shields.io/badge/Golang-Foundation-green.svg)](http://golangfoundation.org)
|
||||||
|
|
||||||
[![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego)
|
|
||||||
[![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego)
|
|
||||||
[![Foundation](https://img.shields.io/badge/Golang-Foundation-green.svg)](http://golangfoundation.org)
|
|
||||||
|
|
||||||
beego is used for rapid development of RESTful APIs, web apps and backend services in Go.
|
beego is used for rapid development of RESTful APIs, web apps and backend services in Go.
|
||||||
It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct embedding.
|
It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct embedding.
|
||||||
|
|
||||||
More info [beego.me](http://beego.me)
|
###### More info at [beego.me](http://beego.me).
|
||||||
|
|
||||||
##Quick Start
|
## Quick Start
|
||||||
######Download and install
|
|
||||||
|
#### Download and install
|
||||||
|
|
||||||
go get github.com/astaxie/beego
|
go get github.com/astaxie/beego
|
||||||
|
|
||||||
######Create file `hello.go`
|
#### Create file `hello.go`
|
||||||
```go
|
```go
|
||||||
package main
|
package main
|
||||||
|
|
||||||
@ -24,15 +21,16 @@ func main(){
|
|||||||
beego.Run()
|
beego.Run()
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
######Build and run
|
#### Build and run
|
||||||
```bash
|
|
||||||
go build hello.go
|
go build hello.go
|
||||||
./hello
|
./hello
|
||||||
```
|
|
||||||
######Congratulations!
|
#### Go to [http://localhost:8080](http://localhost:8080)
|
||||||
You just built your first beego app.
|
|
||||||
Open your browser and visit `http://localhost:8080`.
|
Congratulations! You've just built your first **beego** app.
|
||||||
Please see [Documentation](http://beego.me/docs) for more.
|
|
||||||
|
###### Please see [Documentation](http://beego.me/docs) for more.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
@ -56,7 +54,7 @@ Please see [Documentation](http://beego.me/docs) for more.
|
|||||||
* [http://beego.me/community](http://beego.me/community)
|
* [http://beego.me/community](http://beego.me/community)
|
||||||
* Welcome to join us in Slack: [https://beego.slack.com](https://beego.slack.com), you can get invited from [here](https://github.com/beego/beedoc/issues/232)
|
* Welcome to join us in Slack: [https://beego.slack.com](https://beego.slack.com), you can get invited from [here](https://github.com/beego/beedoc/issues/232)
|
||||||
|
|
||||||
## LICENSE
|
## License
|
||||||
|
|
||||||
beego source code is licensed under the Apache Licence, Version 2.0
|
beego source code is licensed under the Apache Licence, Version 2.0
|
||||||
(http://www.apache.org/licenses/LICENSE-2.0.html).
|
(http://www.apache.org/licenses/LICENSE-2.0.html).
|
||||||
|
87
admin.go
87
admin.go
@ -37,7 +37,7 @@ var beeAdminApp *adminApp
|
|||||||
// FilterMonitorFunc is default monitor filter when admin module is enable.
|
// FilterMonitorFunc is default monitor filter when admin module is enable.
|
||||||
// if this func returns, admin module records qbs for this request by condition of this function logic.
|
// if this func returns, admin module records qbs for this request by condition of this function logic.
|
||||||
// usage:
|
// usage:
|
||||||
// func MyFilterMonitor(method, requestPath string, t time.Duration) bool {
|
// func MyFilterMonitor(method, requestPath string, t time.Duration, pattern string, statusCode int) bool {
|
||||||
// if method == "POST" {
|
// if method == "POST" {
|
||||||
// return false
|
// return false
|
||||||
// }
|
// }
|
||||||
@ -50,7 +50,7 @@ var beeAdminApp *adminApp
|
|||||||
// return true
|
// return true
|
||||||
// }
|
// }
|
||||||
// beego.FilterMonitorFunc = MyFilterMonitor.
|
// beego.FilterMonitorFunc = MyFilterMonitor.
|
||||||
var FilterMonitorFunc func(string, string, time.Duration) bool
|
var FilterMonitorFunc func(string, string, time.Duration, string, int) bool
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
beeAdminApp = &adminApp{
|
beeAdminApp = &adminApp{
|
||||||
@ -62,7 +62,7 @@ func init() {
|
|||||||
beeAdminApp.Route("/healthcheck", healthcheck)
|
beeAdminApp.Route("/healthcheck", healthcheck)
|
||||||
beeAdminApp.Route("/task", taskStatus)
|
beeAdminApp.Route("/task", taskStatus)
|
||||||
beeAdminApp.Route("/listconf", listConf)
|
beeAdminApp.Route("/listconf", listConf)
|
||||||
FilterMonitorFunc = func(string, string, time.Duration) bool { return true }
|
FilterMonitorFunc = func(string, string, time.Duration, string, int) bool { return true }
|
||||||
}
|
}
|
||||||
|
|
||||||
// AdminIndex is the default http.Handler for admin module.
|
// AdminIndex is the default http.Handler for admin module.
|
||||||
@ -105,29 +105,12 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
|
|||||||
tmpl.Execute(rw, data)
|
tmpl.Execute(rw, data)
|
||||||
|
|
||||||
case "router":
|
case "router":
|
||||||
var (
|
content := PrintTree()
|
||||||
content = map[string]interface{}{
|
content["Fields"] = []string{
|
||||||
"Fields": []string{
|
"Router Pattern",
|
||||||
"Router Pattern",
|
"Methods",
|
||||||
"Methods",
|
"Controller",
|
||||||
"Controller",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
methods = []string{}
|
|
||||||
methodsData = make(map[string]interface{})
|
|
||||||
)
|
|
||||||
for method, t := range BeeApp.Handlers.routers {
|
|
||||||
|
|
||||||
resultList := new([][]string)
|
|
||||||
|
|
||||||
printTree(resultList, t)
|
|
||||||
|
|
||||||
methods = append(methods, method)
|
|
||||||
methodsData[method] = resultList
|
|
||||||
}
|
}
|
||||||
|
|
||||||
content["Data"] = methodsData
|
|
||||||
content["Methods"] = methods
|
|
||||||
data["Content"] = content
|
data["Content"] = content
|
||||||
data["Title"] = "Routers"
|
data["Title"] = "Routers"
|
||||||
execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl)
|
execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl)
|
||||||
@ -157,8 +140,8 @@ func listConf(rw http.ResponseWriter, r *http.Request) {
|
|||||||
resultList := new([][]string)
|
resultList := new([][]string)
|
||||||
for _, f := range bf {
|
for _, f := range bf {
|
||||||
var result = []string{
|
var result = []string{
|
||||||
fmt.Sprintf("%s", f.pattern),
|
f.pattern,
|
||||||
fmt.Sprintf("%s", utils.GetFuncName(f.filterFunc)),
|
utils.GetFuncName(f.filterFunc),
|
||||||
}
|
}
|
||||||
*resultList = append(*resultList, result)
|
*resultList = append(*resultList, result)
|
||||||
}
|
}
|
||||||
@ -200,6 +183,28 @@ func list(root string, p interface{}, m map[string]interface{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PrintTree prints all registered routers.
|
||||||
|
func PrintTree() map[string]interface{} {
|
||||||
|
var (
|
||||||
|
content = map[string]interface{}{}
|
||||||
|
methods = []string{}
|
||||||
|
methodsData = make(map[string]interface{})
|
||||||
|
)
|
||||||
|
for method, t := range BeeApp.Handlers.routers {
|
||||||
|
|
||||||
|
resultList := new([][]string)
|
||||||
|
|
||||||
|
printTree(resultList, t)
|
||||||
|
|
||||||
|
methods = append(methods, method)
|
||||||
|
methodsData[method] = resultList
|
||||||
|
}
|
||||||
|
|
||||||
|
content["Data"] = methodsData
|
||||||
|
content["Methods"] = methods
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
func printTree(resultList *[][]string, t *Tree) {
|
func printTree(resultList *[][]string, t *Tree) {
|
||||||
for _, tr := range t.fixrouters {
|
for _, tr := range t.fixrouters {
|
||||||
printTree(resultList, tr)
|
printTree(resultList, tr)
|
||||||
@ -208,12 +213,12 @@ func printTree(resultList *[][]string, t *Tree) {
|
|||||||
printTree(resultList, t.wildcard)
|
printTree(resultList, t.wildcard)
|
||||||
}
|
}
|
||||||
for _, l := range t.leaves {
|
for _, l := range t.leaves {
|
||||||
if v, ok := l.runObject.(*controllerInfo); ok {
|
if v, ok := l.runObject.(*ControllerInfo); ok {
|
||||||
if v.routerType == routerTypeBeego {
|
if v.routerType == routerTypeBeego {
|
||||||
var result = []string{
|
var result = []string{
|
||||||
v.pattern,
|
v.pattern,
|
||||||
fmt.Sprintf("%s", v.methods),
|
fmt.Sprintf("%s", v.methods),
|
||||||
fmt.Sprintf("%s", v.controllerType),
|
v.controllerType.String(),
|
||||||
}
|
}
|
||||||
*resultList = append(*resultList, result)
|
*resultList = append(*resultList, result)
|
||||||
} else if v.routerType == routerTypeRESTFul {
|
} else if v.routerType == routerTypeRESTFul {
|
||||||
@ -276,8 +281,8 @@ func profIndex(rw http.ResponseWriter, r *http.Request) {
|
|||||||
// it's in "/healthcheck" pattern in admin module.
|
// it's in "/healthcheck" pattern in admin module.
|
||||||
func healthcheck(rw http.ResponseWriter, req *http.Request) {
|
func healthcheck(rw http.ResponseWriter, req *http.Request) {
|
||||||
var (
|
var (
|
||||||
|
result []string
|
||||||
data = make(map[interface{}]interface{})
|
data = make(map[interface{}]interface{})
|
||||||
result = []string{}
|
|
||||||
resultList = new([][]string)
|
resultList = new([][]string)
|
||||||
content = map[string]interface{}{
|
content = map[string]interface{}{
|
||||||
"Fields": []string{"Name", "Message", "Status"},
|
"Fields": []string{"Name", "Message", "Status"},
|
||||||
@ -287,21 +292,20 @@ func healthcheck(rw http.ResponseWriter, req *http.Request) {
|
|||||||
for name, h := range toolbox.AdminCheckList {
|
for name, h := range toolbox.AdminCheckList {
|
||||||
if err := h.Check(); err != nil {
|
if err := h.Check(); err != nil {
|
||||||
result = []string{
|
result = []string{
|
||||||
fmt.Sprintf("error"),
|
"error",
|
||||||
fmt.Sprintf("%s", name),
|
name,
|
||||||
fmt.Sprintf("%s", err.Error()),
|
err.Error(),
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
result = []string{
|
result = []string{
|
||||||
fmt.Sprintf("success"),
|
"success",
|
||||||
fmt.Sprintf("%s", name),
|
name,
|
||||||
fmt.Sprintf("OK"),
|
"OK",
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
*resultList = append(*resultList, result)
|
*resultList = append(*resultList, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
content["Data"] = resultList
|
content["Data"] = resultList
|
||||||
data["Content"] = content
|
data["Content"] = content
|
||||||
data["Title"] = "Health Check"
|
data["Title"] = "Health Check"
|
||||||
@ -330,7 +334,6 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) {
|
|||||||
// List Tasks
|
// List Tasks
|
||||||
content := make(map[string]interface{})
|
content := make(map[string]interface{})
|
||||||
resultList := new([][]string)
|
resultList := new([][]string)
|
||||||
var result = []string{}
|
|
||||||
var fields = []string{
|
var fields = []string{
|
||||||
"Task Name",
|
"Task Name",
|
||||||
"Task Spec",
|
"Task Spec",
|
||||||
@ -339,10 +342,10 @@ func taskStatus(rw http.ResponseWriter, req *http.Request) {
|
|||||||
"",
|
"",
|
||||||
}
|
}
|
||||||
for tname, tk := range toolbox.AdminTaskList {
|
for tname, tk := range toolbox.AdminTaskList {
|
||||||
result = []string{
|
result := []string{
|
||||||
tname,
|
tname,
|
||||||
fmt.Sprintf("%s", tk.GetSpec()),
|
tk.GetSpec(),
|
||||||
fmt.Sprintf("%s", tk.GetStatus()),
|
tk.GetStatus(),
|
||||||
tk.GetPrev().String(),
|
tk.GetPrev().String(),
|
||||||
}
|
}
|
||||||
*resultList = append(*resultList, result)
|
*resultList = append(*resultList, result)
|
||||||
|
6
app.go
6
app.go
@ -348,9 +348,9 @@ func Any(rootpath string, f FilterFunc) *App {
|
|||||||
|
|
||||||
// Handler used to register a Handler router
|
// Handler used to register a Handler router
|
||||||
// usage:
|
// usage:
|
||||||
// beego.Handler("/api", func(ctx *context.Context){
|
// beego.Handler("/api", http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {
|
||||||
// ctx.Output.Body("hello world")
|
// fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path))
|
||||||
// })
|
// }))
|
||||||
func Handler(rootpath string, h http.Handler, options ...interface{}) *App {
|
func Handler(rootpath string, h http.Handler, options ...interface{}) *App {
|
||||||
BeeApp.Handlers.Handler(rootpath, h, options...)
|
BeeApp.Handlers.Handler(rootpath, h, options...)
|
||||||
return BeeApp
|
return BeeApp
|
||||||
|
22
beego.go
22
beego.go
@ -23,7 +23,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
// VERSION represent beego web framework version.
|
// VERSION represent beego web framework version.
|
||||||
VERSION = "1.7.2"
|
VERSION = "1.9.0"
|
||||||
|
|
||||||
// DEV is for develop
|
// DEV is for develop
|
||||||
DEV = "dev"
|
DEV = "dev"
|
||||||
@ -40,9 +40,9 @@ var (
|
|||||||
|
|
||||||
// AddAPPStartHook is used to register the hookfunc
|
// AddAPPStartHook is used to register the hookfunc
|
||||||
// The hookfuncs will run in beego.Run()
|
// The hookfuncs will run in beego.Run()
|
||||||
// such as sessionInit, middleware start, buildtemplate, admin start
|
// such as initiating session , starting middleware , building template, starting admin control and so on.
|
||||||
func AddAPPStartHook(hf hookfunc) {
|
func AddAPPStartHook(hf ...hookfunc) {
|
||||||
hooks = append(hooks, hf)
|
hooks = append(hooks, hf...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run beego application.
|
// Run beego application.
|
||||||
@ -69,12 +69,14 @@ func Run(params ...string) {
|
|||||||
|
|
||||||
func initBeforeHTTPRun() {
|
func initBeforeHTTPRun() {
|
||||||
//init hooks
|
//init hooks
|
||||||
AddAPPStartHook(registerMime)
|
AddAPPStartHook(
|
||||||
AddAPPStartHook(registerDefaultErrorHandler)
|
registerMime,
|
||||||
AddAPPStartHook(registerSession)
|
registerDefaultErrorHandler,
|
||||||
AddAPPStartHook(registerTemplate)
|
registerSession,
|
||||||
AddAPPStartHook(registerAdmin)
|
registerTemplate,
|
||||||
AddAPPStartHook(registerGzip)
|
registerAdmin,
|
||||||
|
registerGzip,
|
||||||
|
)
|
||||||
|
|
||||||
for _, hk := range hooks {
|
for _, hk := range hooks {
|
||||||
if err := hk(); err != nil {
|
if err := hk(); err != nil {
|
||||||
|
2
cache/conv.go
vendored
2
cache/conv.go
vendored
@ -28,7 +28,7 @@ func GetString(v interface{}) string {
|
|||||||
return string(result)
|
return string(result)
|
||||||
default:
|
default:
|
||||||
if v != nil {
|
if v != nil {
|
||||||
return fmt.Sprintf("%v", result)
|
return fmt.Sprint(result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
|
6
cache/conv_test.go
vendored
6
cache/conv_test.go
vendored
@ -118,14 +118,14 @@ func TestGetFloat64(t *testing.T) {
|
|||||||
|
|
||||||
func TestGetBool(t *testing.T) {
|
func TestGetBool(t *testing.T) {
|
||||||
var t1 = true
|
var t1 = true
|
||||||
if true != GetBool(t1) {
|
if !GetBool(t1) {
|
||||||
t.Error("get bool from bool error")
|
t.Error("get bool from bool error")
|
||||||
}
|
}
|
||||||
var t2 = "true"
|
var t2 = "true"
|
||||||
if true != GetBool(t2) {
|
if !GetBool(t2) {
|
||||||
t.Error("get bool from string error")
|
t.Error("get bool from string error")
|
||||||
}
|
}
|
||||||
if false != GetBool(nil) {
|
if GetBool(nil) {
|
||||||
t.Error("get bool from nil error")
|
t.Error("get bool from nil error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
25
cache/file.go
vendored
25
cache/file.go
vendored
@ -22,6 +22,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -222,33 +223,13 @@ func exists(path string) (bool, error) {
|
|||||||
// FileGetContents Get bytes to file.
|
// FileGetContents Get bytes to file.
|
||||||
// if non-exist, create this file.
|
// if non-exist, create this file.
|
||||||
func FileGetContents(filename string) (data []byte, e error) {
|
func FileGetContents(filename string) (data []byte, e error) {
|
||||||
f, e := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, os.ModePerm)
|
return ioutil.ReadFile(filename)
|
||||||
if e != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
stat, e := f.Stat()
|
|
||||||
if e != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
data = make([]byte, stat.Size())
|
|
||||||
result, e := f.Read(data)
|
|
||||||
if e != nil || int64(result) != stat.Size() {
|
|
||||||
return nil, e
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilePutContents Put bytes to file.
|
// FilePutContents Put bytes to file.
|
||||||
// if non-exist, create this file.
|
// if non-exist, create this file.
|
||||||
func FilePutContents(filename string, content []byte) error {
|
func FilePutContents(filename string, content []byte) error {
|
||||||
fp, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, os.ModePerm)
|
return ioutil.WriteFile(filename, content, os.ModePerm)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer fp.Close()
|
|
||||||
_, err = fp.Write(content)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GobEncode Gob encodes file cache item.
|
// GobEncode Gob encodes file cache item.
|
||||||
|
5
cache/memcache/memcache.go
vendored
5
cache/memcache/memcache.go
vendored
@ -146,10 +146,7 @@ func (rc *Cache) IsExist(key string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
_, err := rc.conn.Get(key)
|
_, err := rc.conn.Get(key)
|
||||||
if err != nil {
|
return !(err != nil)
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClearAll clear all cached in memcache.
|
// ClearAll clear all cached in memcache.
|
||||||
|
31
cache/memory.go
vendored
31
cache/memory.go
vendored
@ -217,26 +217,31 @@ func (bc *MemoryCache) vaccuum() {
|
|||||||
if bc.items == nil {
|
if bc.items == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for name := range bc.items {
|
if keys := bc.expiredKeys(); len(keys) != 0 {
|
||||||
bc.itemExpired(name)
|
bc.clearItems(keys)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// itemExpired returns true if an item is expired.
|
// expiredKeys returns key list which are expired.
|
||||||
func (bc *MemoryCache) itemExpired(name string) bool {
|
func (bc *MemoryCache) expiredKeys() (keys []string) {
|
||||||
|
bc.RLock()
|
||||||
|
defer bc.RUnlock()
|
||||||
|
for key, itm := range bc.items {
|
||||||
|
if itm.isExpire() {
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// clearItems removes all the items which key in keys.
|
||||||
|
func (bc *MemoryCache) clearItems(keys []string) {
|
||||||
bc.Lock()
|
bc.Lock()
|
||||||
defer bc.Unlock()
|
defer bc.Unlock()
|
||||||
|
for _, key := range keys {
|
||||||
itm, ok := bc.items[name]
|
delete(bc.items, key)
|
||||||
if !ok {
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
if itm.isExpire() {
|
|
||||||
delete(bc.items, name)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
2
cache/redis/redis.go
vendored
2
cache/redis/redis.go
vendored
@ -137,7 +137,7 @@ func (rc *Cache) IsExist(key string) bool {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if v == false {
|
if !v {
|
||||||
if _, err = rc.do("HDEL", rc.key, key); err != nil {
|
if _, err = rc.do("HDEL", rc.key, key); err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
21
cache/ssdb/ssdb.go
vendored
21
cache/ssdb/ssdb.go
vendored
@ -53,7 +53,7 @@ func (rc *Cache) GetMulti(keys []string) []interface{} {
|
|||||||
resSize := len(res)
|
resSize := len(res)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
for i := 1; i < resSize; i += 2 {
|
for i := 1; i < resSize; i += 2 {
|
||||||
values = append(values, string(res[i+1]))
|
values = append(values, res[i+1])
|
||||||
}
|
}
|
||||||
return values
|
return values
|
||||||
}
|
}
|
||||||
@ -71,10 +71,7 @@ func (rc *Cache) DelMulti(keys []string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
_, err := rc.conn.Do("multi_del", keys)
|
_, err := rc.conn.Do("multi_del", keys)
|
||||||
if err != nil {
|
return err
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put put value to memcache. only support string.
|
// Put put value to memcache. only support string.
|
||||||
@ -113,10 +110,7 @@ func (rc *Cache) Delete(key string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
_, err := rc.conn.Del(key)
|
_, err := rc.conn.Del(key)
|
||||||
if err != nil {
|
return err
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Incr increase counter.
|
// Incr increase counter.
|
||||||
@ -152,7 +146,7 @@ func (rc *Cache) IsExist(key string) bool {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if resp[1] == "1" {
|
if len(resp) == 2 && resp[1] == "1" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
@ -175,7 +169,7 @@ func (rc *Cache) ClearAll() error {
|
|||||||
}
|
}
|
||||||
keys := []string{}
|
keys := []string{}
|
||||||
for i := 1; i < size; i += 2 {
|
for i := 1; i < size; i += 2 {
|
||||||
keys = append(keys, string(resp[i]))
|
keys = append(keys, resp[i])
|
||||||
}
|
}
|
||||||
_, e := rc.conn.Do("multi_del", keys)
|
_, e := rc.conn.Do("multi_del", keys)
|
||||||
if e != nil {
|
if e != nil {
|
||||||
@ -229,10 +223,7 @@ func (rc *Cache) connectInit() error {
|
|||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
rc.conn, err = ssdb.Connect(host, port)
|
rc.conn, err = ssdb.Connect(host, port)
|
||||||
if err != nil {
|
return err
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
16
config.go
16
config.go
@ -41,6 +41,7 @@ type Config struct {
|
|||||||
EnableGzip bool
|
EnableGzip bool
|
||||||
MaxMemory int64
|
MaxMemory int64
|
||||||
EnableErrorsShow bool
|
EnableErrorsShow bool
|
||||||
|
EnableErrorsRender bool
|
||||||
Listen Listen
|
Listen Listen
|
||||||
WebConfig WebConfig
|
WebConfig WebConfig
|
||||||
Log LogConfig
|
Log LogConfig
|
||||||
@ -144,9 +145,6 @@ func init() {
|
|||||||
if err = parseConfig(appConfigPath); err != nil {
|
if err = parseConfig(appConfigPath); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
if err = os.Chdir(AppPath); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func recoverPanic(ctx *context.Context) {
|
func recoverPanic(ctx *context.Context) {
|
||||||
@ -174,7 +172,7 @@ func recoverPanic(ctx *context.Context) {
|
|||||||
logs.Critical(fmt.Sprintf("%s:%d", file, line))
|
logs.Critical(fmt.Sprintf("%s:%d", file, line))
|
||||||
stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line))
|
stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line))
|
||||||
}
|
}
|
||||||
if BConfig.RunMode == DEV {
|
if BConfig.RunMode == DEV && BConfig.EnableErrorsRender {
|
||||||
showErr(err, ctx, stack)
|
showErr(err, ctx, stack)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -192,6 +190,7 @@ func newBConfig() *Config {
|
|||||||
EnableGzip: false,
|
EnableGzip: false,
|
||||||
MaxMemory: 1 << 26, //64MB
|
MaxMemory: 1 << 26, //64MB
|
||||||
EnableErrorsShow: true,
|
EnableErrorsShow: true,
|
||||||
|
EnableErrorsRender: true,
|
||||||
Listen: Listen{
|
Listen: Listen{
|
||||||
Graceful: false,
|
Graceful: false,
|
||||||
ServerTimeOut: 0,
|
ServerTimeOut: 0,
|
||||||
@ -257,6 +256,9 @@ func parseConfig(appConfigPath string) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func assignConfig(ac config.Configer) error {
|
func assignConfig(ac config.Configer) error {
|
||||||
|
for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} {
|
||||||
|
assignSingleConfig(i, ac)
|
||||||
|
}
|
||||||
// set the run mode first
|
// set the run mode first
|
||||||
if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" {
|
if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" {
|
||||||
BConfig.RunMode = envRunMode
|
BConfig.RunMode = envRunMode
|
||||||
@ -264,10 +266,6 @@ func assignConfig(ac config.Configer) error {
|
|||||||
BConfig.RunMode = runMode
|
BConfig.RunMode = runMode
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} {
|
|
||||||
assignSingleConfig(i, ac)
|
|
||||||
}
|
|
||||||
|
|
||||||
if sd := ac.String("StaticDir"); sd != "" {
|
if sd := ac.String("StaticDir"); sd != "" {
|
||||||
BConfig.WebConfig.StaticDir = map[string]string{}
|
BConfig.WebConfig.StaticDir = map[string]string{}
|
||||||
sds := strings.Fields(sd)
|
sds := strings.Fields(sd)
|
||||||
@ -347,7 +345,7 @@ func assignSingleConfig(p interface{}, ac config.Configer) {
|
|||||||
case reflect.String:
|
case reflect.String:
|
||||||
pf.SetString(ac.DefaultString(name, pf.String()))
|
pf.SetString(ac.DefaultString(name, pf.String()))
|
||||||
case reflect.Int, reflect.Int64:
|
case reflect.Int, reflect.Int64:
|
||||||
pf.SetInt(int64(ac.DefaultInt64(name, pf.Int())))
|
pf.SetInt(ac.DefaultInt64(name, pf.Int()))
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
pf.SetBool(ac.DefaultBool(name, pf.Bool()))
|
pf.SetBool(ac.DefaultBool(name, pf.Bool()))
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
|
@ -189,16 +189,16 @@ func ParseBool(val interface{}) (value bool, err error) {
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
case int8, int32, int64:
|
case int8, int32, int64:
|
||||||
strV := fmt.Sprintf("%s", v)
|
strV := fmt.Sprintf("%d", v)
|
||||||
if strV == "1" {
|
if strV == "1" {
|
||||||
return true, nil
|
return true, nil
|
||||||
} else if strV == "0" {
|
} else if strV == "0" {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
case float64:
|
case float64:
|
||||||
if v == 1 {
|
if v == 1.0 {
|
||||||
return true, nil
|
return true, nil
|
||||||
} else if v == 0 {
|
} else if v == 0.0 {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
87
config/env/env.go
vendored
Normal file
87
config/env/env.go
vendored
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
// Copyright 2014 beego Author. All Rights Reserved.
|
||||||
|
// Copyright 2017 Faissal Elamraoui. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Package env is used to parse environment.
|
||||||
|
package env
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/astaxie/beego/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
var env *utils.BeeMap
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
env = utils.NewBeeMap()
|
||||||
|
for _, e := range os.Environ() {
|
||||||
|
splits := strings.Split(e, "=")
|
||||||
|
env.Set(splits[0], os.Getenv(splits[0]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a value by key.
|
||||||
|
// If the key does not exist, the default value will be returned.
|
||||||
|
func Get(key string, defVal string) string {
|
||||||
|
if val := env.Get(key); val != nil {
|
||||||
|
return val.(string)
|
||||||
|
}
|
||||||
|
return defVal
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustGet returns a value by key.
|
||||||
|
// If the key does not exist, it will return an error.
|
||||||
|
func MustGet(key string) (string, error) {
|
||||||
|
if val := env.Get(key); val != nil {
|
||||||
|
return val.(string), nil
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("no env variable with %s", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets a value in the ENV copy.
|
||||||
|
// This does not affect the child process environment.
|
||||||
|
func Set(key string, value string) {
|
||||||
|
env.Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustSet sets a value in the ENV copy and the child process environment.
|
||||||
|
// It returns an error in case the set operation failed.
|
||||||
|
func MustSet(key string, value string) error {
|
||||||
|
err := os.Setenv(key, value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
env.Set(key, value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAll returns all keys/values in the current child process environment.
|
||||||
|
func GetAll() map[string]string {
|
||||||
|
items := env.Items()
|
||||||
|
envs := make(map[string]string, env.Count())
|
||||||
|
|
||||||
|
for key, val := range items {
|
||||||
|
switch key := key.(type) {
|
||||||
|
case string:
|
||||||
|
switch val := val.(type) {
|
||||||
|
case string:
|
||||||
|
envs[key] = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return envs
|
||||||
|
}
|
75
config/env/env_test.go
vendored
Normal file
75
config/env/env_test.go
vendored
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
// Copyright 2014 beego Author. All Rights Reserved.
|
||||||
|
// Copyright 2017 Faissal Elamraoui. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
package env
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEnvGet(t *testing.T) {
|
||||||
|
gopath := Get("GOPATH", "")
|
||||||
|
if gopath != os.Getenv("GOPATH") {
|
||||||
|
t.Error("expected GOPATH not empty.")
|
||||||
|
}
|
||||||
|
|
||||||
|
noExistVar := Get("NOEXISTVAR", "foo")
|
||||||
|
if noExistVar != "foo" {
|
||||||
|
t.Errorf("expected NOEXISTVAR to equal foo, got %s.", noExistVar)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnvMustGet(t *testing.T) {
|
||||||
|
gopath, err := MustGet("GOPATH")
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gopath != os.Getenv("GOPATH") {
|
||||||
|
t.Errorf("expected GOPATH to be the same, got %s.", gopath)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = MustGet("NOEXISTVAR")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error to be non-nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnvSet(t *testing.T) {
|
||||||
|
Set("MYVAR", "foo")
|
||||||
|
myVar := Get("MYVAR", "bar")
|
||||||
|
if myVar != "foo" {
|
||||||
|
t.Errorf("expected MYVAR to equal foo, got %s.", myVar)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnvMustSet(t *testing.T) {
|
||||||
|
err := MustSet("FOO", "bar")
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fooVar := os.Getenv("FOO")
|
||||||
|
if fooVar != "bar" {
|
||||||
|
t.Errorf("expected FOO variable to equal bar, got %s.", fooVar)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnvGetAll(t *testing.T) {
|
||||||
|
envMap := GetAll()
|
||||||
|
if len(envMap) == 0 {
|
||||||
|
t.Error("expected environment not empty.")
|
||||||
|
}
|
||||||
|
}
|
@ -18,16 +18,14 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"os/user"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -52,24 +50,26 @@ func (ini *IniConfig) Parse(name string) (Configer, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
|
func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
|
||||||
file, err := os.Open(name)
|
data, err := ioutil.ReadFile(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return ini.parseData(filepath.Dir(name), data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, error) {
|
||||||
cfg := &IniConfigContainer{
|
cfg := &IniConfigContainer{
|
||||||
file.Name(),
|
data: make(map[string]map[string]string),
|
||||||
make(map[string]map[string]string),
|
sectionComment: make(map[string]string),
|
||||||
make(map[string]string),
|
keyComment: make(map[string]string),
|
||||||
make(map[string]string),
|
RWMutex: sync.RWMutex{},
|
||||||
sync.RWMutex{},
|
|
||||||
}
|
}
|
||||||
cfg.Lock()
|
cfg.Lock()
|
||||||
defer cfg.Unlock()
|
defer cfg.Unlock()
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
var comment bytes.Buffer
|
var comment bytes.Buffer
|
||||||
buf := bufio.NewReader(file)
|
buf := bufio.NewReader(bytes.NewBuffer(data))
|
||||||
// check the BOM
|
// check the BOM
|
||||||
head, err := buf.Peek(3)
|
head, err := buf.Peek(3)
|
||||||
if err == nil && head[0] == 239 && head[1] == 187 && head[2] == 191 {
|
if err == nil && head[0] == 239 && head[1] == 187 && head[2] == 191 {
|
||||||
@ -130,16 +130,20 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
|
|||||||
|
|
||||||
// handle include "other.conf"
|
// handle include "other.conf"
|
||||||
if len(keyValue) == 1 && strings.HasPrefix(key, "include") {
|
if len(keyValue) == 1 && strings.HasPrefix(key, "include") {
|
||||||
|
|
||||||
includefiles := strings.Fields(key)
|
includefiles := strings.Fields(key)
|
||||||
if includefiles[0] == "include" && len(includefiles) == 2 {
|
if includefiles[0] == "include" && len(includefiles) == 2 {
|
||||||
|
|
||||||
otherfile := strings.Trim(includefiles[1], "\"")
|
otherfile := strings.Trim(includefiles[1], "\"")
|
||||||
if !filepath.IsAbs(otherfile) {
|
if !filepath.IsAbs(otherfile) {
|
||||||
otherfile = filepath.Join(filepath.Dir(name), otherfile)
|
otherfile = filepath.Join(dir, otherfile)
|
||||||
}
|
}
|
||||||
|
|
||||||
i, err := ini.parseFile(otherfile)
|
i, err := ini.parseFile(otherfile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for sec, dt := range i.data {
|
for sec, dt := range i.data {
|
||||||
if _, ok := cfg.data[sec]; !ok {
|
if _, ok := cfg.data[sec]; !ok {
|
||||||
cfg.data[sec] = make(map[string]string)
|
cfg.data[sec] = make(map[string]string)
|
||||||
@ -148,12 +152,15 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
|
|||||||
cfg.data[sec][k] = v
|
cfg.data[sec][k] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for sec, comm := range i.sectionComment {
|
for sec, comm := range i.sectionComment {
|
||||||
cfg.sectionComment[sec] = comm
|
cfg.sectionComment[sec] = comm
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, comm := range i.keyComment {
|
for k, comm := range i.keyComment {
|
||||||
cfg.keyComment[k] = comm
|
cfg.keyComment[k] = comm
|
||||||
}
|
}
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -177,20 +184,25 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ParseData parse ini the data
|
// ParseData parse ini the data
|
||||||
|
// When include other.conf,other.conf is either absolute directory
|
||||||
|
// or under beego in default temporary directory(/tmp/beego[-username]).
|
||||||
func (ini *IniConfig) ParseData(data []byte) (Configer, error) {
|
func (ini *IniConfig) ParseData(data []byte) (Configer, error) {
|
||||||
// Save memory data to temporary file
|
dir := "beego"
|
||||||
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond()))
|
currentUser, err := user.Current()
|
||||||
os.MkdirAll(path.Dir(tmpName), os.ModePerm)
|
if err == nil {
|
||||||
if err := ioutil.WriteFile(tmpName, data, 0655); err != nil {
|
dir = "beego-" + currentUser.Username
|
||||||
|
}
|
||||||
|
dir = filepath.Join(os.TempDir(), dir)
|
||||||
|
if err = os.MkdirAll(dir, os.ModePerm); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return ini.Parse(tmpName)
|
|
||||||
|
return ini.parseData(dir, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IniConfigContainer A Config represents the ini configuration.
|
// IniConfigContainer A Config represents the ini configuration.
|
||||||
// When set and get value, support key as section:name type.
|
// When set and get value, support key as section:name type.
|
||||||
type IniConfigContainer struct {
|
type IniConfigContainer struct {
|
||||||
filename string
|
|
||||||
data map[string]map[string]string // section=> key:val
|
data map[string]map[string]string // section=> key:val
|
||||||
sectionComment map[string]string // section : comment
|
sectionComment map[string]string // section : comment
|
||||||
keyComment map[string]string // id: []{comment, key...}; id 1 is for main comment.
|
keyComment map[string]string // id: []{comment, key...}; id 1 is for main comment.
|
||||||
@ -297,7 +309,7 @@ func (c *IniConfigContainer) GetSection(section string) (map[string]string, erro
|
|||||||
if v, ok := c.data[section]; ok {
|
if v, ok := c.data[section]; ok {
|
||||||
return v, nil
|
return v, nil
|
||||||
}
|
}
|
||||||
return nil, errors.New("not exist setction")
|
return nil, errors.New("not exist section")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveConfigFile save the config into file.
|
// SaveConfigFile save the config into file.
|
||||||
@ -313,7 +325,10 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
|
|||||||
|
|
||||||
// Get section or key comments. Fixed #1607
|
// Get section or key comments. Fixed #1607
|
||||||
getCommentStr := func(section, key string) string {
|
getCommentStr := func(section, key string) string {
|
||||||
comment, ok := "", false
|
var (
|
||||||
|
comment string
|
||||||
|
ok bool
|
||||||
|
)
|
||||||
if len(key) == 0 {
|
if len(key) == 0 {
|
||||||
comment, ok = c.sectionComment[section]
|
comment, ok = c.sectionComment[section]
|
||||||
} else {
|
} else {
|
||||||
@ -393,11 +408,8 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
_, err = buf.WriteTo(f)
|
||||||
if _, err = buf.WriteTo(f); err != nil {
|
return err
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set writes a new value for key.
|
// Set writes a new value for key.
|
||||||
@ -412,7 +424,7 @@ func (c *IniConfigContainer) Set(key, value string) error {
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
section, k string
|
section, k string
|
||||||
sectionKey = strings.Split(key, "::")
|
sectionKey = strings.Split(strings.ToLower(key), "::")
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(sectionKey) >= 2 {
|
if len(sectionKey) >= 2 {
|
||||||
|
@ -181,7 +181,7 @@ name=mysql
|
|||||||
cfgData := string(data)
|
cfgData := string(data)
|
||||||
datas := strings.Split(saveResult, "\n")
|
datas := strings.Split(saveResult, "\n")
|
||||||
for _, line := range datas {
|
for _, line := range datas {
|
||||||
if strings.Contains(cfgData, line+"\n") == false {
|
if !strings.Contains(cfgData, line+"\n") {
|
||||||
t.Fatalf("different after save ini config file. need contains %q", line)
|
t.Fatalf("different after save ini config file. need contains %q", line)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -35,11 +35,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/astaxie/beego/config"
|
"github.com/astaxie/beego/config"
|
||||||
"github.com/beego/x2j"
|
"github.com/beego/x2j"
|
||||||
@ -52,36 +50,26 @@ type Config struct{}
|
|||||||
|
|
||||||
// Parse returns a ConfigContainer with parsed xml config map.
|
// Parse returns a ConfigContainer with parsed xml config map.
|
||||||
func (xc *Config) Parse(filename string) (config.Configer, error) {
|
func (xc *Config) Parse(filename string) (config.Configer, error) {
|
||||||
file, err := os.Open(filename)
|
context, err := ioutil.ReadFile(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
|
return xc.ParseData(context)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseData xml data
|
||||||
|
func (xc *Config) ParseData(data []byte) (config.Configer, error) {
|
||||||
x := &ConfigContainer{data: make(map[string]interface{})}
|
x := &ConfigContainer{data: make(map[string]interface{})}
|
||||||
content, err := ioutil.ReadAll(file)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := x2j.DocToMap(string(content))
|
d, err := x2j.DocToMap(string(data))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
x.data = config.ExpandValueEnvForMap(d["config"].(map[string]interface{}))
|
x.data = config.ExpandValueEnvForMap(d["config"].(map[string]interface{}))
|
||||||
return x, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseData xml data
|
return x, nil
|
||||||
func (xc *Config) ParseData(data []byte) (config.Configer, error) {
|
|
||||||
// Save memory data to temporary file
|
|
||||||
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond()))
|
|
||||||
os.MkdirAll(path.Dir(tmpName), os.ModePerm)
|
|
||||||
if err := ioutil.WriteFile(tmpName, data, 0655); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return xc.Parse(tmpName)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConfigContainer A Config represents the xml configuration.
|
// ConfigContainer A Config represents the xml configuration.
|
||||||
|
@ -37,10 +37,8 @@ import (
|
|||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/astaxie/beego/config"
|
"github.com/astaxie/beego/config"
|
||||||
"github.com/beego/goyaml2"
|
"github.com/beego/goyaml2"
|
||||||
@ -63,26 +61,30 @@ func (yaml *Config) Parse(filename string) (y config.Configer, err error) {
|
|||||||
|
|
||||||
// ParseData parse yaml data
|
// ParseData parse yaml data
|
||||||
func (yaml *Config) ParseData(data []byte) (config.Configer, error) {
|
func (yaml *Config) ParseData(data []byte) (config.Configer, error) {
|
||||||
// Save memory data to temporary file
|
cnf, err := parseYML(data)
|
||||||
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond()))
|
if err != nil {
|
||||||
os.MkdirAll(path.Dir(tmpName), os.ModePerm)
|
|
||||||
if err := ioutil.WriteFile(tmpName, data, 0655); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return yaml.Parse(tmpName)
|
|
||||||
|
return &ConfigContainer{
|
||||||
|
data: cnf,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadYmlReader Read yaml file to map.
|
// ReadYmlReader Read yaml file to map.
|
||||||
// if json like, use json package, unless goyaml2 package.
|
// if json like, use json package, unless goyaml2 package.
|
||||||
func ReadYmlReader(path string) (cnf map[string]interface{}, err error) {
|
func ReadYmlReader(path string) (cnf map[string]interface{}, err error) {
|
||||||
f, err := os.Open(path)
|
buf, err := ioutil.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
buf, err := ioutil.ReadAll(f)
|
return parseYML(buf)
|
||||||
if err != nil || len(buf) < 3 {
|
}
|
||||||
|
|
||||||
|
// parseYML parse yaml formatted []byte to map.
|
||||||
|
func parseYML(buf []byte) (cnf map[string]interface{}, err error) {
|
||||||
|
if len(buf) < 3 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -250,7 +252,7 @@ func (c *ConfigContainer) GetSection(section string) (map[string]string, error)
|
|||||||
if v, ok := c.data[section]; ok {
|
if v, ok := c.data[section]; ok {
|
||||||
return v.(map[string]string), nil
|
return v.(map[string]string), nil
|
||||||
}
|
}
|
||||||
return nil, errors.New("not exist setction")
|
return nil, errors.New("not exist section")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveConfigFile save the config into file
|
// SaveConfigFile save the config into file
|
||||||
|
@ -39,6 +39,7 @@ var (
|
|||||||
getMethodOnly bool
|
getMethodOnly bool
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// InitGzip init the gzipcompress
|
||||||
func InitGzip(minLength, compressLevel int, methods []string) {
|
func InitGzip(minLength, compressLevel int, methods []string) {
|
||||||
if minLength >= 0 {
|
if minLength >= 0 {
|
||||||
gzipMinLength = minLength
|
gzipMinLength = minLength
|
||||||
|
@ -171,6 +171,22 @@ func (ctx *Context) CheckXSRFCookie() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RenderMethodResult renders the return value of a controller method to the output
|
||||||
|
func (ctx *Context) RenderMethodResult(result interface{}) {
|
||||||
|
if result != nil {
|
||||||
|
renderer, ok := result.(Renderer)
|
||||||
|
if !ok {
|
||||||
|
err, ok := result.(error)
|
||||||
|
if ok {
|
||||||
|
renderer = errorRenderer(err)
|
||||||
|
} else {
|
||||||
|
renderer = jsonRenderer(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
renderer.Render(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//Response is a wrapper for the http.ResponseWriter
|
//Response is a wrapper for the http.ResponseWriter
|
||||||
//started set to true if response was written to then don't execute other handler
|
//started set to true if response was written to then don't execute other handler
|
||||||
type Response struct {
|
type Response struct {
|
||||||
|
@ -16,9 +16,11 @@ package context
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
@ -349,11 +351,22 @@ func (input *BeegoInput) CopyBody(MaxMemory int64) []byte {
|
|||||||
if input.Context.Request.Body == nil {
|
if input.Context.Request.Body == nil {
|
||||||
return []byte{}
|
return []byte{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var requestbody []byte
|
||||||
safe := &io.LimitedReader{R: input.Context.Request.Body, N: MaxMemory}
|
safe := &io.LimitedReader{R: input.Context.Request.Body, N: MaxMemory}
|
||||||
requestbody, _ := ioutil.ReadAll(safe)
|
if input.Header("Content-Encoding") == "gzip" {
|
||||||
|
reader, err := gzip.NewReader(safe)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
requestbody, _ = ioutil.ReadAll(reader)
|
||||||
|
} else {
|
||||||
|
requestbody, _ = ioutil.ReadAll(safe)
|
||||||
|
}
|
||||||
|
|
||||||
input.Context.Request.Body.Close()
|
input.Context.Request.Body.Close()
|
||||||
bf := bytes.NewBuffer(requestbody)
|
bf := bytes.NewBuffer(requestbody)
|
||||||
input.Context.Request.Body = ioutil.NopCloser(bf)
|
input.Context.Request.Body = http.MaxBytesReader(input.Context.ResponseWriter, ioutil.NopCloser(bf), MaxMemory)
|
||||||
input.RequestBody = requestbody
|
input.RequestBody = requestbody
|
||||||
return requestbody
|
return requestbody
|
||||||
}
|
}
|
||||||
@ -413,7 +426,13 @@ func (input *BeegoInput) Bind(dest interface{}, key string) error {
|
|||||||
if !value.CanSet() {
|
if !value.CanSet() {
|
||||||
return errors.New("beego: non-settable variable passed to Bind: " + key)
|
return errors.New("beego: non-settable variable passed to Bind: " + key)
|
||||||
}
|
}
|
||||||
rv := input.bind(key, value.Type())
|
typ := value.Type()
|
||||||
|
// Get real type if dest define with interface{}.
|
||||||
|
// e.g var dest interface{} dest=1.0
|
||||||
|
if value.Kind() == reflect.Interface {
|
||||||
|
typ = value.Elem().Type()
|
||||||
|
}
|
||||||
|
rv := input.bind(key, typ)
|
||||||
if !rv.IsValid() {
|
if !rv.IsValid() {
|
||||||
return errors.New("beego: reflect value is empty")
|
return errors.New("beego: reflect value is empty")
|
||||||
}
|
}
|
||||||
@ -422,6 +441,9 @@ func (input *BeegoInput) Bind(dest interface{}, key string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value {
|
func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value {
|
||||||
|
if input.Context.Request.Form == nil {
|
||||||
|
input.Context.Request.ParseForm()
|
||||||
|
}
|
||||||
rv := reflect.Zero(typ)
|
rv := reflect.Zero(typ)
|
||||||
switch typ.Kind() {
|
switch typ.Kind() {
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
@ -15,81 +15,97 @@
|
|||||||
package context
|
package context
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParse(t *testing.T) {
|
func TestBind(t *testing.T) {
|
||||||
r, _ := http.NewRequest("GET", "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil)
|
type testItem struct {
|
||||||
beegoInput := NewInput()
|
field string
|
||||||
beegoInput.Context = NewContext()
|
empty interface{}
|
||||||
beegoInput.Context.Reset(httptest.NewRecorder(), r)
|
want interface{}
|
||||||
beegoInput.ParseFormOrMulitForm(1 << 20)
|
}
|
||||||
|
type Human struct {
|
||||||
|
ID int
|
||||||
|
Nick string
|
||||||
|
Pwd string
|
||||||
|
Ms bool
|
||||||
|
}
|
||||||
|
|
||||||
var id int
|
cases := []struct {
|
||||||
err := beegoInput.Bind(&id, "id")
|
request string
|
||||||
if id != 123 || err != nil {
|
valueGp []testItem
|
||||||
t.Fatal("id should has int value")
|
}{
|
||||||
}
|
{"/?p=str", []testItem{{"p", interface{}(""), interface{}("str")}}},
|
||||||
fmt.Println(id)
|
|
||||||
|
|
||||||
var isok bool
|
{"/?p=", []testItem{{"p", "", ""}}},
|
||||||
err = beegoInput.Bind(&isok, "isok")
|
{"/?p=str", []testItem{{"p", "", "str"}}},
|
||||||
if !isok || err != nil {
|
|
||||||
t.Fatal("isok should be true")
|
|
||||||
}
|
|
||||||
fmt.Println(isok)
|
|
||||||
|
|
||||||
var float float64
|
{"/?p=123", []testItem{{"p", 0, 123}}},
|
||||||
err = beegoInput.Bind(&float, "ft")
|
{"/?p=123", []testItem{{"p", uint(0), uint(123)}}},
|
||||||
if float != 1.2 || err != nil {
|
|
||||||
t.Fatal("float should be equal to 1.2")
|
|
||||||
}
|
|
||||||
fmt.Println(float)
|
|
||||||
|
|
||||||
ol := make([]int, 0, 2)
|
{"/?p=1.0", []testItem{{"p", 0.0, 1.0}}},
|
||||||
err = beegoInput.Bind(&ol, "ol")
|
{"/?p=1", []testItem{{"p", false, true}}},
|
||||||
if len(ol) != 2 || err != nil || ol[0] != 1 || ol[1] != 2 {
|
|
||||||
t.Fatal("ol should has two elements")
|
|
||||||
}
|
|
||||||
fmt.Println(ol)
|
|
||||||
|
|
||||||
ul := make([]string, 0, 2)
|
{"/?p=true", []testItem{{"p", false, true}}},
|
||||||
err = beegoInput.Bind(&ul, "ul")
|
{"/?p=ON", []testItem{{"p", false, true}}},
|
||||||
if len(ul) != 2 || err != nil || ul[0] != "str" || ul[1] != "array" {
|
{"/?p=on", []testItem{{"p", false, true}}},
|
||||||
t.Fatal("ul should has two elements")
|
{"/?p=1", []testItem{{"p", false, true}}},
|
||||||
}
|
{"/?p=2", []testItem{{"p", false, false}}},
|
||||||
fmt.Println(ul)
|
{"/?p=false", []testItem{{"p", false, false}}},
|
||||||
|
|
||||||
type User struct {
|
{"/?p[a]=1&p[b]=2&p[c]=3", []testItem{{"p", map[string]int{}, map[string]int{"a": 1, "b": 2, "c": 3}}}},
|
||||||
Name string
|
{"/?p[a]=v1&p[b]=v2&p[c]=v3", []testItem{{"p", map[string]string{}, map[string]string{"a": "v1", "b": "v2", "c": "v3"}}}},
|
||||||
}
|
|
||||||
user := User{}
|
|
||||||
err = beegoInput.Bind(&user, "user")
|
|
||||||
if err != nil || user.Name != "astaxie" {
|
|
||||||
t.Fatal("user should has name")
|
|
||||||
}
|
|
||||||
fmt.Println(user)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParse2(t *testing.T) {
|
{"/?p[]=8&p[]=9&p[]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}},
|
||||||
r, _ := http.NewRequest("GET", "/?user[0][Username]=Raph&user[1].Username=Leo&user[0].Password=123456&user[1][Password]=654321", nil)
|
{"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}},
|
||||||
beegoInput := NewInput()
|
{"/?p[0]=8&p[1]=9&p[2]=10&p[5]=14", []testItem{{"p", []int{}, []int{8, 9, 10, 0, 0, 14}}}},
|
||||||
beegoInput.Context = NewContext()
|
{"/?p[0]=8.0&p[1]=9.0&p[2]=10.0", []testItem{{"p", []float64{}, []float64{8.0, 9.0, 10.0}}}},
|
||||||
beegoInput.Context.Reset(httptest.NewRecorder(), r)
|
|
||||||
beegoInput.ParseFormOrMulitForm(1 << 20)
|
{"/?p[]=10&p[]=9&p[]=8", []testItem{{"p", []string{}, []string{"10", "9", "8"}}}},
|
||||||
type User struct {
|
{"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []string{}, []string{"8", "9", "10"}}}},
|
||||||
Username string
|
|
||||||
Password string
|
{"/?p[0]=true&p[1]=false&p[2]=true&p[5]=1&p[6]=ON&p[7]=other", []testItem{{"p", []bool{}, []bool{true, false, true, false, false, true, true, false}}}},
|
||||||
|
|
||||||
|
{"/?human.Nick=astaxie", []testItem{{"human", Human{}, Human{Nick: "astaxie"}}}},
|
||||||
|
{"/?human.ID=888&human.Nick=astaxie&human.Ms=true&human[Pwd]=pass", []testItem{{"human", Human{}, Human{ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass"}}}},
|
||||||
|
{"/?human[0].ID=888&human[0].Nick=astaxie&human[0].Ms=true&human[0][Pwd]=pass01&human[1].ID=999&human[1].Nick=ysqi&human[1].Ms=On&human[1].Pwd=pass02",
|
||||||
|
[]testItem{{"human", []Human{}, []Human{
|
||||||
|
{ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass01"},
|
||||||
|
{ID: 999, Nick: "ysqi", Ms: true, Pwd: "pass02"},
|
||||||
|
}}}},
|
||||||
|
|
||||||
|
{
|
||||||
|
"/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&human.Nick=astaxie",
|
||||||
|
[]testItem{
|
||||||
|
{"id", 0, 123},
|
||||||
|
{"isok", false, true},
|
||||||
|
{"ft", 0.0, 1.2},
|
||||||
|
{"ol", []int{}, []int{1, 2}},
|
||||||
|
{"ul", []string{}, []string{"str", "array"}},
|
||||||
|
{"human", Human{}, Human{Nick: "astaxie"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
var users []User
|
for _, c := range cases {
|
||||||
err := beegoInput.Bind(&users, "user")
|
r, _ := http.NewRequest("GET", c.request, nil)
|
||||||
fmt.Println(users)
|
beegoInput := NewInput()
|
||||||
if err != nil || users[0].Username != "Raph" || users[0].Password != "123456" || users[1].Username != "Leo" || users[1].Password != "654321" {
|
beegoInput.Context = NewContext()
|
||||||
t.Fatal("users info wrong")
|
beegoInput.Context.Reset(httptest.NewRecorder(), r)
|
||||||
|
|
||||||
|
for _, item := range c.valueGp {
|
||||||
|
got := item.empty
|
||||||
|
err := beegoInput.Bind(&got, item.field)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, item.want) {
|
||||||
|
t.Fatalf("Bind %q error,should be:\n%#v \ngot:\n%#v", item.field, item.want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,6 +67,7 @@ func (output *BeegoOutput) Body(content []byte) error {
|
|||||||
}
|
}
|
||||||
if b, n, _ := WriteBody(encoding, buf, content); b {
|
if b, n, _ := WriteBody(encoding, buf, content); b {
|
||||||
output.Header("Content-Encoding", n)
|
output.Header("Content-Encoding", n)
|
||||||
|
output.Header("Content-Length", strconv.Itoa(buf.Len()))
|
||||||
} else {
|
} else {
|
||||||
output.Header("Content-Length", strconv.Itoa(len(content)))
|
output.Header("Content-Length", strconv.Itoa(len(content)))
|
||||||
}
|
}
|
||||||
@ -167,6 +168,19 @@ func sanitizeValue(v string) string {
|
|||||||
return cookieValueSanitizer.Replace(v)
|
return cookieValueSanitizer.Replace(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func jsonRenderer(value interface{}) Renderer {
|
||||||
|
return rendererFunc(func(ctx *Context) {
|
||||||
|
ctx.Output.JSON(value, false, false)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func errorRenderer(err error) Renderer {
|
||||||
|
return rendererFunc(func(ctx *Context) {
|
||||||
|
ctx.Output.SetStatus(500)
|
||||||
|
ctx.Output.Body([]byte(err.Error()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// JSON writes json to response body.
|
// JSON writes json to response body.
|
||||||
// if coding is true, it converts utf-8 to \u0000 type.
|
// if coding is true, it converts utf-8 to \u0000 type.
|
||||||
func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, coding bool) error {
|
func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, coding bool) error {
|
||||||
@ -329,17 +343,17 @@ func (output *BeegoOutput) IsServerError() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func stringsToJSON(str string) string {
|
func stringsToJSON(str string) string {
|
||||||
rs := []rune(str)
|
var jsons bytes.Buffer
|
||||||
jsons := ""
|
for _, r := range str {
|
||||||
for _, r := range rs {
|
|
||||||
rint := int(r)
|
rint := int(r)
|
||||||
if rint < 128 {
|
if rint < 128 {
|
||||||
jsons += string(r)
|
jsons.WriteRune(r)
|
||||||
} else {
|
} else {
|
||||||
jsons += "\\u" + strconv.FormatInt(int64(rint), 16) // json
|
jsons.WriteString("\\u")
|
||||||
|
jsons.WriteString(strconv.FormatInt(int64(rint), 16))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return jsons
|
return jsons.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Session sets session item value with given key.
|
// Session sets session item value with given key.
|
||||||
|
78
context/param/conv.go
Normal file
78
context/param/conv.go
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
package param
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
beecontext "github.com/astaxie/beego/context"
|
||||||
|
"github.com/astaxie/beego/logs"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConvertParams converts http method params to values that will be passed to the method controller as arguments
|
||||||
|
func ConvertParams(methodParams []*MethodParam, methodType reflect.Type, ctx *beecontext.Context) (result []reflect.Value) {
|
||||||
|
result = make([]reflect.Value, 0, len(methodParams))
|
||||||
|
for i := 0; i < len(methodParams); i++ {
|
||||||
|
reflectValue := convertParam(methodParams[i], methodType.In(i), ctx)
|
||||||
|
result = append(result, reflectValue)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertParam(param *MethodParam, paramType reflect.Type, ctx *beecontext.Context) (result reflect.Value) {
|
||||||
|
paramValue := getParamValue(param, ctx)
|
||||||
|
if paramValue == "" {
|
||||||
|
if param.required {
|
||||||
|
ctx.Abort(400, fmt.Sprintf("Missing parameter %s", param.name))
|
||||||
|
} else {
|
||||||
|
paramValue = param.defaultValue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reflectValue, err := parseValue(param, paramValue, paramType)
|
||||||
|
if err != nil {
|
||||||
|
logs.Debug(fmt.Sprintf("Error converting param %s to type %s. Value: %v, Error: %s", param.name, paramType, paramValue, err))
|
||||||
|
ctx.Abort(400, fmt.Sprintf("Invalid parameter %s. Can not convert %v to type %s", param.name, paramValue, paramType))
|
||||||
|
}
|
||||||
|
|
||||||
|
return reflectValue
|
||||||
|
}
|
||||||
|
|
||||||
|
func getParamValue(param *MethodParam, ctx *beecontext.Context) string {
|
||||||
|
switch param.in {
|
||||||
|
case body:
|
||||||
|
return string(ctx.Input.RequestBody)
|
||||||
|
case header:
|
||||||
|
return ctx.Input.Header(param.name)
|
||||||
|
case path:
|
||||||
|
return ctx.Input.Query(":" + param.name)
|
||||||
|
default:
|
||||||
|
return ctx.Input.Query(param.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseValue(param *MethodParam, paramValue string, paramType reflect.Type) (result reflect.Value, err error) {
|
||||||
|
if paramValue == "" {
|
||||||
|
return reflect.Zero(paramType), nil
|
||||||
|
}
|
||||||
|
parser := getParser(param, paramType)
|
||||||
|
value, err := parser.parse(paramValue, paramType)
|
||||||
|
if err != nil {
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return safeConvert(reflect.ValueOf(value), paramType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func safeConvert(value reflect.Value, t reflect.Type) (result reflect.Value, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
var ok bool
|
||||||
|
err, ok = r.(error)
|
||||||
|
if !ok {
|
||||||
|
err = fmt.Errorf("%v", r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
result = value.Convert(t)
|
||||||
|
return
|
||||||
|
}
|
69
context/param/methodparams.go
Normal file
69
context/param/methodparams.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
package param
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
//MethodParam keeps param information to be auto passed to controller methods
|
||||||
|
type MethodParam struct {
|
||||||
|
name string
|
||||||
|
in paramType
|
||||||
|
required bool
|
||||||
|
defaultValue string
|
||||||
|
}
|
||||||
|
|
||||||
|
type paramType byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
param paramType = iota
|
||||||
|
path
|
||||||
|
body
|
||||||
|
header
|
||||||
|
)
|
||||||
|
|
||||||
|
//New creates a new MethodParam with name and specific options
|
||||||
|
func New(name string, opts ...MethodParamOption) *MethodParam {
|
||||||
|
return newParam(name, nil, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newParam(name string, parser paramParser, opts []MethodParamOption) (param *MethodParam) {
|
||||||
|
param = &MethodParam{name: name}
|
||||||
|
for _, option := range opts {
|
||||||
|
option(param)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//Make creates an array of MethodParmas or an empty array
|
||||||
|
func Make(list ...*MethodParam) []*MethodParam {
|
||||||
|
if len(list) > 0 {
|
||||||
|
return list
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mp *MethodParam) String() string {
|
||||||
|
options := []string{}
|
||||||
|
result := "param.New(\"" + mp.name + "\""
|
||||||
|
if mp.required {
|
||||||
|
options = append(options, "param.IsRequired")
|
||||||
|
}
|
||||||
|
switch mp.in {
|
||||||
|
case path:
|
||||||
|
options = append(options, "param.InPath")
|
||||||
|
case body:
|
||||||
|
options = append(options, "param.InBody")
|
||||||
|
case header:
|
||||||
|
options = append(options, "param.InHeader")
|
||||||
|
}
|
||||||
|
if mp.defaultValue != "" {
|
||||||
|
options = append(options, fmt.Sprintf(`param.Default("%s")`, mp.defaultValue))
|
||||||
|
}
|
||||||
|
if len(options) > 0 {
|
||||||
|
result += ", "
|
||||||
|
}
|
||||||
|
result += strings.Join(options, ", ")
|
||||||
|
result += ")"
|
||||||
|
return result
|
||||||
|
}
|
37
context/param/options.go
Normal file
37
context/param/options.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
package param
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MethodParamOption defines a func which apply options on a MethodParam
|
||||||
|
type MethodParamOption func(*MethodParam)
|
||||||
|
|
||||||
|
// IsRequired indicates that this param is required and can not be omitted from the http request
|
||||||
|
var IsRequired MethodParamOption = func(p *MethodParam) {
|
||||||
|
p.required = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// InHeader indicates that this param is passed via an http header
|
||||||
|
var InHeader MethodParamOption = func(p *MethodParam) {
|
||||||
|
p.in = header
|
||||||
|
}
|
||||||
|
|
||||||
|
// InPath indicates that this param is part of the URL path
|
||||||
|
var InPath MethodParamOption = func(p *MethodParam) {
|
||||||
|
p.in = path
|
||||||
|
}
|
||||||
|
|
||||||
|
// InBody indicates that this param is passed as an http request body
|
||||||
|
var InBody MethodParamOption = func(p *MethodParam) {
|
||||||
|
p.in = body
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default provides a default value for the http param
|
||||||
|
func Default(defaultValue interface{}) MethodParamOption {
|
||||||
|
return func(p *MethodParam) {
|
||||||
|
if defaultValue != nil {
|
||||||
|
p.defaultValue = fmt.Sprint(defaultValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
149
context/param/parsers.go
Normal file
149
context/param/parsers.go
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
package param
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type paramParser interface {
|
||||||
|
parse(value string, toType reflect.Type) (interface{}, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getParser(param *MethodParam, t reflect.Type) paramParser {
|
||||||
|
switch t.Kind() {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||||
|
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
return intParser{}
|
||||||
|
case reflect.Slice:
|
||||||
|
if t.Elem().Kind() == reflect.Uint8 { //treat []byte as string
|
||||||
|
return stringParser{}
|
||||||
|
}
|
||||||
|
if param.in == body {
|
||||||
|
return jsonParser{}
|
||||||
|
}
|
||||||
|
elemParser := getParser(param, t.Elem())
|
||||||
|
if elemParser == (jsonParser{}) {
|
||||||
|
return elemParser
|
||||||
|
}
|
||||||
|
return sliceParser(elemParser)
|
||||||
|
case reflect.Bool:
|
||||||
|
return boolParser{}
|
||||||
|
case reflect.String:
|
||||||
|
return stringParser{}
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
return floatParser{}
|
||||||
|
case reflect.Ptr:
|
||||||
|
elemParser := getParser(param, t.Elem())
|
||||||
|
if elemParser == (jsonParser{}) {
|
||||||
|
return elemParser
|
||||||
|
}
|
||||||
|
return ptrParser(elemParser)
|
||||||
|
default:
|
||||||
|
if t.PkgPath() == "time" && t.Name() == "Time" {
|
||||||
|
return timeParser{}
|
||||||
|
}
|
||||||
|
return jsonParser{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type parserFunc func(value string, toType reflect.Type) (interface{}, error)
|
||||||
|
|
||||||
|
func (f parserFunc) parse(value string, toType reflect.Type) (interface{}, error) {
|
||||||
|
return f(value, toType)
|
||||||
|
}
|
||||||
|
|
||||||
|
type boolParser struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p boolParser) parse(value string, toType reflect.Type) (interface{}, error) {
|
||||||
|
return strconv.ParseBool(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
type stringParser struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p stringParser) parse(value string, toType reflect.Type) (interface{}, error) {
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type intParser struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p intParser) parse(value string, toType reflect.Type) (interface{}, error) {
|
||||||
|
return strconv.Atoi(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
type floatParser struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p floatParser) parse(value string, toType reflect.Type) (interface{}, error) {
|
||||||
|
if toType.Kind() == reflect.Float32 {
|
||||||
|
res, err := strconv.ParseFloat(value, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return float32(res), nil
|
||||||
|
}
|
||||||
|
return strconv.ParseFloat(value, 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
type timeParser struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p timeParser) parse(value string, toType reflect.Type) (result interface{}, err error) {
|
||||||
|
result, err = time.Parse(time.RFC3339, value)
|
||||||
|
if err != nil {
|
||||||
|
result, err = time.Parse("2006-01-02", value)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type jsonParser struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p jsonParser) parse(value string, toType reflect.Type) (interface{}, error) {
|
||||||
|
pResult := reflect.New(toType)
|
||||||
|
v := pResult.Interface()
|
||||||
|
err := json.Unmarshal([]byte(value), v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return pResult.Elem().Interface(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sliceParser(elemParser paramParser) paramParser {
|
||||||
|
return parserFunc(func(value string, toType reflect.Type) (interface{}, error) {
|
||||||
|
values := strings.Split(value, ",")
|
||||||
|
result := reflect.MakeSlice(toType, 0, len(values))
|
||||||
|
elemType := toType.Elem()
|
||||||
|
for _, v := range values {
|
||||||
|
parsedValue, err := elemParser.parse(v, elemType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = reflect.Append(result, reflect.ValueOf(parsedValue))
|
||||||
|
}
|
||||||
|
return result.Interface(), nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ptrParser(elemParser paramParser) paramParser {
|
||||||
|
return parserFunc(func(value string, toType reflect.Type) (interface{}, error) {
|
||||||
|
parsedValue, err := elemParser.parse(value, toType.Elem())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newValPtr := reflect.New(toType.Elem())
|
||||||
|
newVal := reflect.Indirect(newValPtr)
|
||||||
|
convertedVal, err := safeConvert(reflect.ValueOf(parsedValue), toType.Elem())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
newVal.Set(convertedVal)
|
||||||
|
return newValPtr.Interface(), nil
|
||||||
|
})
|
||||||
|
}
|
84
context/param/parsers_test.go
Normal file
84
context/param/parsers_test.go
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
package param
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
import "reflect"
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
type testDefinition struct {
|
||||||
|
strValue string
|
||||||
|
expectedValue interface{}
|
||||||
|
expectedParser paramParser
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Parsers(t *testing.T) {
|
||||||
|
|
||||||
|
//ints
|
||||||
|
checkParser(testDefinition{"1", 1, intParser{}}, t)
|
||||||
|
checkParser(testDefinition{"-1", int64(-1), intParser{}}, t)
|
||||||
|
checkParser(testDefinition{"1", uint64(1), intParser{}}, t)
|
||||||
|
|
||||||
|
//floats
|
||||||
|
checkParser(testDefinition{"1.0", float32(1.0), floatParser{}}, t)
|
||||||
|
checkParser(testDefinition{"-1.0", float64(-1.0), floatParser{}}, t)
|
||||||
|
|
||||||
|
//strings
|
||||||
|
checkParser(testDefinition{"AB", "AB", stringParser{}}, t)
|
||||||
|
checkParser(testDefinition{"AB", []byte{65, 66}, stringParser{}}, t)
|
||||||
|
|
||||||
|
//bools
|
||||||
|
checkParser(testDefinition{"true", true, boolParser{}}, t)
|
||||||
|
checkParser(testDefinition{"0", false, boolParser{}}, t)
|
||||||
|
|
||||||
|
//timeParser
|
||||||
|
checkParser(testDefinition{"2017-05-30T13:54:53Z", time.Date(2017, 5, 30, 13, 54, 53, 0, time.UTC), timeParser{}}, t)
|
||||||
|
checkParser(testDefinition{"2017-05-30", time.Date(2017, 5, 30, 0, 0, 0, 0, time.UTC), timeParser{}}, t)
|
||||||
|
|
||||||
|
//json
|
||||||
|
checkParser(testDefinition{`{"X": 5, "Y":"Z"}`, struct {
|
||||||
|
X int
|
||||||
|
Y string
|
||||||
|
}{5, "Z"}, jsonParser{}}, t)
|
||||||
|
|
||||||
|
//slice in query is parsed as comma delimited
|
||||||
|
checkParser(testDefinition{`1,2`, []int{1, 2}, sliceParser(intParser{})}, t)
|
||||||
|
|
||||||
|
//slice in body is parsed as json
|
||||||
|
checkParser(testDefinition{`["a","b"]`, []string{"a", "b"}, jsonParser{}}, t, MethodParam{in: body})
|
||||||
|
|
||||||
|
//pointers
|
||||||
|
var someInt = 1
|
||||||
|
checkParser(testDefinition{`1`, &someInt, ptrParser(intParser{})}, t)
|
||||||
|
|
||||||
|
var someStruct = struct{ X int }{5}
|
||||||
|
checkParser(testDefinition{`{"X": 5}`, &someStruct, jsonParser{}}, t)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkParser(def testDefinition, t *testing.T, methodParam ...MethodParam) {
|
||||||
|
toType := reflect.TypeOf(def.expectedValue)
|
||||||
|
var mp MethodParam
|
||||||
|
if len(methodParam) == 0 {
|
||||||
|
mp = MethodParam{}
|
||||||
|
} else {
|
||||||
|
mp = methodParam[0]
|
||||||
|
}
|
||||||
|
parser := getParser(&mp, toType)
|
||||||
|
|
||||||
|
if reflect.TypeOf(parser) != reflect.TypeOf(def.expectedParser) {
|
||||||
|
t.Errorf("Invalid parser for value %v. Expected: %v, actual: %v", def.strValue, reflect.TypeOf(def.expectedParser).Name(), reflect.TypeOf(parser).Name())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result, err := parser.parse(def.strValue, toType)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Parsing error for value %v. Expected result: %v, error: %v", def.strValue, def.expectedValue, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
convResult, err := safeConvert(reflect.ValueOf(result), toType)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Conversion error for %v. from value: %v, toType: %v, error: %v", def.strValue, result, toType, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(convResult.Interface(), def.expectedValue) {
|
||||||
|
t.Errorf("Parsing error for value %v. Expected result: %v, actual: %v", def.strValue, def.expectedValue, result)
|
||||||
|
}
|
||||||
|
}
|
12
context/renderer.go
Normal file
12
context/renderer.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package context
|
||||||
|
|
||||||
|
// Renderer defines an http response renderer
|
||||||
|
type Renderer interface {
|
||||||
|
Render(ctx *Context)
|
||||||
|
}
|
||||||
|
|
||||||
|
type rendererFunc func(ctx *Context)
|
||||||
|
|
||||||
|
func (f rendererFunc) Render(ctx *Context) {
|
||||||
|
f(ctx)
|
||||||
|
}
|
27
context/response.go
Normal file
27
context/response.go
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
package context
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
//BadRequest indicates http error 400
|
||||||
|
BadRequest StatusCode = http.StatusBadRequest
|
||||||
|
|
||||||
|
//NotFound indicates http error 404
|
||||||
|
NotFound StatusCode = http.StatusNotFound
|
||||||
|
)
|
||||||
|
|
||||||
|
// StatusCode sets the http response status code
|
||||||
|
type StatusCode int
|
||||||
|
|
||||||
|
func (s StatusCode) Error() string {
|
||||||
|
return strconv.Itoa(int(s))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render sets the http status code
|
||||||
|
func (s StatusCode) Render(ctx *Context) {
|
||||||
|
ctx.Output.SetStatus(int(s))
|
||||||
|
}
|
@ -28,6 +28,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/astaxie/beego/context"
|
"github.com/astaxie/beego/context"
|
||||||
|
"github.com/astaxie/beego/context/param"
|
||||||
"github.com/astaxie/beego/session"
|
"github.com/astaxie/beego/session"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -51,8 +52,16 @@ type ControllerComments struct {
|
|||||||
Router string
|
Router string
|
||||||
AllowHTTPMethods []string
|
AllowHTTPMethods []string
|
||||||
Params []map[string]string
|
Params []map[string]string
|
||||||
|
MethodParams []*param.MethodParam
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ControllerCommentsSlice implements the sort interface
|
||||||
|
type ControllerCommentsSlice []ControllerComments
|
||||||
|
|
||||||
|
func (p ControllerCommentsSlice) Len() int { return len(p) }
|
||||||
|
func (p ControllerCommentsSlice) Less(i, j int) bool { return p[i].Router < p[j].Router }
|
||||||
|
func (p ControllerCommentsSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||||
|
|
||||||
// Controller defines some basic http request handler operations, such as
|
// Controller defines some basic http request handler operations, such as
|
||||||
// http context, template and view, session and xsrf.
|
// http context, template and view, session and xsrf.
|
||||||
type Controller struct {
|
type Controller struct {
|
||||||
@ -69,6 +78,7 @@ type Controller struct {
|
|||||||
|
|
||||||
// template data
|
// template data
|
||||||
TplName string
|
TplName string
|
||||||
|
ViewPath string
|
||||||
Layout string
|
Layout string
|
||||||
LayoutSections map[string]string // the key is the section name and the value is the template name
|
LayoutSections map[string]string // the key is the section name and the value is the template name
|
||||||
TplPrefix string
|
TplPrefix string
|
||||||
@ -185,7 +195,11 @@ func (c *Controller) Render() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8")
|
|
||||||
|
if c.Ctx.ResponseWriter.Header().Get("Content-Type") == "" {
|
||||||
|
c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8")
|
||||||
|
}
|
||||||
|
|
||||||
return c.Ctx.Output.Body(rb)
|
return c.Ctx.Output.Body(rb)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -209,7 +223,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
err = ExecuteTemplate(&buf, sectionTpl, c.Data)
|
err = ExecuteViewPathTemplate(&buf, sectionTpl, c.viewPath(), c.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -218,7 +232,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
ExecuteTemplate(&buf, c.Layout, c.Data)
|
ExecuteViewPathTemplate(&buf, c.Layout, c.viewPath(), c.Data)
|
||||||
}
|
}
|
||||||
return buf.Bytes(), err
|
return buf.Bytes(), err
|
||||||
}
|
}
|
||||||
@ -244,9 +258,16 @@ func (c *Controller) renderTemplate() (bytes.Buffer, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
BuildTemplate(BConfig.WebConfig.ViewsPath, buildFiles...)
|
BuildTemplate(c.viewPath(), buildFiles...)
|
||||||
}
|
}
|
||||||
return buf, ExecuteTemplate(&buf, c.TplName, c.Data)
|
return buf, ExecuteViewPathTemplate(&buf, c.TplName, c.viewPath(), c.Data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) viewPath() string {
|
||||||
|
if c.ViewPath == "" {
|
||||||
|
return BConfig.WebConfig.ViewsPath
|
||||||
|
}
|
||||||
|
return c.ViewPath
|
||||||
}
|
}
|
||||||
|
|
||||||
// Redirect sends the redirection response to url with status code.
|
// Redirect sends the redirection response to url with status code.
|
||||||
@ -302,7 +323,7 @@ func (c *Controller) ServeJSON(encoding ...bool) {
|
|||||||
if BConfig.RunMode == PROD {
|
if BConfig.RunMode == PROD {
|
||||||
hasIndent = false
|
hasIndent = false
|
||||||
}
|
}
|
||||||
if len(encoding) > 0 && encoding[0] == true {
|
if len(encoding) > 0 && encoding[0] {
|
||||||
hasEncoding = true
|
hasEncoding = true
|
||||||
}
|
}
|
||||||
c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding)
|
c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding)
|
||||||
|
@ -20,6 +20,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/astaxie/beego/context"
|
"github.com/astaxie/beego/context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetInt(t *testing.T) {
|
func TestGetInt(t *testing.T) {
|
||||||
@ -121,3 +123,59 @@ func TestGetUint64(t *testing.T) {
|
|||||||
t.Errorf("TestGetUint64 expect %v,get %T,%v", uint64(math.MaxUint64), val, val)
|
t.Errorf("TestGetUint64 expect %v,get %T,%v", uint64(math.MaxUint64), val, val)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAdditionalViewPaths(t *testing.T) {
|
||||||
|
dir1 := "_beeTmp"
|
||||||
|
dir2 := "_beeTmp2"
|
||||||
|
defer os.RemoveAll(dir1)
|
||||||
|
defer os.RemoveAll(dir2)
|
||||||
|
|
||||||
|
dir1file := "file1.tpl"
|
||||||
|
dir2file := "file2.tpl"
|
||||||
|
|
||||||
|
genFile := func(dir string, name string, content string) {
|
||||||
|
os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777)
|
||||||
|
if f, err := os.Create(filepath.Join(dir, name)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else {
|
||||||
|
defer f.Close()
|
||||||
|
f.WriteString(content)
|
||||||
|
f.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
genFile(dir1, dir1file, `<div>{{.Content}}</div>`)
|
||||||
|
genFile(dir2, dir2file, `<html>{{.Content}}</html>`)
|
||||||
|
|
||||||
|
AddViewPath(dir1)
|
||||||
|
AddViewPath(dir2)
|
||||||
|
|
||||||
|
ctrl := Controller{
|
||||||
|
TplName: "file1.tpl",
|
||||||
|
ViewPath: dir1,
|
||||||
|
}
|
||||||
|
ctrl.Data = map[interface{}]interface{}{
|
||||||
|
"Content": "value2",
|
||||||
|
}
|
||||||
|
if result, err := ctrl.RenderString(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else {
|
||||||
|
if result != "<div>value2</div>" {
|
||||||
|
t.Fatalf("TestAdditionalViewPaths expect %s got %s", "<div>value2</div>", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func() {
|
||||||
|
ctrl.TplName = "file2.tpl"
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Fatal("TestAdditionalViewPaths expected error")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
ctrl.RenderString()
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctrl.TplName = "file2.tpl"
|
||||||
|
ctrl.ViewPath = dir2
|
||||||
|
ctrl.RenderString()
|
||||||
|
}
|
||||||
|
24
error.go
24
error.go
@ -252,6 +252,30 @@ func forbidden(rw http.ResponseWriter, r *http.Request) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// show 422 missing xsrf token
|
||||||
|
func missingxsrf(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
responseError(rw, r,
|
||||||
|
422,
|
||||||
|
"<br>The page you have requested is forbidden."+
|
||||||
|
"<br>Perhaps you are here because:"+
|
||||||
|
"<br><br><ul>"+
|
||||||
|
"<br>'_xsrf' argument missing from POST"+
|
||||||
|
"</ul>",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// show 417 invalid xsrf token
|
||||||
|
func invalidxsrf(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
responseError(rw, r,
|
||||||
|
417,
|
||||||
|
"<br>The page you have requested is forbidden."+
|
||||||
|
"<br>Perhaps you are here because:"+
|
||||||
|
"<br><br><ul>"+
|
||||||
|
"<br>expected XSRF not found"+
|
||||||
|
"</ul>",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// show 404 not found error.
|
// show 404 not found error.
|
||||||
func notFound(rw http.ResponseWriter, r *http.Request) {
|
func notFound(rw http.ResponseWriter, r *http.Request) {
|
||||||
responseError(rw, r,
|
responseError(rw, r,
|
||||||
|
@ -52,7 +52,7 @@ func TestErrorCode_01(t *testing.T) {
|
|||||||
if w.Code != code {
|
if w.Code != code {
|
||||||
t.Fail()
|
t.Fail()
|
||||||
}
|
}
|
||||||
if !strings.Contains(string(w.Body.Bytes()), http.StatusText(code)) {
|
if !strings.Contains(w.Body.String(), http.StatusText(code)) {
|
||||||
t.Fail()
|
t.Fail()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -82,7 +82,7 @@ func TestErrorCode_03(t *testing.T) {
|
|||||||
if w.Code != 200 {
|
if w.Code != 200 {
|
||||||
t.Fail()
|
t.Fail()
|
||||||
}
|
}
|
||||||
if string(w.Body.Bytes()) != parseCodeError {
|
if w.Body.String() != parseCodeError {
|
||||||
t.Fail()
|
t.Fail()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -48,7 +48,7 @@ func TestFlashHeader(t *testing.T) {
|
|||||||
// match for the expected header
|
// match for the expected header
|
||||||
res := strings.Contains(sc, "BEEGO_FLASH=%00notice%23BEEGOFLASH%23TestFlashString%00")
|
res := strings.Contains(sc, "BEEGO_FLASH=%00notice%23BEEGOFLASH%23TestFlashString%00")
|
||||||
// validate the assertion
|
// validate the assertion
|
||||||
if res != true {
|
if !res {
|
||||||
t.Errorf("TestFlashHeader() unable to validate flash message")
|
t.Errorf("TestFlashHeader() unable to validate flash message")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,14 +3,17 @@ package grace
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type graceConn struct {
|
type graceConn struct {
|
||||||
net.Conn
|
net.Conn
|
||||||
server *Server
|
server *Server
|
||||||
|
m sync.Mutex
|
||||||
|
closed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c graceConn) Close() (err error) {
|
func (c *graceConn) Close() (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
switch x := r.(type) {
|
switch x := r.(type) {
|
||||||
@ -23,6 +26,14 @@ func (c graceConn) Close() (err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
c.m.Lock()
|
||||||
|
if c.closed {
|
||||||
|
c.m.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
c.server.wg.Done()
|
c.server.wg.Done()
|
||||||
|
c.closed = true
|
||||||
|
c.m.Unlock()
|
||||||
return c.Conn.Close()
|
return c.Conn.Close()
|
||||||
}
|
}
|
||||||
|
@ -85,23 +85,31 @@ var (
|
|||||||
|
|
||||||
isChild bool
|
isChild bool
|
||||||
socketOrder string
|
socketOrder string
|
||||||
once sync.Once
|
|
||||||
|
hookableSignals []os.Signal
|
||||||
)
|
)
|
||||||
|
|
||||||
func onceInit() {
|
func init() {
|
||||||
regLock = &sync.Mutex{}
|
|
||||||
flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)")
|
flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)")
|
||||||
flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started")
|
flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started")
|
||||||
|
|
||||||
|
regLock = &sync.Mutex{}
|
||||||
runningServers = make(map[string]*Server)
|
runningServers = make(map[string]*Server)
|
||||||
runningServersOrder = []string{}
|
runningServersOrder = []string{}
|
||||||
socketPtrOffsetMap = make(map[string]uint)
|
socketPtrOffsetMap = make(map[string]uint)
|
||||||
|
|
||||||
|
hookableSignals = []os.Signal{
|
||||||
|
syscall.SIGHUP,
|
||||||
|
syscall.SIGINT,
|
||||||
|
syscall.SIGTERM,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer returns a new graceServer.
|
// NewServer returns a new graceServer.
|
||||||
func NewServer(addr string, handler http.Handler) (srv *Server) {
|
func NewServer(addr string, handler http.Handler) (srv *Server) {
|
||||||
once.Do(onceInit)
|
|
||||||
regLock.Lock()
|
regLock.Lock()
|
||||||
defer regLock.Unlock()
|
defer regLock.Unlock()
|
||||||
|
|
||||||
if !flag.Parsed() {
|
if !flag.Parsed() {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,7 @@ func newGraceListener(l net.Listener, srv *Server) (el *graceListener) {
|
|||||||
server: srv,
|
server: srv,
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
_ = <-el.stop
|
<-el.stop
|
||||||
el.stopped = true
|
el.stopped = true
|
||||||
el.stop <- el.Listener.Close()
|
el.stop <- el.Listener.Close()
|
||||||
}()
|
}()
|
||||||
@ -37,7 +37,7 @@ func (gl *graceListener) Accept() (c net.Conn, err error) {
|
|||||||
tc.SetKeepAlive(true)
|
tc.SetKeepAlive(true)
|
||||||
tc.SetKeepAlivePeriod(3 * time.Minute)
|
tc.SetKeepAlivePeriod(3 * time.Minute)
|
||||||
|
|
||||||
c = graceConn{
|
c = &graceConn{
|
||||||
Conn: tc,
|
Conn: tc,
|
||||||
server: gl.server,
|
server: gl.server,
|
||||||
}
|
}
|
||||||
|
@ -162,9 +162,7 @@ func (srv *Server) handleSignals() {
|
|||||||
|
|
||||||
signal.Notify(
|
signal.Notify(
|
||||||
srv.sigChan,
|
srv.sigChan,
|
||||||
syscall.SIGHUP,
|
hookableSignals...,
|
||||||
syscall.SIGINT,
|
|
||||||
syscall.SIGTERM,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
pid := syscall.Getpid()
|
pid := syscall.Getpid()
|
||||||
@ -198,7 +196,6 @@ func (srv *Server) signalHooks(ppFlag int, sig os.Signal) {
|
|||||||
for _, f := range srv.SignalHooks[ppFlag][sig] {
|
for _, f := range srv.SignalHooks[ppFlag][sig] {
|
||||||
f()
|
f()
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// shutdown closes the listener so that no new connections are accepted. it also
|
// shutdown closes the listener so that no new connections are accepted. it also
|
||||||
@ -290,3 +287,19 @@ func (srv *Server) fork() (err error) {
|
|||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal.
|
||||||
|
func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) {
|
||||||
|
if ppFlag != PreSignal && ppFlag != PostSignal {
|
||||||
|
err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, s := range hookableSignals {
|
||||||
|
if s == sig {
|
||||||
|
srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("Signal '%v' is not supported", sig)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
11
hooks.go
11
hooks.go
@ -32,6 +32,8 @@ func registerDefaultErrorHandler() error {
|
|||||||
"502": badGateway,
|
"502": badGateway,
|
||||||
"503": serviceUnavailable,
|
"503": serviceUnavailable,
|
||||||
"504": gatewayTimeout,
|
"504": gatewayTimeout,
|
||||||
|
"417": invalidxsrf,
|
||||||
|
"422": missingxsrf,
|
||||||
}
|
}
|
||||||
for e, h := range m {
|
for e, h := range m {
|
||||||
if _, ok := ErrorMaps[e]; !ok {
|
if _, ok := ErrorMaps[e]; !ok {
|
||||||
@ -55,9 +57,9 @@ func registerSession() error {
|
|||||||
conf.ProviderConfig = filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig)
|
conf.ProviderConfig = filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig)
|
||||||
conf.DisableHTTPOnly = BConfig.WebConfig.Session.SessionDisableHTTPOnly
|
conf.DisableHTTPOnly = BConfig.WebConfig.Session.SessionDisableHTTPOnly
|
||||||
conf.Domain = BConfig.WebConfig.Session.SessionDomain
|
conf.Domain = BConfig.WebConfig.Session.SessionDomain
|
||||||
conf.EnableSidInHttpHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader
|
conf.EnableSidInHTTPHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader
|
||||||
conf.SessionNameInHttpHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader
|
conf.SessionNameInHTTPHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader
|
||||||
conf.EnableSidInUrlQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery
|
conf.EnableSidInURLQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery
|
||||||
} else {
|
} else {
|
||||||
if err = json.Unmarshal([]byte(sessionConfig), conf); err != nil {
|
if err = json.Unmarshal([]byte(sessionConfig), conf); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -72,7 +74,8 @@ func registerSession() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func registerTemplate() error {
|
func registerTemplate() error {
|
||||||
if err := BuildTemplate(BConfig.WebConfig.ViewsPath); err != nil {
|
defer lockViewPaths()
|
||||||
|
if err := AddViewPath(BConfig.WebConfig.ViewsPath); err != nil {
|
||||||
if BConfig.RunMode == DEV {
|
if BConfig.RunMode == DEV {
|
||||||
logs.Warn(err)
|
logs.Warn(err)
|
||||||
}
|
}
|
||||||
|
@ -32,7 +32,7 @@ The default timeout is `60` seconds, function prototype:
|
|||||||
|
|
||||||
SetTimeout(connectTimeout, readWriteTimeout time.Duration)
|
SetTimeout(connectTimeout, readWriteTimeout time.Duration)
|
||||||
|
|
||||||
Exmaple:
|
Example:
|
||||||
|
|
||||||
// GET
|
// GET
|
||||||
httplib.Get("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second)
|
httplib.Get("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second)
|
||||||
|
@ -140,6 +140,7 @@ type BeegoHTTPSettings struct {
|
|||||||
EnableCookie bool
|
EnableCookie bool
|
||||||
Gzip bool
|
Gzip bool
|
||||||
DumpBody bool
|
DumpBody bool
|
||||||
|
Retries int // if set to -1 means will retry forever
|
||||||
}
|
}
|
||||||
|
|
||||||
// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request.
|
// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request.
|
||||||
@ -189,6 +190,15 @@ func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Retries sets Retries times.
|
||||||
|
// default is 0 means no retried.
|
||||||
|
// -1 means retried forever.
|
||||||
|
// others means retried times.
|
||||||
|
func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest {
|
||||||
|
b.setting.Retries = times
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
// DumpBody setting whether need to Dump the Body.
|
// DumpBody setting whether need to Dump the Body.
|
||||||
func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest {
|
func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest {
|
||||||
b.setting.DumpBody = isdump
|
b.setting.DumpBody = isdump
|
||||||
@ -325,7 +335,7 @@ func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error)
|
|||||||
func (b *BeegoHTTPRequest) buildURL(paramBody string) {
|
func (b *BeegoHTTPRequest) buildURL(paramBody string) {
|
||||||
// build GET url with query string
|
// build GET url with query string
|
||||||
if b.req.Method == "GET" && len(paramBody) > 0 {
|
if b.req.Method == "GET" && len(paramBody) > 0 {
|
||||||
if strings.Index(b.url, "?") != -1 {
|
if strings.Contains(b.url, "?") {
|
||||||
b.url += "&" + paramBody
|
b.url += "&" + paramBody
|
||||||
} else {
|
} else {
|
||||||
b.url = b.url + "?" + paramBody
|
b.url = b.url + "?" + paramBody
|
||||||
@ -334,7 +344,7 @@ func (b *BeegoHTTPRequest) buildURL(paramBody string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// build POST/PUT/PATCH url and body
|
// build POST/PUT/PATCH url and body
|
||||||
if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH") && b.req.Body == nil {
|
if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH" || b.req.Method == "DELETE") && b.req.Body == nil {
|
||||||
// with files
|
// with files
|
||||||
if len(b.files) > 0 {
|
if len(b.files) > 0 {
|
||||||
pr, pw := io.Pipe()
|
pr, pw := io.Pipe()
|
||||||
@ -390,7 +400,7 @@ func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DoRequest will do the client.Do
|
// DoRequest will do the client.Do
|
||||||
func (b *BeegoHTTPRequest) DoRequest() (*http.Response, error) {
|
func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) {
|
||||||
var paramBody string
|
var paramBody string
|
||||||
if len(b.params) > 0 {
|
if len(b.params) > 0 {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
@ -467,7 +477,16 @@ func (b *BeegoHTTPRequest) DoRequest() (*http.Response, error) {
|
|||||||
}
|
}
|
||||||
b.dump = dump
|
b.dump = dump
|
||||||
}
|
}
|
||||||
return client.Do(b.req)
|
// retries default value is 0, it will run once.
|
||||||
|
// retries equal to -1, it will run forever until success
|
||||||
|
// retries is setted, it will retries fixed times.
|
||||||
|
for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ {
|
||||||
|
resp, err = client.Do(b.req)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// String returns the body string in response.
|
// String returns the body string in response.
|
||||||
@ -501,9 +520,9 @@ func (b *BeegoHTTPRequest) Bytes() ([]byte, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
b.body, err = ioutil.ReadAll(reader)
|
b.body, err = ioutil.ReadAll(reader)
|
||||||
} else {
|
return b.body, err
|
||||||
b.body, err = ioutil.ReadAll(resp.Body)
|
|
||||||
}
|
}
|
||||||
|
b.body, err = ioutil.ReadAll(resp.Body)
|
||||||
return b.body, err
|
return b.body, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,6 +102,14 @@ func TestSimpleDelete(t *testing.T) {
|
|||||||
t.Log(str)
|
t.Log(str)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSimpleDeleteParam(t *testing.T) {
|
||||||
|
str, err := Delete("http://httpbin.org/delete").Param("key", "val").String()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log(str)
|
||||||
|
}
|
||||||
|
|
||||||
func TestWithCookie(t *testing.T) {
|
func TestWithCookie(t *testing.T) {
|
||||||
v := "smallfish"
|
v := "smallfish"
|
||||||
str, err := Get("http://httpbin.org/cookies/set?k1=" + v).SetEnableCookie(true).String()
|
str, err := Get("http://httpbin.org/cookies/set?k1=" + v).SetEnableCookie(true).String()
|
||||||
|
186
logs/alils/alils.go
Normal file
186
logs/alils/alils.go
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
package alils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/astaxie/beego/logs"
|
||||||
|
"github.com/gogo/protobuf/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// CacheSize set the flush size
|
||||||
|
CacheSize int = 64
|
||||||
|
// Delimiter define the topic delimiter
|
||||||
|
Delimiter string = "##"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config is the Config for Ali Log
|
||||||
|
type Config struct {
|
||||||
|
Project string `json:"project"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
KeyID string `json:"key_id"`
|
||||||
|
KeySecret string `json:"key_secret"`
|
||||||
|
LogStore string `json:"log_store"`
|
||||||
|
Topics []string `json:"topics"`
|
||||||
|
Source string `json:"source"`
|
||||||
|
Level int `json:"level"`
|
||||||
|
FlushWhen int `json:"flush_when"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// aliLSWriter implements LoggerInterface.
|
||||||
|
// it writes messages in keep-live tcp connection.
|
||||||
|
type aliLSWriter struct {
|
||||||
|
store *LogStore
|
||||||
|
group []*LogGroup
|
||||||
|
withMap bool
|
||||||
|
groupMap map[string]*LogGroup
|
||||||
|
lock *sync.Mutex
|
||||||
|
Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAliLS create a new Logger
|
||||||
|
func NewAliLS() logs.Logger {
|
||||||
|
alils := new(aliLSWriter)
|
||||||
|
alils.Level = logs.LevelTrace
|
||||||
|
return alils
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init parse config and init struct
|
||||||
|
func (c *aliLSWriter) Init(jsonConfig string) (err error) {
|
||||||
|
|
||||||
|
json.Unmarshal([]byte(jsonConfig), c)
|
||||||
|
|
||||||
|
if c.FlushWhen > CacheSize {
|
||||||
|
c.FlushWhen = CacheSize
|
||||||
|
}
|
||||||
|
|
||||||
|
prj := &LogProject{
|
||||||
|
Name: c.Project,
|
||||||
|
Endpoint: c.Endpoint,
|
||||||
|
AccessKeyID: c.KeyID,
|
||||||
|
AccessKeySecret: c.KeySecret,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.store, err = prj.GetLogStore(c.LogStore)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create default Log Group
|
||||||
|
c.group = append(c.group, &LogGroup{
|
||||||
|
Topic: proto.String(""),
|
||||||
|
Source: proto.String(c.Source),
|
||||||
|
Logs: make([]*Log, 0, c.FlushWhen),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create other Log Group
|
||||||
|
c.groupMap = make(map[string]*LogGroup)
|
||||||
|
for _, topic := range c.Topics {
|
||||||
|
|
||||||
|
lg := &LogGroup{
|
||||||
|
Topic: proto.String(topic),
|
||||||
|
Source: proto.String(c.Source),
|
||||||
|
Logs: make([]*Log, 0, c.FlushWhen),
|
||||||
|
}
|
||||||
|
|
||||||
|
c.group = append(c.group, lg)
|
||||||
|
c.groupMap[topic] = lg
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c.group) == 1 {
|
||||||
|
c.withMap = false
|
||||||
|
} else {
|
||||||
|
c.withMap = true
|
||||||
|
}
|
||||||
|
|
||||||
|
c.lock = &sync.Mutex{}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteMsg write message in connection.
|
||||||
|
// if connection is down, try to re-connect.
|
||||||
|
func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error) {
|
||||||
|
|
||||||
|
if level > c.Level {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var topic string
|
||||||
|
var content string
|
||||||
|
var lg *LogGroup
|
||||||
|
if c.withMap {
|
||||||
|
|
||||||
|
// Topic,LogGroup
|
||||||
|
strs := strings.SplitN(msg, Delimiter, 2)
|
||||||
|
if len(strs) == 2 {
|
||||||
|
pos := strings.LastIndex(strs[0], " ")
|
||||||
|
topic = strs[0][pos+1 : len(strs[0])]
|
||||||
|
content = strs[0][0:pos] + strs[1]
|
||||||
|
lg = c.groupMap[topic]
|
||||||
|
}
|
||||||
|
|
||||||
|
// send to empty Topic
|
||||||
|
if lg == nil {
|
||||||
|
content = msg
|
||||||
|
lg = c.group[0]
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
content = msg
|
||||||
|
lg = c.group[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
c1 := &LogContent{
|
||||||
|
Key: proto.String("msg"),
|
||||||
|
Value: proto.String(content),
|
||||||
|
}
|
||||||
|
|
||||||
|
l := &Log{
|
||||||
|
Time: proto.Uint32(uint32(when.Unix())),
|
||||||
|
Contents: []*LogContent{
|
||||||
|
c1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c.lock.Lock()
|
||||||
|
lg.Logs = append(lg.Logs, l)
|
||||||
|
c.lock.Unlock()
|
||||||
|
|
||||||
|
if len(lg.Logs) >= c.FlushWhen {
|
||||||
|
c.flush(lg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush implementing method. empty.
|
||||||
|
func (c *aliLSWriter) Flush() {
|
||||||
|
|
||||||
|
// flush all group
|
||||||
|
for _, lg := range c.group {
|
||||||
|
c.flush(lg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Destroy destroy connection writer and close tcp listener.
|
||||||
|
func (c *aliLSWriter) Destroy() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *aliLSWriter) flush(lg *LogGroup) {
|
||||||
|
|
||||||
|
c.lock.Lock()
|
||||||
|
defer c.lock.Unlock()
|
||||||
|
err := c.store.PutLogs(lg)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lg.Logs = make([]*Log, 0, c.FlushWhen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
logs.Register(logs.AdapterAliLS, NewAliLS)
|
||||||
|
}
|
13
logs/alils/config.go
Executable file
13
logs/alils/config.go
Executable file
@ -0,0 +1,13 @@
|
|||||||
|
package alils
|
||||||
|
|
||||||
|
const (
|
||||||
|
version = "0.5.0" // SDK version
|
||||||
|
signatureMethod = "hmac-sha1" // Signature method
|
||||||
|
|
||||||
|
// OffsetNewest stands for the log head offset, i.e. the offset that will be
|
||||||
|
// assigned to the next message that will be produced to the shard.
|
||||||
|
OffsetNewest = "end"
|
||||||
|
// OffsetOldest stands for the oldest offset available on the logstore for a
|
||||||
|
// shard.
|
||||||
|
OffsetOldest = "begin"
|
||||||
|
)
|
1038
logs/alils/log.pb.go
Executable file
1038
logs/alils/log.pb.go
Executable file
File diff suppressed because it is too large
Load Diff
42
logs/alils/log_config.go
Executable file
42
logs/alils/log_config.go
Executable file
@ -0,0 +1,42 @@
|
|||||||
|
package alils
|
||||||
|
|
||||||
|
// InputDetail define log detail
|
||||||
|
type InputDetail struct {
|
||||||
|
LogType string `json:"logType"`
|
||||||
|
LogPath string `json:"logPath"`
|
||||||
|
FilePattern string `json:"filePattern"`
|
||||||
|
LocalStorage bool `json:"localStorage"`
|
||||||
|
TimeFormat string `json:"timeFormat"`
|
||||||
|
LogBeginRegex string `json:"logBeginRegex"`
|
||||||
|
Regex string `json:"regex"`
|
||||||
|
Keys []string `json:"key"`
|
||||||
|
FilterKeys []string `json:"filterKey"`
|
||||||
|
FilterRegex []string `json:"filterRegex"`
|
||||||
|
TopicFormat string `json:"topicFormat"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OutputDetail define the output detail
|
||||||
|
type OutputDetail struct {
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
LogStoreName string `json:"logstoreName"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogConfig define Log Config
|
||||||
|
type LogConfig struct {
|
||||||
|
Name string `json:"configName"`
|
||||||
|
InputType string `json:"inputType"`
|
||||||
|
InputDetail InputDetail `json:"inputDetail"`
|
||||||
|
OutputType string `json:"outputType"`
|
||||||
|
OutputDetail OutputDetail `json:"outputDetail"`
|
||||||
|
|
||||||
|
CreateTime uint32
|
||||||
|
LastModifyTime uint32
|
||||||
|
|
||||||
|
project *LogProject
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAppliedMachineGroup returns applied machine group of this config.
|
||||||
|
func (c *LogConfig) GetAppliedMachineGroup(confName string) (groupNames []string, err error) {
|
||||||
|
groupNames, err = c.project.GetAppliedMachineGroups(c.Name)
|
||||||
|
return
|
||||||
|
}
|
819
logs/alils/log_project.go
Executable file
819
logs/alils/log_project.go
Executable file
@ -0,0 +1,819 @@
|
|||||||
|
/*
|
||||||
|
Package alils implements the SDK(v0.5.0) of Simple Log Service(abbr. SLS).
|
||||||
|
|
||||||
|
For more description about SLS, please read this article:
|
||||||
|
http://gitlab.alibaba-inc.com/sls/doc.
|
||||||
|
*/
|
||||||
|
package alils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Error message in SLS HTTP response.
|
||||||
|
type errorMessage struct {
|
||||||
|
Code string `json:"errorCode"`
|
||||||
|
Message string `json:"errorMessage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogProject Define the Ali Project detail
|
||||||
|
type LogProject struct {
|
||||||
|
Name string // Project name
|
||||||
|
Endpoint string // IP or hostname of SLS endpoint
|
||||||
|
AccessKeyID string
|
||||||
|
AccessKeySecret string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLogProject creates a new SLS project.
|
||||||
|
func NewLogProject(name, endpoint, AccessKeyID, accessKeySecret string) (p *LogProject, err error) {
|
||||||
|
p = &LogProject{
|
||||||
|
Name: name,
|
||||||
|
Endpoint: endpoint,
|
||||||
|
AccessKeyID: AccessKeyID,
|
||||||
|
AccessKeySecret: accessKeySecret,
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListLogStore returns all logstore names of project p.
|
||||||
|
func (p *LogProject) ListLogStore() (storeNames []string, err error) {
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
uri := fmt.Sprintf("/logstores")
|
||||||
|
r, err := request(p, "GET", uri, h, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(buf, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to list logstore")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Printf("%s\n", dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type Body struct {
|
||||||
|
Count int
|
||||||
|
LogStores []string
|
||||||
|
}
|
||||||
|
body := &Body{}
|
||||||
|
|
||||||
|
err = json.Unmarshal(buf, body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
storeNames = body.LogStores
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLogStore returns logstore according by logstore name.
|
||||||
|
func (p *LogProject) GetLogStore(name string) (s *LogStore, err error) {
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := request(p, "GET", "/logstores/"+name, h, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(buf, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to get logstore")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Printf("%s\n", dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s = &LogStore{}
|
||||||
|
err = json.Unmarshal(buf, s)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.project = p
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateLogStore creates a new logstore in SLS,
|
||||||
|
// where name is logstore name,
|
||||||
|
// and ttl is time-to-live(in day) of logs,
|
||||||
|
// and shardCnt is the number of shards.
|
||||||
|
func (p *LogProject) CreateLogStore(name string, ttl, shardCnt int) (err error) {
|
||||||
|
|
||||||
|
type Body struct {
|
||||||
|
Name string `json:"logstoreName"`
|
||||||
|
TTL int `json:"ttl"`
|
||||||
|
ShardCount int `json:"shardCount"`
|
||||||
|
}
|
||||||
|
|
||||||
|
store := &Body{
|
||||||
|
Name: name,
|
||||||
|
TTL: ttl,
|
||||||
|
ShardCount: shardCnt,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(store)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept-Encoding": "deflate", // TODO: support lz4
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := request(p, "POST", "/logstores", h, body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err = ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(body, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to create logstore")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Printf("%s\n", dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteLogStore deletes a logstore according by logstore name.
|
||||||
|
func (p *LogProject) DeleteLogStore(name string) (err error) {
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := request(p, "DELETE", "/logstores/"+name, h, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(body, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to delete logstore")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Printf("%s\n", dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateLogStore updates a logstore according by logstore name,
|
||||||
|
// obviously we can't modify the logstore name itself.
|
||||||
|
func (p *LogProject) UpdateLogStore(name string, ttl, shardCnt int) (err error) {
|
||||||
|
|
||||||
|
type Body struct {
|
||||||
|
Name string `json:"logstoreName"`
|
||||||
|
TTL int `json:"ttl"`
|
||||||
|
ShardCount int `json:"shardCount"`
|
||||||
|
}
|
||||||
|
|
||||||
|
store := &Body{
|
||||||
|
Name: name,
|
||||||
|
TTL: ttl,
|
||||||
|
ShardCount: shardCnt,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(store)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept-Encoding": "deflate", // TODO: support lz4
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := request(p, "PUT", "/logstores", h, body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err = ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(body, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to update logstore")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Printf("%s\n", dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListMachineGroup returns machine group name list and the total number of machine groups.
|
||||||
|
// The offset starts from 0 and the size is the max number of machine groups could be returned.
|
||||||
|
func (p *LogProject) ListMachineGroup(offset, size int) (m []string, total int, err error) {
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
if size <= 0 {
|
||||||
|
size = 500
|
||||||
|
}
|
||||||
|
|
||||||
|
uri := fmt.Sprintf("/machinegroups?offset=%v&size=%v", offset, size)
|
||||||
|
r, err := request(p, "GET", uri, h, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(buf, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to list machine group")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Printf("%s\n", dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type Body struct {
|
||||||
|
MachineGroups []string
|
||||||
|
Count int
|
||||||
|
Total int
|
||||||
|
}
|
||||||
|
body := &Body{}
|
||||||
|
|
||||||
|
err = json.Unmarshal(buf, body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m = body.MachineGroups
|
||||||
|
total = body.Total
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMachineGroup retruns machine group according by machine group name.
|
||||||
|
func (p *LogProject) GetMachineGroup(name string) (m *MachineGroup, err error) {
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := request(p, "GET", "/machinegroups/"+name, h, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(buf, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to get machine group:%v", name)
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Printf("%s\n", dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m = &MachineGroup{}
|
||||||
|
err = json.Unmarshal(buf, m)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.project = p
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateMachineGroup creates a new machine group in SLS.
|
||||||
|
func (p *LogProject) CreateMachineGroup(m *MachineGroup) (err error) {
|
||||||
|
|
||||||
|
body, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept-Encoding": "deflate", // TODO: support lz4
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := request(p, "POST", "/machinegroups", h, body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err = ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(body, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to create machine group")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Printf("%s\n", dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateMachineGroup updates a machine group.
|
||||||
|
func (p *LogProject) UpdateMachineGroup(m *MachineGroup) (err error) {
|
||||||
|
|
||||||
|
body, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept-Encoding": "deflate", // TODO: support lz4
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := request(p, "PUT", "/machinegroups/"+m.Name, h, body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err = ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(body, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to update machine group")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Printf("%s\n", dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteMachineGroup deletes machine group according machine group name.
|
||||||
|
func (p *LogProject) DeleteMachineGroup(name string) (err error) {
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := request(p, "DELETE", "/machinegroups/"+name, h, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(body, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to delete machine group")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Printf("%s\n", dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListConfig returns config names list and the total number of configs.
|
||||||
|
// The offset starts from 0 and the size is the max number of configs could be returned.
|
||||||
|
func (p *LogProject) ListConfig(offset, size int) (cfgNames []string, total int, err error) {
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
if size <= 0 {
|
||||||
|
size = 100
|
||||||
|
}
|
||||||
|
|
||||||
|
uri := fmt.Sprintf("/configs?offset=%v&size=%v", offset, size)
|
||||||
|
r, err := request(p, "GET", uri, h, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(buf, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to delete machine group")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Printf("%s\n", dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type Body struct {
|
||||||
|
Total int
|
||||||
|
Configs []string
|
||||||
|
}
|
||||||
|
body := &Body{}
|
||||||
|
|
||||||
|
err = json.Unmarshal(buf, body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cfgNames = body.Configs
|
||||||
|
total = body.Total
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfig returns config according by config name.
|
||||||
|
func (p *LogProject) GetConfig(name string) (c *LogConfig, err error) {
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := request(p, "GET", "/configs/"+name, h, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(buf, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to delete config")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Printf("%s\n", dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c = &LogConfig{}
|
||||||
|
err = json.Unmarshal(buf, c)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.project = p
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateConfig updates a config.
|
||||||
|
func (p *LogProject) UpdateConfig(c *LogConfig) (err error) {
|
||||||
|
|
||||||
|
body, err := json.Marshal(c)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept-Encoding": "deflate", // TODO: support lz4
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := request(p, "PUT", "/configs/"+c.Name, h, body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err = ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(body, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to update config")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Printf("%s\n", dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateConfig creates a new config in SLS.
|
||||||
|
func (p *LogProject) CreateConfig(c *LogConfig) (err error) {
|
||||||
|
|
||||||
|
body, err := json.Marshal(c)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept-Encoding": "deflate", // TODO: support lz4
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := request(p, "POST", "/configs", h, body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err = ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(body, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to update config")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Printf("%s\n", dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteConfig deletes a config according by config name.
|
||||||
|
func (p *LogProject) DeleteConfig(name string) (err error) {
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := request(p, "DELETE", "/configs/"+name, h, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(body, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to delete config")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Printf("%s\n", dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAppliedMachineGroups returns applied machine group names list according config name.
|
||||||
|
func (p *LogProject) GetAppliedMachineGroups(confName string) (groupNames []string, err error) {
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
uri := fmt.Sprintf("/configs/%v/machinegroups", confName)
|
||||||
|
r, err := request(p, "GET", uri, h, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(buf, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to get applied machine groups")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Printf("%s\n", dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type Body struct {
|
||||||
|
Count int
|
||||||
|
Machinegroups []string
|
||||||
|
}
|
||||||
|
|
||||||
|
body := &Body{}
|
||||||
|
err = json.Unmarshal(buf, body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
groupNames = body.Machinegroups
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAppliedConfigs returns applied config names list according machine group name groupName.
|
||||||
|
func (p *LogProject) GetAppliedConfigs(groupName string) (confNames []string, err error) {
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
uri := fmt.Sprintf("/machinegroups/%v/configs", groupName)
|
||||||
|
r, err := request(p, "GET", uri, h, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(buf, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to applied configs")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Printf("%s\n", dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type Cfg struct {
|
||||||
|
Count int `json:"count"`
|
||||||
|
Configs []string `json:"configs"`
|
||||||
|
}
|
||||||
|
|
||||||
|
body := &Cfg{}
|
||||||
|
err = json.Unmarshal(buf, body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
confNames = body.Configs
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyConfigToMachineGroup applies config to machine group.
|
||||||
|
func (p *LogProject) ApplyConfigToMachineGroup(confName, groupName string) (err error) {
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName)
|
||||||
|
r, err := request(p, "PUT", uri, h, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(buf, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to apply config to machine group")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Printf("%s\n", dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveConfigFromMachineGroup removes config from machine group.
|
||||||
|
func (p *LogProject) RemoveConfigFromMachineGroup(confName, groupName string) (err error) {
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName)
|
||||||
|
r, err := request(p, "DELETE", uri, h, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(buf, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to remove config from machine group")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Printf("%s\n", dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
271
logs/alils/log_store.go
Executable file
271
logs/alils/log_store.go
Executable file
@ -0,0 +1,271 @@
|
|||||||
|
package alils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
lz4 "github.com/cloudflare/golz4"
|
||||||
|
"github.com/gogo/protobuf/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LogStore Store the logs
|
||||||
|
type LogStore struct {
|
||||||
|
Name string `json:"logstoreName"`
|
||||||
|
TTL int
|
||||||
|
ShardCount int
|
||||||
|
|
||||||
|
CreateTime uint32
|
||||||
|
LastModifyTime uint32
|
||||||
|
|
||||||
|
project *LogProject
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shard define the Log Shard
|
||||||
|
type Shard struct {
|
||||||
|
ShardID int `json:"shardID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListShards returns shard id list of this logstore.
|
||||||
|
func (s *LogStore) ListShards() (shardIDs []int, err error) {
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
uri := fmt.Sprintf("/logstores/%v/shards", s.Name)
|
||||||
|
r, err := request(s.project, "GET", uri, h, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(buf, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to list logstore")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Println(dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var shards []*Shard
|
||||||
|
err = json.Unmarshal(buf, &shards)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range shards {
|
||||||
|
shardIDs = append(shardIDs, v.ShardID)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutLogs put logs into logstore.
|
||||||
|
// The callers should transform user logs into LogGroup.
|
||||||
|
func (s *LogStore) PutLogs(lg *LogGroup) (err error) {
|
||||||
|
body, err := proto.Marshal(lg)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compresse body with lz4
|
||||||
|
out := make([]byte, lz4.CompressBound(body))
|
||||||
|
n, err := lz4.Compress(body, out)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-compresstype": "lz4",
|
||||||
|
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
|
||||||
|
"Content-Type": "application/x-protobuf",
|
||||||
|
}
|
||||||
|
|
||||||
|
uri := fmt.Sprintf("/logstores/%v", s.Name)
|
||||||
|
r, err := request(s.project, "POST", uri, h, out[:n])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(buf, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to put logs")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Println(dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCursor gets log cursor of one shard specified by shardID.
|
||||||
|
// The from can be in three form: a) unix timestamp in seccond, b) "begin", c) "end".
|
||||||
|
// For more detail please read: http://gitlab.alibaba-inc.com/sls/doc/blob/master/api/shard.md#logstore
|
||||||
|
func (s *LogStore) GetCursor(shardID int, from string) (cursor string, err error) {
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
uri := fmt.Sprintf("/logstores/%v/shards/%v?type=cursor&from=%v",
|
||||||
|
s.Name, shardID, from)
|
||||||
|
|
||||||
|
r, err := request(s.project, "GET", uri, h, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(buf, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to get cursor")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Println(dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type Body struct {
|
||||||
|
Cursor string
|
||||||
|
}
|
||||||
|
body := &Body{}
|
||||||
|
|
||||||
|
err = json.Unmarshal(buf, body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cursor = body.Cursor
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLogsBytes gets logs binary data from shard specified by shardID according cursor.
|
||||||
|
// The logGroupMaxCount is the max number of logGroup could be returned.
|
||||||
|
// The nextCursor is the next curosr can be used to read logs at next time.
|
||||||
|
func (s *LogStore) GetLogsBytes(shardID int, cursor string,
|
||||||
|
logGroupMaxCount int) (out []byte, nextCursor string, err error) {
|
||||||
|
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": "0",
|
||||||
|
"Accept": "application/x-protobuf",
|
||||||
|
"Accept-Encoding": "lz4",
|
||||||
|
}
|
||||||
|
|
||||||
|
uri := fmt.Sprintf("/logstores/%v/shards/%v?type=logs&cursor=%v&count=%v",
|
||||||
|
s.Name, shardID, cursor, logGroupMaxCount)
|
||||||
|
|
||||||
|
r, err := request(s.project, "GET", uri, h, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(buf, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to get cursor")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Println(dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok := r.Header["X-Sls-Compresstype"]
|
||||||
|
if !ok || len(v) == 0 {
|
||||||
|
err = fmt.Errorf("can't find 'x-sls-compresstype' header")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if v[0] != "lz4" {
|
||||||
|
err = fmt.Errorf("unexpected compress type:%v", v[0])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok = r.Header["X-Sls-Cursor"]
|
||||||
|
if !ok || len(v) == 0 {
|
||||||
|
err = fmt.Errorf("can't find 'x-sls-cursor' header")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nextCursor = v[0]
|
||||||
|
|
||||||
|
v, ok = r.Header["X-Sls-Bodyrawsize"]
|
||||||
|
if !ok || len(v) == 0 {
|
||||||
|
err = fmt.Errorf("can't find 'x-sls-bodyrawsize' header")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bodyRawSize, err := strconv.Atoi(v[0])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
out = make([]byte, bodyRawSize)
|
||||||
|
err = lz4.Uncompress(buf, out)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogsBytesDecode decodes logs binary data retruned by GetLogsBytes API
|
||||||
|
func LogsBytesDecode(data []byte) (gl *LogGroupList, err error) {
|
||||||
|
|
||||||
|
gl = &LogGroupList{}
|
||||||
|
err = proto.Unmarshal(data, gl)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLogs gets logs from shard specified by shardID according cursor.
|
||||||
|
// The logGroupMaxCount is the max number of logGroup could be returned.
|
||||||
|
// The nextCursor is the next curosr can be used to read logs at next time.
|
||||||
|
func (s *LogStore) GetLogs(shardID int, cursor string,
|
||||||
|
logGroupMaxCount int) (gl *LogGroupList, nextCursor string, err error) {
|
||||||
|
|
||||||
|
out, nextCursor, err := s.GetLogsBytes(shardID, cursor, logGroupMaxCount)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
gl, err = LogsBytesDecode(out)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
91
logs/alils/machine_group.go
Executable file
91
logs/alils/machine_group.go
Executable file
@ -0,0 +1,91 @@
|
|||||||
|
package alils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MachineGroupAttribute define the Attribute
|
||||||
|
type MachineGroupAttribute struct {
|
||||||
|
ExternalName string `json:"externalName"`
|
||||||
|
TopicName string `json:"groupTopic"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MachineGroup define the machine Group
|
||||||
|
type MachineGroup struct {
|
||||||
|
Name string `json:"groupName"`
|
||||||
|
Type string `json:"groupType"`
|
||||||
|
MachineIDType string `json:"machineIdentifyType"`
|
||||||
|
MachineIDList []string `json:"machineList"`
|
||||||
|
|
||||||
|
Attribute MachineGroupAttribute `json:"groupAttribute"`
|
||||||
|
|
||||||
|
CreateTime uint32
|
||||||
|
LastModifyTime uint32
|
||||||
|
|
||||||
|
project *LogProject
|
||||||
|
}
|
||||||
|
|
||||||
|
// Machine define the Machine
|
||||||
|
type Machine struct {
|
||||||
|
IP string
|
||||||
|
UniqueID string `json:"machine-uniqueid"`
|
||||||
|
UserdefinedID string `json:"userdefined-id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MachineList define the Machine List
|
||||||
|
type MachineList struct {
|
||||||
|
Total int
|
||||||
|
Machines []*Machine
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListMachines returns machine list of this machine group.
|
||||||
|
func (m *MachineGroup) ListMachines() (ms []*Machine, total int, err error) {
|
||||||
|
h := map[string]string{
|
||||||
|
"x-sls-bodyrawsize": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
uri := fmt.Sprintf("/machinegroups/%v/machines", m.Name)
|
||||||
|
r, err := request(m.project, "GET", uri, h, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.StatusCode != http.StatusOK {
|
||||||
|
errMsg := &errorMessage{}
|
||||||
|
err = json.Unmarshal(buf, errMsg)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to remove config from machine group")
|
||||||
|
dump, _ := httputil.DumpResponse(r, true)
|
||||||
|
fmt.Println(dump)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body := &MachineList{}
|
||||||
|
err = json.Unmarshal(buf, body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ms = body.Machines
|
||||||
|
total = body.Total
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAppliedConfigs returns applied configs of this machine group.
|
||||||
|
func (m *MachineGroup) GetAppliedConfigs() (confNames []string, err error) {
|
||||||
|
confNames, err = m.project.GetAppliedConfigs(m.Name)
|
||||||
|
return
|
||||||
|
}
|
62
logs/alils/request.go
Executable file
62
logs/alils/request.go
Executable file
@ -0,0 +1,62 @@
|
|||||||
|
package alils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/md5"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// request sends a request to SLS.
|
||||||
|
func request(project *LogProject, method, uri string, headers map[string]string,
|
||||||
|
body []byte) (resp *http.Response, err error) {
|
||||||
|
|
||||||
|
// The caller should provide 'x-sls-bodyrawsize' header
|
||||||
|
if _, ok := headers["x-sls-bodyrawsize"]; !ok {
|
||||||
|
err = fmt.Errorf("Can't find 'x-sls-bodyrawsize' header")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// SLS public request headers
|
||||||
|
headers["Host"] = project.Name + "." + project.Endpoint
|
||||||
|
headers["Date"] = nowRFC1123()
|
||||||
|
headers["x-sls-apiversion"] = version
|
||||||
|
headers["x-sls-signaturemethod"] = signatureMethod
|
||||||
|
if body != nil {
|
||||||
|
bodyMD5 := fmt.Sprintf("%X", md5.Sum(body))
|
||||||
|
headers["Content-MD5"] = bodyMD5
|
||||||
|
|
||||||
|
if _, ok := headers["Content-Type"]; !ok {
|
||||||
|
err = fmt.Errorf("Can't find 'Content-Type' header")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calc Authorization
|
||||||
|
// Authorization = "SLS <AccessKeyID>:<Signature>"
|
||||||
|
digest, err := signature(project, method, uri, headers)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
auth := fmt.Sprintf("SLS %v:%v", project.AccessKeyID, digest)
|
||||||
|
headers["Authorization"] = auth
|
||||||
|
|
||||||
|
// Initialize http request
|
||||||
|
reader := bytes.NewReader(body)
|
||||||
|
urlStr := fmt.Sprintf("http://%v.%v%v", project.Name, project.Endpoint, uri)
|
||||||
|
req, err := http.NewRequest(method, urlStr, reader)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for k, v := range headers {
|
||||||
|
req.Header.Add(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get ready to do request
|
||||||
|
resp, err = http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
111
logs/alils/signature.go
Executable file
111
logs/alils/signature.go
Executable file
@ -0,0 +1,111 @@
|
|||||||
|
package alils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GMT location
|
||||||
|
var gmtLoc = time.FixedZone("GMT", 0)
|
||||||
|
|
||||||
|
// NowRFC1123 returns now time in RFC1123 format with GMT timezone,
|
||||||
|
// eg. "Mon, 02 Jan 2006 15:04:05 GMT".
|
||||||
|
func nowRFC1123() string {
|
||||||
|
return time.Now().In(gmtLoc).Format(time.RFC1123)
|
||||||
|
}
|
||||||
|
|
||||||
|
// signature calculates a request's signature digest.
|
||||||
|
func signature(project *LogProject, method, uri string,
|
||||||
|
headers map[string]string) (digest string, err error) {
|
||||||
|
var contentMD5, contentType, date, canoHeaders, canoResource string
|
||||||
|
var slsHeaderKeys sort.StringSlice
|
||||||
|
|
||||||
|
// SignString = VERB + "\n"
|
||||||
|
// + CONTENT-MD5 + "\n"
|
||||||
|
// + CONTENT-TYPE + "\n"
|
||||||
|
// + DATE + "\n"
|
||||||
|
// + CanonicalizedSLSHeaders + "\n"
|
||||||
|
// + CanonicalizedResource
|
||||||
|
|
||||||
|
if val, ok := headers["Content-MD5"]; ok {
|
||||||
|
contentMD5 = val
|
||||||
|
}
|
||||||
|
|
||||||
|
if val, ok := headers["Content-Type"]; ok {
|
||||||
|
contentType = val
|
||||||
|
}
|
||||||
|
|
||||||
|
date, ok := headers["Date"]
|
||||||
|
if !ok {
|
||||||
|
err = fmt.Errorf("Can't find 'Date' header")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calc CanonicalizedSLSHeaders
|
||||||
|
slsHeaders := make(map[string]string, len(headers))
|
||||||
|
for k, v := range headers {
|
||||||
|
l := strings.TrimSpace(strings.ToLower(k))
|
||||||
|
if strings.HasPrefix(l, "x-sls-") {
|
||||||
|
slsHeaders[l] = strings.TrimSpace(v)
|
||||||
|
slsHeaderKeys = append(slsHeaderKeys, l)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Sort(slsHeaderKeys)
|
||||||
|
for i, k := range slsHeaderKeys {
|
||||||
|
canoHeaders += k + ":" + slsHeaders[k]
|
||||||
|
if i+1 < len(slsHeaderKeys) {
|
||||||
|
canoHeaders += "\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calc CanonicalizedResource
|
||||||
|
u, err := url.Parse(uri)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
canoResource += url.QueryEscape(u.Path)
|
||||||
|
if u.RawQuery != "" {
|
||||||
|
var keys sort.StringSlice
|
||||||
|
|
||||||
|
vals := u.Query()
|
||||||
|
for k := range vals {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Sort(keys)
|
||||||
|
canoResource += "?"
|
||||||
|
for i, k := range keys {
|
||||||
|
if i > 0 {
|
||||||
|
canoResource += "&"
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range vals[k] {
|
||||||
|
canoResource += k + "=" + v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
signStr := method + "\n" +
|
||||||
|
contentMD5 + "\n" +
|
||||||
|
contentType + "\n" +
|
||||||
|
date + "\n" +
|
||||||
|
canoHeaders + "\n" +
|
||||||
|
canoResource
|
||||||
|
|
||||||
|
// Signature = base64(hmac-sha1(UTF8-Encoding-Of(SignString),AccessKeySecret))
|
||||||
|
mac := hmac.New(sha1.New, []byte(project.AccessKeySecret))
|
||||||
|
_, err = mac.Write([]byte(signStr))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
digest = base64.StdEncoding.EncodeToString(mac.Sum(nil))
|
||||||
|
return
|
||||||
|
}
|
@ -361,7 +361,7 @@ func isParameterChar(b byte) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cw *ansiColorWriter) Write(p []byte) (int, error) {
|
func (cw *ansiColorWriter) Write(p []byte) (int, error) {
|
||||||
r, nw, first, last := 0, 0, 0, 0
|
var r, nw, first, last int
|
||||||
if cw.mode != DiscardNonColorEscSeq {
|
if cw.mode != DiscardNonColorEscSeq {
|
||||||
cw.state = outsideCsiCode
|
cw.state = outsideCsiCode
|
||||||
cw.resetBuffer()
|
cw.resetBuffer()
|
||||||
|
@ -41,7 +41,7 @@ var colors = []brush{
|
|||||||
newBrush("1;33"), // Warning yellow
|
newBrush("1;33"), // Warning yellow
|
||||||
newBrush("1;32"), // Notice green
|
newBrush("1;32"), // Notice green
|
||||||
newBrush("1;34"), // Informational blue
|
newBrush("1;34"), // Informational blue
|
||||||
newBrush("1;34"), // Debug blue
|
newBrush("1;44"), // Debug Background blue
|
||||||
}
|
}
|
||||||
|
|
||||||
// consoleWriter implements LoggerInterface and writes messages to terminal.
|
// consoleWriter implements LoggerInterface and writes messages to terminal.
|
||||||
|
49
logs/file.go
49
logs/file.go
@ -56,17 +56,20 @@ type fileLogWriter struct {
|
|||||||
|
|
||||||
Perm string `json:"perm"`
|
Perm string `json:"perm"`
|
||||||
|
|
||||||
|
RotatePerm string `json:"rotateperm"`
|
||||||
|
|
||||||
fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix
|
fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix
|
||||||
}
|
}
|
||||||
|
|
||||||
// newFileWriter create a FileLogWriter returning as LoggerInterface.
|
// newFileWriter create a FileLogWriter returning as LoggerInterface.
|
||||||
func newFileWriter() Logger {
|
func newFileWriter() Logger {
|
||||||
w := &fileLogWriter{
|
w := &fileLogWriter{
|
||||||
Daily: true,
|
Daily: true,
|
||||||
MaxDays: 7,
|
MaxDays: 7,
|
||||||
Rotate: true,
|
Rotate: true,
|
||||||
Level: LevelTrace,
|
RotatePerm: "0440",
|
||||||
Perm: "0660",
|
Level: LevelTrace,
|
||||||
|
Perm: "0660",
|
||||||
}
|
}
|
||||||
return w
|
return w
|
||||||
}
|
}
|
||||||
@ -170,7 +173,7 @@ func (w *fileLogWriter) initFd() error {
|
|||||||
fd := w.fileWriter
|
fd := w.fileWriter
|
||||||
fInfo, err := fd.Stat()
|
fInfo, err := fd.Stat()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get stat err: %s\n", err)
|
return fmt.Errorf("get stat err: %s", err)
|
||||||
}
|
}
|
||||||
w.maxSizeCurSize = int(fInfo.Size())
|
w.maxSizeCurSize = int(fInfo.Size())
|
||||||
w.dailyOpenTime = time.Now()
|
w.dailyOpenTime = time.Now()
|
||||||
@ -193,16 +196,14 @@ func (w *fileLogWriter) dailyRotate(openTime time.Time) {
|
|||||||
y, m, d := openTime.Add(24 * time.Hour).Date()
|
y, m, d := openTime.Add(24 * time.Hour).Date()
|
||||||
nextDay := time.Date(y, m, d, 0, 0, 0, 0, openTime.Location())
|
nextDay := time.Date(y, m, d, 0, 0, 0, 0, openTime.Location())
|
||||||
tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100))
|
tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100))
|
||||||
select {
|
<-tm.C
|
||||||
case <-tm.C:
|
w.Lock()
|
||||||
w.Lock()
|
if w.needRotate(0, time.Now().Day()) {
|
||||||
if w.needRotate(0, time.Now().Day()) {
|
if err := w.doRotate(time.Now()); err != nil {
|
||||||
if err := w.doRotate(time.Now()); err != nil {
|
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
|
||||||
fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
w.Unlock()
|
|
||||||
}
|
}
|
||||||
|
w.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *fileLogWriter) lines() (int, error) {
|
func (w *fileLogWriter) lines() (int, error) {
|
||||||
@ -239,8 +240,12 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error {
|
|||||||
// Find the next available number
|
// Find the next available number
|
||||||
num := 1
|
num := 1
|
||||||
fName := ""
|
fName := ""
|
||||||
|
rotatePerm, err := strconv.ParseInt(w.RotatePerm, 8, 64)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
_, err := os.Lstat(w.Filename)
|
_, err = os.Lstat(w.Filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//even if the file is not exist or other ,we should RESTART the logger
|
//even if the file is not exist or other ,we should RESTART the logger
|
||||||
goto RESTART_LOGGER
|
goto RESTART_LOGGER
|
||||||
@ -261,7 +266,7 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error {
|
|||||||
}
|
}
|
||||||
// return error if the last file checked still existed
|
// return error if the last file checked still existed
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return fmt.Errorf("Rotate: Cannot find free log number to rename %s\n", w.Filename)
|
return fmt.Errorf("Rotate: Cannot find free log number to rename %s", w.Filename)
|
||||||
}
|
}
|
||||||
|
|
||||||
// close fileWriter before rename
|
// close fileWriter before rename
|
||||||
@ -270,20 +275,24 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error {
|
|||||||
// Rename the file to its new found name
|
// Rename the file to its new found name
|
||||||
// even if occurs error,we MUST guarantee to restart new logger
|
// even if occurs error,we MUST guarantee to restart new logger
|
||||||
err = os.Rename(w.Filename, fName)
|
err = os.Rename(w.Filename, fName)
|
||||||
// re-start logger
|
if err != nil {
|
||||||
|
goto RESTART_LOGGER
|
||||||
|
}
|
||||||
|
|
||||||
|
err = os.Chmod(fName, os.FileMode(rotatePerm))
|
||||||
|
|
||||||
RESTART_LOGGER:
|
RESTART_LOGGER:
|
||||||
|
|
||||||
startLoggerErr := w.startLogger()
|
startLoggerErr := w.startLogger()
|
||||||
go w.deleteOldLog()
|
go w.deleteOldLog()
|
||||||
|
|
||||||
if startLoggerErr != nil {
|
if startLoggerErr != nil {
|
||||||
return fmt.Errorf("Rotate StartLogger: %s\n", startLoggerErr)
|
return fmt.Errorf("Rotate StartLogger: %s", startLoggerErr)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Rotate: %s\n", err)
|
return fmt.Errorf("Rotate: %s", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *fileLogWriter) deleteOldLog() {
|
func (w *fileLogWriter) deleteOldLog() {
|
||||||
|
@ -162,14 +162,35 @@ func TestFileRotate_05(t *testing.T) {
|
|||||||
testFileDailyRotate(t, fn1, fn2)
|
testFileDailyRotate(t, fn1, fn2)
|
||||||
os.Remove(fn)
|
os.Remove(fn)
|
||||||
}
|
}
|
||||||
|
func TestFileRotate_06(t *testing.T) { //test file mode
|
||||||
|
log := NewLogger(10000)
|
||||||
|
log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`)
|
||||||
|
log.Debug("debug")
|
||||||
|
log.Info("info")
|
||||||
|
log.Notice("notice")
|
||||||
|
log.Warning("warning")
|
||||||
|
log.Error("error")
|
||||||
|
log.Alert("alert")
|
||||||
|
log.Critical("critical")
|
||||||
|
log.Emergency("emergency")
|
||||||
|
rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log"
|
||||||
|
s, _ := os.Lstat(rotateName)
|
||||||
|
if s.Mode() != 0440 {
|
||||||
|
os.Remove(rotateName)
|
||||||
|
os.Remove("test3.log")
|
||||||
|
t.Fatal("rotate file mode error")
|
||||||
|
}
|
||||||
|
os.Remove(rotateName)
|
||||||
|
os.Remove("test3.log")
|
||||||
|
}
|
||||||
func testFileRotate(t *testing.T, fn1, fn2 string) {
|
func testFileRotate(t *testing.T, fn1, fn2 string) {
|
||||||
fw := &fileLogWriter{
|
fw := &fileLogWriter{
|
||||||
Daily: true,
|
Daily: true,
|
||||||
MaxDays: 7,
|
MaxDays: 7,
|
||||||
Rotate: true,
|
Rotate: true,
|
||||||
Level: LevelTrace,
|
Level: LevelTrace,
|
||||||
Perm: "0660",
|
Perm: "0660",
|
||||||
|
RotatePerm: "0440",
|
||||||
}
|
}
|
||||||
fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1))
|
fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1))
|
||||||
fw.dailyOpenTime = time.Now().Add(-24 * time.Hour)
|
fw.dailyOpenTime = time.Now().Add(-24 * time.Hour)
|
||||||
@ -188,11 +209,12 @@ func testFileRotate(t *testing.T, fn1, fn2 string) {
|
|||||||
|
|
||||||
func testFileDailyRotate(t *testing.T, fn1, fn2 string) {
|
func testFileDailyRotate(t *testing.T, fn1, fn2 string) {
|
||||||
fw := &fileLogWriter{
|
fw := &fileLogWriter{
|
||||||
Daily: true,
|
Daily: true,
|
||||||
MaxDays: 7,
|
MaxDays: 7,
|
||||||
Rotate: true,
|
Rotate: true,
|
||||||
Level: LevelTrace,
|
Level: LevelTrace,
|
||||||
Perm: "0660",
|
Perm: "0660",
|
||||||
|
RotatePerm: "0440",
|
||||||
}
|
}
|
||||||
fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1))
|
fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1))
|
||||||
fw.dailyOpenTime = time.Now().Add(-24 * time.Hour)
|
fw.dailyOpenTime = time.Now().Add(-24 * time.Hour)
|
||||||
|
@ -25,11 +25,7 @@ func newJLWriter() Logger {
|
|||||||
|
|
||||||
// Init JLWriter with json config string
|
// Init JLWriter with json config string
|
||||||
func (s *JLWriter) Init(jsonconfig string) error {
|
func (s *JLWriter) Init(jsonconfig string) error {
|
||||||
err := json.Unmarshal([]byte(jsonconfig), s)
|
return json.Unmarshal([]byte(jsonconfig), s)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteMsg write message in smtp writer.
|
// WriteMsg write message in smtp writer.
|
||||||
@ -65,12 +61,10 @@ func (s *JLWriter) WriteMsg(when time.Time, msg string, level int) error {
|
|||||||
|
|
||||||
// Flush implementing method. empty.
|
// Flush implementing method. empty.
|
||||||
func (s *JLWriter) Flush() {
|
func (s *JLWriter) Flush() {
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Destroy implementing method. empty.
|
// Destroy implementing method. empty.
|
||||||
func (s *JLWriter) Destroy() {
|
func (s *JLWriter) Destroy() {
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
14
logs/log.go
14
logs/log.go
@ -71,6 +71,7 @@ const (
|
|||||||
AdapterEs = "es"
|
AdapterEs = "es"
|
||||||
AdapterJianLiao = "jianliao"
|
AdapterJianLiao = "jianliao"
|
||||||
AdapterSlack = "slack"
|
AdapterSlack = "slack"
|
||||||
|
AdapterAliLS = "alils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Legacy log level constants to ensure backwards compatibility.
|
// Legacy log level constants to ensure backwards compatibility.
|
||||||
@ -274,7 +275,7 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error
|
|||||||
line = 0
|
line = 0
|
||||||
}
|
}
|
||||||
_, filename := path.Split(file)
|
_, filename := path.Split(file)
|
||||||
msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "] " + msg
|
msg = "[" + filename + ":" + strconv.Itoa(line) + "] " + msg
|
||||||
}
|
}
|
||||||
|
|
||||||
//set level info in front of filename info
|
//set level info in front of filename info
|
||||||
@ -491,9 +492,9 @@ func (bl *BeeLogger) flush() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// beeLogger references the used application logger.
|
// beeLogger references the used application logger.
|
||||||
var beeLogger *BeeLogger = NewLogger()
|
var beeLogger = NewLogger()
|
||||||
|
|
||||||
// GetLogger returns the default BeeLogger
|
// GetBeeLogger returns the default BeeLogger
|
||||||
func GetBeeLogger() *BeeLogger {
|
func GetBeeLogger() *BeeLogger {
|
||||||
return beeLogger
|
return beeLogger
|
||||||
}
|
}
|
||||||
@ -533,6 +534,7 @@ func Reset() {
|
|||||||
beeLogger.Reset()
|
beeLogger.Reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Async set the beelogger with Async mode and hold msglen messages
|
||||||
func Async(msgLen ...int64) *BeeLogger {
|
func Async(msgLen ...int64) *BeeLogger {
|
||||||
return beeLogger.Async(msgLen...)
|
return beeLogger.Async(msgLen...)
|
||||||
}
|
}
|
||||||
@ -560,11 +562,7 @@ func SetLogFuncCallDepth(d int) {
|
|||||||
|
|
||||||
// SetLogger sets a new logger.
|
// SetLogger sets a new logger.
|
||||||
func SetLogger(adapter string, config ...string) error {
|
func SetLogger(adapter string, config ...string) error {
|
||||||
err := beeLogger.SetLogger(adapter, config...)
|
return beeLogger.SetLogger(adapter, config...)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Emergency logs a message at emergency level.
|
// Emergency logs a message at emergency level.
|
||||||
|
@ -139,6 +139,11 @@ var (
|
|||||||
reset = string([]byte{27, 91, 48, 109})
|
reset = string([]byte{27, 91, 48, 109})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ColorByStatus return color by http code
|
||||||
|
// 2xx return Green
|
||||||
|
// 3xx return White
|
||||||
|
// 4xx return Yellow
|
||||||
|
// 5xx return Red
|
||||||
func ColorByStatus(cond bool, code int) string {
|
func ColorByStatus(cond bool, code int) string {
|
||||||
switch {
|
switch {
|
||||||
case code >= 200 && code < 300:
|
case code >= 200 && code < 300:
|
||||||
@ -152,6 +157,14 @@ func ColorByStatus(cond bool, code int) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ColorByMethod return color by http code
|
||||||
|
// GET return Blue
|
||||||
|
// POST return Cyan
|
||||||
|
// PUT return Yellow
|
||||||
|
// DELETE return Red
|
||||||
|
// PATCH return Green
|
||||||
|
// HEAD return Magenta
|
||||||
|
// OPTIONS return WHITE
|
||||||
func ColorByMethod(cond bool, method string) string {
|
func ColorByMethod(cond bool, method string) string {
|
||||||
switch method {
|
switch method {
|
||||||
case "GET":
|
case "GET":
|
||||||
@ -173,10 +186,10 @@ func ColorByMethod(cond bool, method string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Guard Mutex to guarantee atomicity of W32Debug(string) function
|
// Guard Mutex to guarantee atomic of W32Debug(string) function
|
||||||
var mu sync.Mutex
|
var mu sync.Mutex
|
||||||
|
|
||||||
// Helper method to output colored logs in Windows terminals
|
// W32Debug Helper method to output colored logs in Windows terminals
|
||||||
func W32Debug(msg string) {
|
func W32Debug(msg string) {
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
defer mu.Unlock()
|
defer mu.Unlock()
|
||||||
|
@ -21,11 +21,7 @@ func newSLACKWriter() Logger {
|
|||||||
|
|
||||||
// Init SLACKWriter with json config string
|
// Init SLACKWriter with json config string
|
||||||
func (s *SLACKWriter) Init(jsonconfig string) error {
|
func (s *SLACKWriter) Init(jsonconfig string) error {
|
||||||
err := json.Unmarshal([]byte(jsonconfig), s)
|
return json.Unmarshal([]byte(jsonconfig), s)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteMsg write message in smtp writer.
|
// WriteMsg write message in smtp writer.
|
||||||
@ -53,12 +49,10 @@ func (s *SLACKWriter) WriteMsg(when time.Time, msg string, level int) error {
|
|||||||
|
|
||||||
// Flush implementing method. empty.
|
// Flush implementing method. empty.
|
||||||
func (s *SLACKWriter) Flush() {
|
func (s *SLACKWriter) Flush() {
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Destroy implementing method. empty.
|
// Destroy implementing method. empty.
|
||||||
func (s *SLACKWriter) Destroy() {
|
func (s *SLACKWriter) Destroy() {
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
17
logs/smtp.go
17
logs/smtp.go
@ -52,11 +52,7 @@ func newSMTPWriter() Logger {
|
|||||||
// "level":LevelError
|
// "level":LevelError
|
||||||
// }
|
// }
|
||||||
func (s *SMTPWriter) Init(jsonconfig string) error {
|
func (s *SMTPWriter) Init(jsonconfig string) error {
|
||||||
err := json.Unmarshal([]byte(jsonconfig), s)
|
return json.Unmarshal([]byte(jsonconfig), s)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth {
|
func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth {
|
||||||
@ -106,7 +102,7 @@ func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = w.Write([]byte(msgContent))
|
_, err = w.Write(msgContent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -116,12 +112,7 @@ func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = client.Quit()
|
return client.Quit()
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteMsg write message in smtp writer.
|
// WriteMsg write message in smtp writer.
|
||||||
@ -147,12 +138,10 @@ func (s *SMTPWriter) WriteMsg(when time.Time, msg string, level int) error {
|
|||||||
|
|
||||||
// Flush implementing method. empty.
|
// Flush implementing method. empty.
|
||||||
func (s *SMTPWriter) Flush() {
|
func (s *SMTPWriter) Flush() {
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Destroy implementing method. empty.
|
// Destroy implementing method. empty.
|
||||||
func (s *SMTPWriter) Destroy() {
|
func (s *SMTPWriter) Destroy() {
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
392
migration/ddl.go
392
migration/ddl.go
@ -14,40 +14,382 @@
|
|||||||
|
|
||||||
package migration
|
package migration
|
||||||
|
|
||||||
// Table store the tablename and Column
|
import (
|
||||||
type Table struct {
|
"fmt"
|
||||||
TableName string
|
|
||||||
Columns []*Column
|
"github.com/astaxie/beego"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Index struct defines the structure of Index Columns
|
||||||
|
type Index struct {
|
||||||
|
Name string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create return the create sql
|
// Unique struct defines a single unique key combination
|
||||||
func (t *Table) Create() string {
|
type Unique struct {
|
||||||
return ""
|
Definition string
|
||||||
|
Columns []*Column
|
||||||
}
|
}
|
||||||
|
|
||||||
// Drop return the drop sql
|
//Column struct defines a single column of a table
|
||||||
func (t *Table) Drop() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Column define the columns name type and Default
|
|
||||||
type Column struct {
|
type Column struct {
|
||||||
Name string
|
Name string
|
||||||
Type string
|
Inc string
|
||||||
Default interface{}
|
Null string
|
||||||
|
Default string
|
||||||
|
Unsign string
|
||||||
|
DataType string
|
||||||
|
remove bool
|
||||||
|
Modify bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create return create sql with the provided tbname and columns
|
// Foreign struct defines a single foreign relationship
|
||||||
func Create(tbname string, columns ...Column) string {
|
type Foreign struct {
|
||||||
return ""
|
ForeignTable string
|
||||||
|
ForeignColumn string
|
||||||
|
OnDelete string
|
||||||
|
OnUpdate string
|
||||||
|
Column
|
||||||
}
|
}
|
||||||
|
|
||||||
// Drop return the drop sql with the provided tbname and columns
|
// RenameColumn struct allows renaming of columns
|
||||||
func Drop(tbname string, columns ...Column) string {
|
type RenameColumn struct {
|
||||||
return ""
|
OldName string
|
||||||
|
OldNull string
|
||||||
|
OldDefault string
|
||||||
|
OldUnsign string
|
||||||
|
OldDataType string
|
||||||
|
NewName string
|
||||||
|
Column
|
||||||
}
|
}
|
||||||
|
|
||||||
// TableDDL is still in think
|
// CreateTable creates the table on system
|
||||||
func TableDDL(tbname string, columns ...Column) string {
|
func (m *Migration) CreateTable(tablename, engine, charset string, p ...func()) {
|
||||||
return ""
|
m.TableName = tablename
|
||||||
|
m.Engine = engine
|
||||||
|
m.Charset = charset
|
||||||
|
m.ModifyType = "create"
|
||||||
|
}
|
||||||
|
|
||||||
|
// AlterTable set the ModifyType to alter
|
||||||
|
func (m *Migration) AlterTable(tablename string) {
|
||||||
|
m.TableName = tablename
|
||||||
|
m.ModifyType = "alter"
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCol creates a new standard column and attaches it to m struct
|
||||||
|
func (m *Migration) NewCol(name string) *Column {
|
||||||
|
col := &Column{Name: name}
|
||||||
|
m.AddColumns(col)
|
||||||
|
return col
|
||||||
|
}
|
||||||
|
|
||||||
|
//PriCol creates a new primary column and attaches it to m struct
|
||||||
|
func (m *Migration) PriCol(name string) *Column {
|
||||||
|
col := &Column{Name: name}
|
||||||
|
m.AddColumns(col)
|
||||||
|
m.AddPrimary(col)
|
||||||
|
return col
|
||||||
|
}
|
||||||
|
|
||||||
|
//UniCol creates / appends columns to specified unique key and attaches it to m struct
|
||||||
|
func (m *Migration) UniCol(uni, name string) *Column {
|
||||||
|
col := &Column{Name: name}
|
||||||
|
m.AddColumns(col)
|
||||||
|
|
||||||
|
uniqueOriginal := &Unique{}
|
||||||
|
|
||||||
|
for _, unique := range m.Uniques {
|
||||||
|
if unique.Definition == uni {
|
||||||
|
unique.AddColumnsToUnique(col)
|
||||||
|
uniqueOriginal = unique
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if uniqueOriginal.Definition == "" {
|
||||||
|
unique := &Unique{Definition: uni}
|
||||||
|
unique.AddColumnsToUnique(col)
|
||||||
|
m.AddUnique(unique)
|
||||||
|
}
|
||||||
|
|
||||||
|
return col
|
||||||
|
}
|
||||||
|
|
||||||
|
//ForeignCol creates a new foreign column and returns the instance of column
|
||||||
|
func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreign *Foreign) {
|
||||||
|
|
||||||
|
foreign = &Foreign{ForeignColumn: foreigncol, ForeignTable: foreigntable}
|
||||||
|
foreign.Name = colname
|
||||||
|
m.AddForeign(foreign)
|
||||||
|
return foreign
|
||||||
|
}
|
||||||
|
|
||||||
|
//SetOnDelete sets the on delete of foreign
|
||||||
|
func (foreign *Foreign) SetOnDelete(del string) *Foreign {
|
||||||
|
foreign.OnDelete = "ON DELETE" + del
|
||||||
|
return foreign
|
||||||
|
}
|
||||||
|
|
||||||
|
//SetOnUpdate sets the on update of foreign
|
||||||
|
func (foreign *Foreign) SetOnUpdate(update string) *Foreign {
|
||||||
|
foreign.OnUpdate = "ON UPDATE" + update
|
||||||
|
return foreign
|
||||||
|
}
|
||||||
|
|
||||||
|
//Remove marks the columns to be removed.
|
||||||
|
//it allows reverse m to create the column.
|
||||||
|
func (c *Column) Remove() {
|
||||||
|
c.remove = true
|
||||||
|
}
|
||||||
|
|
||||||
|
//SetAuto enables auto_increment of column (can be used once)
|
||||||
|
func (c *Column) SetAuto(inc bool) *Column {
|
||||||
|
if inc {
|
||||||
|
c.Inc = "auto_increment"
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
//SetNullable sets the column to be null
|
||||||
|
func (c *Column) SetNullable(null bool) *Column {
|
||||||
|
if null {
|
||||||
|
c.Null = ""
|
||||||
|
|
||||||
|
} else {
|
||||||
|
c.Null = "NOT NULL"
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
//SetDefault sets the default value, prepend with "DEFAULT "
|
||||||
|
func (c *Column) SetDefault(def string) *Column {
|
||||||
|
c.Default = "DEFAULT " + def
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
//SetUnsigned sets the column to be unsigned int
|
||||||
|
func (c *Column) SetUnsigned(unsign bool) *Column {
|
||||||
|
if unsign {
|
||||||
|
c.Unsign = "UNSIGNED"
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
//SetDataType sets the dataType of the column
|
||||||
|
func (c *Column) SetDataType(dataType string) *Column {
|
||||||
|
c.DataType = dataType
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
//SetOldNullable allows reverting to previous nullable on reverse ms
|
||||||
|
func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn {
|
||||||
|
if null {
|
||||||
|
c.OldNull = ""
|
||||||
|
|
||||||
|
} else {
|
||||||
|
c.OldNull = "NOT NULL"
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
//SetOldDefault allows reverting to previous default on reverse ms
|
||||||
|
func (c *RenameColumn) SetOldDefault(def string) *RenameColumn {
|
||||||
|
c.OldDefault = def
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
//SetOldUnsigned allows reverting to previous unsgined on reverse ms
|
||||||
|
func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn {
|
||||||
|
if unsign {
|
||||||
|
c.OldUnsign = "UNSIGNED"
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
//SetOldDataType allows reverting to previous datatype on reverse ms
|
||||||
|
func (c *RenameColumn) SetOldDataType(dataType string) *RenameColumn {
|
||||||
|
c.OldDataType = dataType
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
//SetPrimary adds the columns to the primary key (can only be used any number of times in only one m)
|
||||||
|
func (c *Column) SetPrimary(m *Migration) *Column {
|
||||||
|
m.Primary = append(m.Primary, c)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
//AddColumnsToUnique adds the columns to Unique Struct
|
||||||
|
func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique {
|
||||||
|
|
||||||
|
unique.Columns = append(unique.Columns, columns...)
|
||||||
|
|
||||||
|
return unique
|
||||||
|
}
|
||||||
|
|
||||||
|
//AddColumns adds columns to m struct
|
||||||
|
func (m *Migration) AddColumns(columns ...*Column) *Migration {
|
||||||
|
|
||||||
|
m.Columns = append(m.Columns, columns...)
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
//AddPrimary adds the column to primary in m struct
|
||||||
|
func (m *Migration) AddPrimary(primary *Column) *Migration {
|
||||||
|
m.Primary = append(m.Primary, primary)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
//AddUnique adds the column to unique in m struct
|
||||||
|
func (m *Migration) AddUnique(unique *Unique) *Migration {
|
||||||
|
m.Uniques = append(m.Uniques, unique)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
//AddForeign adds the column to foreign in m struct
|
||||||
|
func (m *Migration) AddForeign(foreign *Foreign) *Migration {
|
||||||
|
m.Foreigns = append(m.Foreigns, foreign)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
//AddIndex adds the column to index in m struct
|
||||||
|
func (m *Migration) AddIndex(index *Index) *Migration {
|
||||||
|
m.Indexes = append(m.Indexes, index)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
//RenameColumn allows renaming of columns
|
||||||
|
func (m *Migration) RenameColumn(from, to string) *RenameColumn {
|
||||||
|
rename := &RenameColumn{OldName: from, NewName: to}
|
||||||
|
m.Renames = append(m.Renames, rename)
|
||||||
|
return rename
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetSQL returns the generated sql depending on ModifyType
|
||||||
|
func (m *Migration) GetSQL() (sql string) {
|
||||||
|
sql = ""
|
||||||
|
switch m.ModifyType {
|
||||||
|
case "create":
|
||||||
|
{
|
||||||
|
sql += fmt.Sprintf("CREATE TABLE `%s` (", m.TableName)
|
||||||
|
for index, column := range m.Columns {
|
||||||
|
sql += fmt.Sprintf("\n `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
|
||||||
|
if len(m.Columns) > index+1 {
|
||||||
|
sql += ","
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m.Primary) > 0 {
|
||||||
|
sql += fmt.Sprintf(",\n PRIMARY KEY( ")
|
||||||
|
}
|
||||||
|
for index, column := range m.Primary {
|
||||||
|
sql += fmt.Sprintf(" `%s`", column.Name)
|
||||||
|
if len(m.Primary) > index+1 {
|
||||||
|
sql += ","
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
if len(m.Primary) > 0 {
|
||||||
|
sql += fmt.Sprintf(")")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, unique := range m.Uniques {
|
||||||
|
sql += fmt.Sprintf(",\n UNIQUE KEY `%s`( ", unique.Definition)
|
||||||
|
for index, column := range unique.Columns {
|
||||||
|
sql += fmt.Sprintf(" `%s`", column.Name)
|
||||||
|
if len(unique.Columns) > index+1 {
|
||||||
|
sql += ","
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sql += fmt.Sprintf(")")
|
||||||
|
}
|
||||||
|
for _, foreign := range m.Foreigns {
|
||||||
|
sql += fmt.Sprintf(",\n `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default)
|
||||||
|
sql += fmt.Sprintf(",\n KEY `%s_%s_foreign`(`%s`),", m.TableName, foreign.Column.Name, foreign.Column.Name)
|
||||||
|
sql += fmt.Sprintf("\n CONSTRAINT `%s_%s_foreign` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) %s %s", m.TableName, foreign.Column.Name, foreign.Column.Name, foreign.ForeignTable, foreign.ForeignColumn, foreign.OnDelete, foreign.OnUpdate)
|
||||||
|
|
||||||
|
}
|
||||||
|
sql += fmt.Sprintf(")ENGINE=%s DEFAULT CHARSET=%s;", m.Engine, m.Charset)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case "alter":
|
||||||
|
{
|
||||||
|
sql += fmt.Sprintf("ALTER TABLE `%s` ", m.TableName)
|
||||||
|
for index, column := range m.Columns {
|
||||||
|
if !column.remove {
|
||||||
|
beego.BeeLogger.Info("col")
|
||||||
|
sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
|
||||||
|
} else {
|
||||||
|
sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m.Columns) > index {
|
||||||
|
sql += ","
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for index, column := range m.Renames {
|
||||||
|
sql += fmt.Sprintf("CHANGE COLUMN `%s` `%s` %s %s %s %s %s", column.OldName, column.NewName, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
|
||||||
|
if len(m.Renames) > index+1 {
|
||||||
|
sql += ","
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for index, foreign := range m.Foreigns {
|
||||||
|
sql += fmt.Sprintf("ADD `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default)
|
||||||
|
sql += fmt.Sprintf(",\n ADD KEY `%s_%s_foreign`(`%s`)", m.TableName, foreign.Column.Name, foreign.Column.Name)
|
||||||
|
sql += fmt.Sprintf(",\n ADD CONSTRAINT `%s_%s_foreign` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) %s %s", m.TableName, foreign.Column.Name, foreign.Column.Name, foreign.ForeignTable, foreign.ForeignColumn, foreign.OnDelete, foreign.OnUpdate)
|
||||||
|
if len(m.Foreigns) > index+1 {
|
||||||
|
sql += ","
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sql += ";"
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case "reverse":
|
||||||
|
{
|
||||||
|
|
||||||
|
sql += fmt.Sprintf("ALTER TABLE `%s`", m.TableName)
|
||||||
|
for index, column := range m.Columns {
|
||||||
|
if column.remove {
|
||||||
|
sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
|
||||||
|
} else {
|
||||||
|
sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name)
|
||||||
|
}
|
||||||
|
if len(m.Columns) > index {
|
||||||
|
sql += ","
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m.Primary) > 0 {
|
||||||
|
sql += fmt.Sprintf("\n DROP PRIMARY KEY,")
|
||||||
|
}
|
||||||
|
|
||||||
|
for index, unique := range m.Uniques {
|
||||||
|
sql += fmt.Sprintf("\n DROP KEY `%s`", unique.Definition)
|
||||||
|
if len(m.Uniques) > index {
|
||||||
|
sql += ","
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
for index, column := range m.Renames {
|
||||||
|
sql += fmt.Sprintf("\n CHANGE COLUMN `%s` `%s` %s %s %s %s", column.NewName, column.OldName, column.OldDataType, column.OldUnsign, column.OldNull, column.OldDefault)
|
||||||
|
if len(m.Renames) > index {
|
||||||
|
sql += ","
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, foreign := range m.Foreigns {
|
||||||
|
sql += fmt.Sprintf("\n DROP KEY `%s_%s_foreign`", m.TableName, foreign.Column.Name)
|
||||||
|
sql += fmt.Sprintf(",\n DROP FOREIGN KEY `%s_%s_foreign`", m.TableName, foreign.Column.Name)
|
||||||
|
sql += fmt.Sprintf(",\n DROP COLUMN `%s`", foreign.Name)
|
||||||
|
}
|
||||||
|
sql += ";"
|
||||||
|
}
|
||||||
|
case "delete":
|
||||||
|
{
|
||||||
|
sql += fmt.Sprintf("DROP TABLE IF EXISTS `%s`;", m.TableName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
32
migration/doc.go
Normal file
32
migration/doc.go
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
// Package migration enables you to generate migrations back and forth. It generates both migrations.
|
||||||
|
//
|
||||||
|
// //Creates a table
|
||||||
|
// m.CreateTable("tablename","InnoDB","utf8");
|
||||||
|
//
|
||||||
|
// //Alter a table
|
||||||
|
// m.AlterTable("tablename")
|
||||||
|
//
|
||||||
|
// Standard Column Methods
|
||||||
|
// * SetDataType
|
||||||
|
// * SetNullable
|
||||||
|
// * SetDefault
|
||||||
|
// * SetUnsigned (use only on integer types unless produces error)
|
||||||
|
//
|
||||||
|
// //Sets a primary column, multiple calls allowed, standard column methods available
|
||||||
|
// m.PriCol("id").SetAuto(true).SetNullable(false).SetDataType("INT(10)").SetUnsigned(true)
|
||||||
|
//
|
||||||
|
// //UniCol Can be used multiple times, allows standard Column methods. Use same "index" string to add to same index
|
||||||
|
// m.UniCol("index","column")
|
||||||
|
//
|
||||||
|
// //Standard Column Initialisation, can call .Remove() after NewCol("") on alter to remove
|
||||||
|
// m.NewCol("name").SetDataType("VARCHAR(255) COLLATE utf8_unicode_ci").SetNullable(false)
|
||||||
|
// m.NewCol("value").SetDataType("DOUBLE(8,2)").SetNullable(false)
|
||||||
|
//
|
||||||
|
// //Rename Columns , only use with Alter table, doesn't works with Create, prefix standard column methods with "Old" to
|
||||||
|
// //create a true reversible migration eg: SetOldDataType("DOUBLE(12,3)")
|
||||||
|
// m.RenameColumn("from","to")...
|
||||||
|
//
|
||||||
|
// //Foreign Columns, single columns are only supported, SetOnDelete & SetOnUpdate are available, call appropriately.
|
||||||
|
// //Supports standard column methods, automatic reverse.
|
||||||
|
// m.ForeignCol("local_col","foreign_col","foreign_table")
|
||||||
|
package migration
|
@ -52,6 +52,26 @@ type Migrationer interface {
|
|||||||
GetCreated() int64
|
GetCreated() int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//Migration defines the migrations by either SQL or DDL
|
||||||
|
type Migration struct {
|
||||||
|
sqls []string
|
||||||
|
Created string
|
||||||
|
TableName string
|
||||||
|
Engine string
|
||||||
|
Charset string
|
||||||
|
ModifyType string
|
||||||
|
Columns []*Column
|
||||||
|
Indexes []*Index
|
||||||
|
Primary []*Column
|
||||||
|
Uniques []*Unique
|
||||||
|
Foreigns []*Foreign
|
||||||
|
Renames []*RenameColumn
|
||||||
|
RemoveColumns []*Column
|
||||||
|
RemoveIndexes []*Index
|
||||||
|
RemoveUniques []*Unique
|
||||||
|
RemoveForeigns []*Foreign
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
migrationMap map[string]Migrationer
|
migrationMap map[string]Migrationer
|
||||||
)
|
)
|
||||||
@ -60,20 +80,34 @@ func init() {
|
|||||||
migrationMap = make(map[string]Migrationer)
|
migrationMap = make(map[string]Migrationer)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Migration the basic type which will implement the basic type
|
|
||||||
type Migration struct {
|
|
||||||
sqls []string
|
|
||||||
Created string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Up implement in the Inheritance struct for upgrade
|
// Up implement in the Inheritance struct for upgrade
|
||||||
func (m *Migration) Up() {
|
func (m *Migration) Up() {
|
||||||
|
|
||||||
|
switch m.ModifyType {
|
||||||
|
case "reverse":
|
||||||
|
m.ModifyType = "alter"
|
||||||
|
case "delete":
|
||||||
|
m.ModifyType = "create"
|
||||||
|
}
|
||||||
|
m.sqls = append(m.sqls, m.GetSQL())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Down implement in the Inheritance struct for down
|
// Down implement in the Inheritance struct for down
|
||||||
func (m *Migration) Down() {
|
func (m *Migration) Down() {
|
||||||
|
|
||||||
|
switch m.ModifyType {
|
||||||
|
case "alter":
|
||||||
|
m.ModifyType = "reverse"
|
||||||
|
case "create":
|
||||||
|
m.ModifyType = "delete"
|
||||||
|
}
|
||||||
|
m.sqls = append(m.sqls, m.GetSQL())
|
||||||
|
}
|
||||||
|
|
||||||
|
//Migrate adds the SQL to the execution list
|
||||||
|
func (m *Migration) Migrate(migrationType string) {
|
||||||
|
m.ModifyType = migrationType
|
||||||
|
m.sqls = append(m.sqls, m.GetSQL())
|
||||||
}
|
}
|
||||||
|
|
||||||
// SQL add sql want to execute
|
// SQL add sql want to execute
|
||||||
|
11
namespace.go
11
namespace.go
@ -267,13 +267,12 @@ func addPrefix(t *Tree, prefix string) {
|
|||||||
addPrefix(t.wildcard, prefix)
|
addPrefix(t.wildcard, prefix)
|
||||||
}
|
}
|
||||||
for _, l := range t.leaves {
|
for _, l := range t.leaves {
|
||||||
if c, ok := l.runObject.(*controllerInfo); ok {
|
if c, ok := l.runObject.(*ControllerInfo); ok {
|
||||||
if !strings.HasPrefix(c.pattern, prefix) {
|
if !strings.HasPrefix(c.pattern, prefix) {
|
||||||
c.pattern = prefix + c.pattern
|
c.pattern = prefix + c.pattern
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NSCond is Namespace Condition
|
// NSCond is Namespace Condition
|
||||||
@ -284,16 +283,16 @@ func NSCond(cond namespaceCond) LinkNamespace {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NSBefore Namespace BeforeRouter filter
|
// NSBefore Namespace BeforeRouter filter
|
||||||
func NSBefore(filiterList ...FilterFunc) LinkNamespace {
|
func NSBefore(filterList ...FilterFunc) LinkNamespace {
|
||||||
return func(ns *Namespace) {
|
return func(ns *Namespace) {
|
||||||
ns.Filter("before", filiterList...)
|
ns.Filter("before", filterList...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NSAfter add Namespace FinishRouter filter
|
// NSAfter add Namespace FinishRouter filter
|
||||||
func NSAfter(filiterList ...FilterFunc) LinkNamespace {
|
func NSAfter(filterList ...FilterFunc) LinkNamespace {
|
||||||
return func(ns *Namespace) {
|
return func(ns *Namespace) {
|
||||||
ns.Filter("after", filiterList...)
|
ns.Filter("after", filterList...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,10 +139,7 @@ func TestNamespaceCond(t *testing.T) {
|
|||||||
|
|
||||||
ns := NewNamespace("/v2")
|
ns := NewNamespace("/v2")
|
||||||
ns.Cond(func(ctx *context.Context) bool {
|
ns.Cond(func(ctx *context.Context) bool {
|
||||||
if ctx.Input.Domain() == "beego.me" {
|
return ctx.Input.Domain() == "beego.me"
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}).
|
}).
|
||||||
AutoRouter(&TestController{})
|
AutoRouter(&TestController{})
|
||||||
AddNamespace(ns)
|
AddNamespace(ns)
|
||||||
|
@ -150,7 +150,7 @@ func (d *commandSyncDb) Run() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, fi := range mi.fields.fieldsDB {
|
for _, fi := range mi.fields.fieldsDB {
|
||||||
if _, ok := columns[fi.column]; ok == false {
|
if _, ok := columns[fi.column]; !ok {
|
||||||
fields = append(fields, fi)
|
fields = append(fields, fi)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -175,7 +175,7 @@ func (d *commandSyncDb) Run() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, idx := range indexes[mi.table] {
|
for _, idx := range indexes[mi.table] {
|
||||||
if d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) == false {
|
if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) {
|
||||||
if !d.noInfo {
|
if !d.noInfo {
|
||||||
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
|
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
|
||||||
}
|
}
|
||||||
|
@ -89,7 +89,7 @@ checkColumn:
|
|||||||
col = T["float64"]
|
col = T["float64"]
|
||||||
case TypeDecimalField:
|
case TypeDecimalField:
|
||||||
s := T["float64-decimal"]
|
s := T["float64-decimal"]
|
||||||
if strings.Index(s, "%d") == -1 {
|
if !strings.Contains(s, "%d") {
|
||||||
col = s
|
col = s
|
||||||
} else {
|
} else {
|
||||||
col = fmt.Sprintf(s, fi.digits, fi.decimals)
|
col = fmt.Sprintf(s, fi.digits, fi.decimals)
|
||||||
@ -120,7 +120,7 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string {
|
|||||||
Q := al.DbBaser.TableQuote()
|
Q := al.DbBaser.TableQuote()
|
||||||
typ := getColumnTyp(al, fi)
|
typ := getColumnTyp(al, fi)
|
||||||
|
|
||||||
if fi.null == false {
|
if !fi.null {
|
||||||
typ += " " + "NOT NULL"
|
typ += " " + "NOT NULL"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -172,7 +172,7 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
|
|||||||
} else {
|
} else {
|
||||||
column += col
|
column += col
|
||||||
|
|
||||||
if fi.null == false {
|
if !fi.null {
|
||||||
column += " " + "NOT NULL"
|
column += " " + "NOT NULL"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -192,7 +192,7 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.Index(column, "%COL%") != -1 {
|
if strings.Contains(column, "%COL%") {
|
||||||
column = strings.Replace(column, "%COL%", fi.column, -1)
|
column = strings.Replace(column, "%COL%", fi.column, -1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
45
orm/db.go
45
orm/db.go
@ -48,7 +48,7 @@ var (
|
|||||||
"lte": true,
|
"lte": true,
|
||||||
"eq": true,
|
"eq": true,
|
||||||
"nq": true,
|
"nq": true,
|
||||||
"ne": true,
|
"ne": true,
|
||||||
"startswith": true,
|
"startswith": true,
|
||||||
"endswith": true,
|
"endswith": true,
|
||||||
"istartswith": true,
|
"istartswith": true,
|
||||||
@ -87,7 +87,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
|
|||||||
} else {
|
} else {
|
||||||
panic(fmt.Errorf("wrong db field/column name `%s` for model `%s`", column, mi.fullName))
|
panic(fmt.Errorf("wrong db field/column name `%s` for model `%s`", column, mi.fullName))
|
||||||
}
|
}
|
||||||
if fi.dbcol == false || fi.auto && skipAuto {
|
if !fi.dbcol || fi.auto && skipAuto {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
|
value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
|
||||||
@ -224,7 +224,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
|
|||||||
value = nil
|
value = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if fi.null == false && value == nil {
|
if !fi.null && value == nil {
|
||||||
return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName)
|
return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -271,7 +271,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
|
|||||||
dbcols := make([]string, 0, len(mi.fields.dbcols))
|
dbcols := make([]string, 0, len(mi.fields.dbcols))
|
||||||
marks := make([]string, 0, len(mi.fields.dbcols))
|
marks := make([]string, 0, len(mi.fields.dbcols))
|
||||||
for _, fi := range mi.fields.fieldsDB {
|
for _, fi := range mi.fields.fieldsDB {
|
||||||
if fi.auto == false {
|
if !fi.auto {
|
||||||
dbcols = append(dbcols, fi.column)
|
dbcols = append(dbcols, fi.column)
|
||||||
marks = append(marks, "?")
|
marks = append(marks, "?")
|
||||||
}
|
}
|
||||||
@ -326,7 +326,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
|
|||||||
} else {
|
} else {
|
||||||
// default use pk value as where condtion.
|
// default use pk value as where condtion.
|
||||||
pkColumn, pkValue, ok := getExistPk(mi, ind)
|
pkColumn, pkValue, ok := getExistPk(mi, ind)
|
||||||
if ok == false {
|
if !ok {
|
||||||
return ErrMissPK
|
return ErrMissPK
|
||||||
}
|
}
|
||||||
whereCols = []string{pkColumn}
|
whereCols = []string{pkColumn}
|
||||||
@ -507,10 +507,9 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
|
|||||||
case DRPostgres:
|
case DRPostgres:
|
||||||
if len(args) == 0 {
|
if len(args) == 0 {
|
||||||
return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName)
|
return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName)
|
||||||
} else {
|
|
||||||
args0 = strings.ToLower(args[0])
|
|
||||||
iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0)
|
|
||||||
}
|
}
|
||||||
|
args0 = strings.ToLower(args[0])
|
||||||
|
iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0)
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName)
|
return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName)
|
||||||
}
|
}
|
||||||
@ -592,7 +591,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
|
|||||||
row := q.QueryRow(query, values...)
|
row := q.QueryRow(query, values...)
|
||||||
var id int64
|
var id int64
|
||||||
err = row.Scan(&id)
|
err = row.Scan(&id)
|
||||||
if err.Error() == `pq: syntax error at or near "ON"` {
|
if err != nil && err.Error() == `pq: syntax error at or near "ON"` {
|
||||||
err = fmt.Errorf("postgres version must 9.5 or higher")
|
err = fmt.Errorf("postgres version must 9.5 or higher")
|
||||||
}
|
}
|
||||||
return id, err
|
return id, err
|
||||||
@ -601,7 +600,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
|
|||||||
// execute update sql dbQuerier with given struct reflect.Value.
|
// execute update sql dbQuerier with given struct reflect.Value.
|
||||||
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
|
||||||
pkName, pkValue, ok := getExistPk(mi, ind)
|
pkName, pkValue, ok := getExistPk(mi, ind)
|
||||||
if ok == false {
|
if !ok {
|
||||||
return 0, ErrMissPK
|
return 0, ErrMissPK
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -654,7 +653,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
|
|||||||
} else {
|
} else {
|
||||||
// default use pk value as where condtion.
|
// default use pk value as where condtion.
|
||||||
pkColumn, pkValue, ok := getExistPk(mi, ind)
|
pkColumn, pkValue, ok := getExistPk(mi, ind)
|
||||||
if ok == false {
|
if !ok {
|
||||||
return 0, ErrMissPK
|
return 0, ErrMissPK
|
||||||
}
|
}
|
||||||
whereCols = []string{pkColumn}
|
whereCols = []string{pkColumn}
|
||||||
@ -699,7 +698,7 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
|||||||
columns := make([]string, 0, len(params))
|
columns := make([]string, 0, len(params))
|
||||||
values := make([]interface{}, 0, len(params))
|
values := make([]interface{}, 0, len(params))
|
||||||
for col, val := range params {
|
for col, val := range params {
|
||||||
if fi, ok := mi.fields.GetByAny(col); ok == false || fi.dbcol == false {
|
if fi, ok := mi.fields.GetByAny(col); !ok || !fi.dbcol {
|
||||||
panic(fmt.Errorf("wrong field/column name `%s`", col))
|
panic(fmt.Errorf("wrong field/column name `%s`", col))
|
||||||
} else {
|
} else {
|
||||||
columns = append(columns, fi.column)
|
columns = append(columns, fi.column)
|
||||||
@ -834,7 +833,11 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
|
|||||||
if err := rs.Scan(&ref); err != nil {
|
if err := rs.Scan(&ref); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
args = append(args, reflect.ValueOf(ref).Interface())
|
pkValue, err := d.convertValueFromDB(mi.fields.pk, reflect.ValueOf(ref).Interface(), tz)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
args = append(args, pkValue)
|
||||||
cnt++
|
cnt++
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -929,7 +932,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
|||||||
if hasRel {
|
if hasRel {
|
||||||
for _, fi := range mi.fields.fieldsDB {
|
for _, fi := range mi.fields.fieldsDB {
|
||||||
if fi.fieldType&IsRelField > 0 {
|
if fi.fieldType&IsRelField > 0 {
|
||||||
if maps[fi.column] == false {
|
if !maps[fi.column] {
|
||||||
tCols = append(tCols, fi.column)
|
tCols = append(tCols, fi.column)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -987,7 +990,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
|||||||
|
|
||||||
var cnt int64
|
var cnt int64
|
||||||
for rs.Next() {
|
for rs.Next() {
|
||||||
if one && cnt == 0 || one == false {
|
if one && cnt == 0 || !one {
|
||||||
if err := rs.Scan(refs...); err != nil {
|
if err := rs.Scan(refs...); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -1067,7 +1070,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
|
|||||||
cnt++
|
cnt++
|
||||||
}
|
}
|
||||||
|
|
||||||
if one == false {
|
if !one {
|
||||||
if cnt > 0 {
|
if cnt > 0 {
|
||||||
ind.Set(slice)
|
ind.Set(slice)
|
||||||
} else {
|
} else {
|
||||||
@ -1110,7 +1113,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
|
|||||||
|
|
||||||
// generate sql with replacing operator string placeholders and replaced values.
|
// generate sql with replacing operator string placeholders and replaced values.
|
||||||
func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
|
func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
|
||||||
sql := ""
|
var sql string
|
||||||
params := getFlatParams(fi, args, tz)
|
params := getFlatParams(fi, args, tz)
|
||||||
|
|
||||||
if len(params) == 0 {
|
if len(params) == 0 {
|
||||||
@ -1357,7 +1360,7 @@ end:
|
|||||||
func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) {
|
func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) {
|
||||||
|
|
||||||
fieldType := fi.fieldType
|
fieldType := fi.fieldType
|
||||||
isNative := fi.isFielder == false
|
isNative := !fi.isFielder
|
||||||
|
|
||||||
setValue:
|
setValue:
|
||||||
switch {
|
switch {
|
||||||
@ -1533,7 +1536,7 @@ setValue:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if isNative == false {
|
if !isNative {
|
||||||
fd := field.Addr().Interface().(Fielder)
|
fd := field.Addr().Interface().(Fielder)
|
||||||
err := fd.SetRaw(value)
|
err := fd.SetRaw(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1594,7 +1597,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
|
|||||||
infos = make([]*fieldInfo, 0, len(exprs))
|
infos = make([]*fieldInfo, 0, len(exprs))
|
||||||
for _, ex := range exprs {
|
for _, ex := range exprs {
|
||||||
index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep))
|
index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep))
|
||||||
if suc == false {
|
if !suc {
|
||||||
panic(fmt.Errorf("unknown field/column name `%s`", ex))
|
panic(fmt.Errorf("unknown field/column name `%s`", ex))
|
||||||
}
|
}
|
||||||
cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q))
|
cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q))
|
||||||
@ -1733,7 +1736,7 @@ func (d *dbBase) TableQuote() string {
|
|||||||
return "`"
|
return "`"
|
||||||
}
|
}
|
||||||
|
|
||||||
// replace value placeholer in parametered sql string.
|
// replace value placeholder in parametered sql string.
|
||||||
func (d *dbBase) ReplaceMarks(query *string) {
|
func (d *dbBase) ReplaceMarks(query *string) {
|
||||||
// default use `?` as mark, do nothing
|
// default use `?` as mark, do nothing
|
||||||
}
|
}
|
||||||
|
@ -60,6 +60,8 @@ var (
|
|||||||
"sqlite3": DRSqlite,
|
"sqlite3": DRSqlite,
|
||||||
"tidb": DRTiDB,
|
"tidb": DRTiDB,
|
||||||
"oracle": DROracle,
|
"oracle": DROracle,
|
||||||
|
"oci8": DROracle, // github.com/mattn/go-oci8
|
||||||
|
"ora": DROracle, //https://github.com/rana/ora
|
||||||
}
|
}
|
||||||
dbBasers = map[DriverType]dbBaser{
|
dbBasers = map[DriverType]dbBaser{
|
||||||
DRMySQL: newdbBaseMysql(),
|
DRMySQL: newdbBaseMysql(),
|
||||||
@ -186,7 +188,7 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
|
|||||||
return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
|
return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
if dataBaseCache.add(aliasName, al) == false {
|
if !dataBaseCache.add(aliasName, al) {
|
||||||
return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
|
return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -244,11 +246,11 @@ end:
|
|||||||
|
|
||||||
// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
|
// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
|
||||||
func RegisterDriver(driverName string, typ DriverType) error {
|
func RegisterDriver(driverName string, typ DriverType) error {
|
||||||
if t, ok := drivers[driverName]; ok == false {
|
if t, ok := drivers[driverName]; !ok {
|
||||||
drivers[driverName] = typ
|
drivers[driverName] = typ
|
||||||
} else {
|
} else {
|
||||||
if t != typ {
|
if t != typ {
|
||||||
return fmt.Errorf("driverName `%s` db driver already registered and is other type\n", driverName)
|
return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -259,7 +261,7 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error {
|
|||||||
if al, ok := dataBaseCache.get(aliasName); ok {
|
if al, ok := dataBaseCache.get(aliasName); ok {
|
||||||
al.TZ = tz
|
al.TZ = tz
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("DataBase alias name `%s` not registered\n", aliasName)
|
return fmt.Errorf("DataBase alias name `%s` not registered", aliasName)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -294,5 +296,5 @@ func GetDB(aliasNames ...string) (*sql.DB, error) {
|
|||||||
if ok {
|
if ok {
|
||||||
return al.DB, nil
|
return al.DB, nil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
|
return nil, fmt.Errorf("DataBase of alias name `%s` not found", name)
|
||||||
}
|
}
|
||||||
|
@ -103,8 +103,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool
|
|||||||
// If no will insert
|
// If no will insert
|
||||||
// Add "`" for mysql sql building
|
// Add "`" for mysql sql building
|
||||||
func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
|
func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
|
||||||
|
var iouStr string
|
||||||
iouStr := ""
|
|
||||||
argsMap := map[string]string{}
|
argsMap := map[string]string{}
|
||||||
|
|
||||||
iouStr = "ON DUPLICATE KEY UPDATE"
|
iouStr = "ON DUPLICATE KEY UPDATE"
|
||||||
|
@ -94,3 +94,43 @@ func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool
|
|||||||
row.Scan(&cnt)
|
row.Scan(&cnt)
|
||||||
return cnt > 0
|
return cnt > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// execute insert sql with given struct and given values.
|
||||||
|
// insert the given values, not the field values in struct.
|
||||||
|
func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
|
||||||
|
Q := d.ins.TableQuote()
|
||||||
|
|
||||||
|
marks := make([]string, len(names))
|
||||||
|
for i := range marks {
|
||||||
|
marks[i] = ":" + names[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
sep := fmt.Sprintf("%s, %s", Q, Q)
|
||||||
|
qmarks := strings.Join(marks, ", ")
|
||||||
|
columns := strings.Join(names, sep)
|
||||||
|
|
||||||
|
multi := len(values) / len(names)
|
||||||
|
|
||||||
|
if isMulti {
|
||||||
|
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
|
||||||
|
}
|
||||||
|
|
||||||
|
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
|
||||||
|
|
||||||
|
d.ins.ReplaceMarks(&query)
|
||||||
|
|
||||||
|
if isMulti || !d.ins.HasReturningID(mi, &query) {
|
||||||
|
res, err := q.Exec(query, values...)
|
||||||
|
if err == nil {
|
||||||
|
if isMulti {
|
||||||
|
return res.RowsAffected()
|
||||||
|
}
|
||||||
|
return res.LastInsertId()
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
row := q.QueryRow(query, values...)
|
||||||
|
var id int64
|
||||||
|
err := row.Scan(&id)
|
||||||
|
return id, err
|
||||||
|
}
|
||||||
|
@ -134,7 +134,7 @@ func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool
|
|||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var tmp, index sql.NullString
|
var tmp, index sql.NullString
|
||||||
rows.Scan(&tmp, &index, &tmp)
|
rows.Scan(&tmp, &index, &tmp, &tmp, &tmp)
|
||||||
if name == index.String {
|
if name == index.String {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -63,7 +63,7 @@ func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
|
|||||||
// add table info to collection.
|
// add table info to collection.
|
||||||
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
|
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
|
||||||
name := strings.Join(names, ExprSep)
|
name := strings.Join(names, ExprSep)
|
||||||
if _, ok := t.tablesM[name]; ok == false {
|
if _, ok := t.tablesM[name]; !ok {
|
||||||
i := len(t.tables) + 1
|
i := len(t.tables) + 1
|
||||||
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
|
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
|
||||||
t.tablesM[name] = jt
|
t.tablesM[name] = jt
|
||||||
@ -261,7 +261,7 @@ loopFor:
|
|||||||
fiN, okN = mmi.fields.GetByAny(exprs[i+1])
|
fiN, okN = mmi.fields.GetByAny(exprs[i+1])
|
||||||
}
|
}
|
||||||
|
|
||||||
if isRel && (fi.mi.isThrough == false || num != i) {
|
if isRel && (!fi.mi.isThrough || num != i) {
|
||||||
if fi.null || t.skipEnd {
|
if fi.null || t.skipEnd {
|
||||||
inner = false
|
inner = false
|
||||||
}
|
}
|
||||||
@ -364,7 +364,7 @@ func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (whe
|
|||||||
}
|
}
|
||||||
|
|
||||||
index, _, fi, suc := t.parseExprs(mi, exprs)
|
index, _, fi, suc := t.parseExprs(mi, exprs)
|
||||||
if suc == false {
|
if !suc {
|
||||||
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
|
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -383,7 +383,7 @@ func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (whe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sub == false && where != "" {
|
if !sub && where != "" {
|
||||||
where = "WHERE " + where
|
where = "WHERE " + where
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -403,7 +403,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
|
|||||||
exprs := strings.Split(group, ExprSep)
|
exprs := strings.Split(group, ExprSep)
|
||||||
|
|
||||||
index, _, fi, suc := t.parseExprs(t.mi, exprs)
|
index, _, fi, suc := t.parseExprs(t.mi, exprs)
|
||||||
if suc == false {
|
if !suc {
|
||||||
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
|
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -432,7 +432,7 @@ func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
|
|||||||
exprs := strings.Split(order, ExprSep)
|
exprs := strings.Split(order, ExprSep)
|
||||||
|
|
||||||
index, _, fi, suc := t.parseExprs(t.mi, exprs)
|
index, _, fi, suc := t.parseExprs(t.mi, exprs)
|
||||||
if suc == false {
|
if !suc {
|
||||||
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
|
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,6 +41,8 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac
|
|||||||
vu := v.Int()
|
vu := v.Int()
|
||||||
exist = true
|
exist = true
|
||||||
value = vu
|
value = vu
|
||||||
|
} else if fi.fieldType&IsRelField > 0 {
|
||||||
|
_, value, exist = getExistPk(fi.relModelInfo, reflect.Indirect(v))
|
||||||
} else {
|
} else {
|
||||||
vu := v.String()
|
vu := v.String()
|
||||||
exist = vu != ""
|
exist = vu != ""
|
||||||
|
@ -75,7 +75,7 @@ func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if mi.fields.pk == nil {
|
if mi.fields.pk == nil {
|
||||||
fmt.Printf("<orm.RegisterModel> `%s` need a primary key field, default use 'id' if not set\n", name)
|
fmt.Printf("<orm.RegisterModel> `%s` needs a primary key field, default is to use 'id' if not set\n", name)
|
||||||
os.Exit(2)
|
os.Exit(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -117,7 +117,7 @@ func bootStrap() {
|
|||||||
name := getFullName(elm)
|
name := getFullName(elm)
|
||||||
mii, ok := modelCache.getByFullName(name)
|
mii, ok := modelCache.getByFullName(name)
|
||||||
if !ok || mii.pkg != elm.PkgPath() {
|
if !ok || mii.pkg != elm.PkgPath() {
|
||||||
err = fmt.Errorf("can not found rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String())
|
err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String())
|
||||||
goto end
|
goto end
|
||||||
}
|
}
|
||||||
fi.relModelInfo = mii
|
fi.relModelInfo = mii
|
||||||
@ -128,7 +128,7 @@ func bootStrap() {
|
|||||||
if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) {
|
if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) {
|
||||||
pn := fi.relThrough[:i]
|
pn := fi.relThrough[:i]
|
||||||
rmi, ok := modelCache.getByFullName(fi.relThrough)
|
rmi, ok := modelCache.getByFullName(fi.relThrough)
|
||||||
if ok == false || pn != rmi.pkg {
|
if !ok || pn != rmi.pkg {
|
||||||
err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough)
|
err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough)
|
||||||
goto end
|
goto end
|
||||||
}
|
}
|
||||||
@ -171,7 +171,7 @@ func bootStrap() {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if inModel == false {
|
if !inModel {
|
||||||
rmi := fi.relModelInfo
|
rmi := fi.relModelInfo
|
||||||
ffi := new(fieldInfo)
|
ffi := new(fieldInfo)
|
||||||
ffi.name = mi.name
|
ffi.name = mi.name
|
||||||
@ -185,7 +185,7 @@ func bootStrap() {
|
|||||||
} else {
|
} else {
|
||||||
ffi.fieldType = RelReverseMany
|
ffi.fieldType = RelReverseMany
|
||||||
}
|
}
|
||||||
if rmi.fields.Add(ffi) == false {
|
if !rmi.fields.Add(ffi) {
|
||||||
added := false
|
added := false
|
||||||
for cnt := 0; cnt < 5; cnt++ {
|
for cnt := 0; cnt < 5; cnt++ {
|
||||||
ffi.name = fmt.Sprintf("%s%d", mi.name, cnt)
|
ffi.name = fmt.Sprintf("%s%d", mi.name, cnt)
|
||||||
@ -195,7 +195,7 @@ func bootStrap() {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if added == false {
|
if !added {
|
||||||
panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName))
|
panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -248,7 +248,7 @@ func bootStrap() {
|
|||||||
break mForA
|
break mForA
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if found == false {
|
if !found {
|
||||||
err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
|
err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
|
||||||
goto end
|
goto end
|
||||||
}
|
}
|
||||||
@ -267,7 +267,7 @@ func bootStrap() {
|
|||||||
break mForB
|
break mForB
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if found == false {
|
if !found {
|
||||||
mForC:
|
mForC:
|
||||||
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] {
|
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] {
|
||||||
conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough ||
|
conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough ||
|
||||||
@ -287,7 +287,7 @@ func bootStrap() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if found == false {
|
if !found {
|
||||||
err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
|
err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
|
||||||
goto end
|
goto end
|
||||||
}
|
}
|
||||||
|
@ -47,7 +47,7 @@ func (f *fields) Add(fi *fieldInfo) (added bool) {
|
|||||||
} else {
|
} else {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, ok := f.fieldsByType[fi.fieldType]; ok == false {
|
if _, ok := f.fieldsByType[fi.fieldType]; !ok {
|
||||||
f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0)
|
f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0)
|
||||||
}
|
}
|
||||||
f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi)
|
f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi)
|
||||||
@ -334,12 +334,12 @@ checkType:
|
|||||||
switch onDelete {
|
switch onDelete {
|
||||||
case odCascade, odDoNothing:
|
case odCascade, odDoNothing:
|
||||||
case odSetDefault:
|
case odSetDefault:
|
||||||
if initial.Exist() == false {
|
if !initial.Exist() {
|
||||||
err = errors.New("on_delete: set_default need set field a default value")
|
err = errors.New("on_delete: set_default need set field a default value")
|
||||||
goto end
|
goto end
|
||||||
}
|
}
|
||||||
case odSetNULL:
|
case odSetNULL:
|
||||||
if fi.null == false {
|
if !fi.null {
|
||||||
err = errors.New("on_delete: set_null need set field null")
|
err = errors.New("on_delete: set_null need set field null")
|
||||||
goto end
|
goto end
|
||||||
}
|
}
|
||||||
|
@ -78,7 +78,7 @@ func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int)
|
|||||||
fi.fieldIndex = append(index, i)
|
fi.fieldIndex = append(index, i)
|
||||||
fi.mi = mi
|
fi.mi = mi
|
||||||
fi.inModel = true
|
fi.inModel = true
|
||||||
if mi.fields.Add(fi) == false {
|
if !mi.fields.Add(fi) {
|
||||||
err = fmt.Errorf("duplicate column name: %s", fi.column)
|
err = fmt.Errorf("duplicate column name: %s", fi.column)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -406,6 +406,11 @@ type UintPk struct {
|
|||||||
Name string
|
Name string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PtrPk struct {
|
||||||
|
ID *IntegerPk `orm:"pk;rel(one)"`
|
||||||
|
Positive bool
|
||||||
|
}
|
||||||
|
|
||||||
var DBARGS = struct {
|
var DBARGS = struct {
|
||||||
Driver string
|
Driver string
|
||||||
Source string
|
Source string
|
||||||
|
32
orm/orm.go
32
orm/orm.go
@ -107,7 +107,7 @@ func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect
|
|||||||
if mi, ok := modelCache.getByFullName(name); ok {
|
if mi, ok := modelCache.getByFullName(name); ok {
|
||||||
return mi, ind
|
return mi, ind
|
||||||
}
|
}
|
||||||
panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
|
panic(fmt.Errorf("<Ormer> table: `%s` not found, make sure it was registered with `RegisterModel()`", name))
|
||||||
}
|
}
|
||||||
|
|
||||||
// get field info from model info by given field name
|
// get field info from model info by given field name
|
||||||
@ -122,21 +122,13 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
|
|||||||
// read data to model
|
// read data to model
|
||||||
func (o *orm) Read(md interface{}, cols ...string) error {
|
func (o *orm) Read(md interface{}, cols ...string) error {
|
||||||
mi, ind := o.getMiInd(md, true)
|
mi, ind := o.getMiInd(md, true)
|
||||||
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
|
return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// read data to model, like Read(), but use "SELECT FOR UPDATE" form
|
// read data to model, like Read(), but use "SELECT FOR UPDATE" form
|
||||||
func (o *orm) ReadForUpdate(md interface{}, cols ...string) error {
|
func (o *orm) ReadForUpdate(md interface{}, cols ...string) error {
|
||||||
mi, ind := o.getMiInd(md, true)
|
mi, ind := o.getMiInd(md, true)
|
||||||
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true)
|
return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to read a row from the database, or insert one if it doesn't exist
|
// Try to read a row from the database, or insert one if it doesn't exist
|
||||||
@ -153,6 +145,8 @@ func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, i
|
|||||||
id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex)
|
id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex)
|
||||||
if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
|
if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
|
||||||
id = int64(vid.Uint())
|
id = int64(vid.Uint())
|
||||||
|
} else if mi.fields.pk.rel {
|
||||||
|
return o.ReadOrCreate(vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name)
|
||||||
} else {
|
} else {
|
||||||
id = vid.Int()
|
id = vid.Int()
|
||||||
}
|
}
|
||||||
@ -236,15 +230,11 @@ func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64
|
|||||||
// cols set the columns those want to update.
|
// cols set the columns those want to update.
|
||||||
func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
|
func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
|
||||||
mi, ind := o.getMiInd(md, true)
|
mi, ind := o.getMiInd(md, true)
|
||||||
num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
|
return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
|
||||||
if err != nil {
|
|
||||||
return num, err
|
|
||||||
}
|
|
||||||
return num, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// delete model in database
|
// delete model in database
|
||||||
// cols shows the delete conditions values read from. deafult is pk
|
// cols shows the delete conditions values read from. default is pk
|
||||||
func (o *orm) Delete(md interface{}, cols ...string) (int64, error) {
|
func (o *orm) Delete(md interface{}, cols ...string) (int64, error) {
|
||||||
mi, ind := o.getMiInd(md, true)
|
mi, ind := o.getMiInd(md, true)
|
||||||
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols)
|
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols)
|
||||||
@ -359,7 +349,7 @@ func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo,
|
|||||||
fi := o.getFieldInfo(mi, name)
|
fi := o.getFieldInfo(mi, name)
|
||||||
|
|
||||||
_, _, exist := getExistPk(mi, ind)
|
_, _, exist := getExistPk(mi, ind)
|
||||||
if exist == false {
|
if !exist {
|
||||||
panic(ErrMissPK)
|
panic(ErrMissPK)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -430,7 +420,7 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
|
|||||||
// table name can be string or struct.
|
// table name can be string or struct.
|
||||||
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
|
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
|
||||||
func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
|
func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
|
||||||
name := ""
|
var name string
|
||||||
if table, ok := ptrStructOrTableName.(string); ok {
|
if table, ok := ptrStructOrTableName.(string); ok {
|
||||||
name = snakeString(table)
|
name = snakeString(table)
|
||||||
if mi, ok := modelCache.get(name); ok {
|
if mi, ok := modelCache.get(name); ok {
|
||||||
@ -487,7 +477,7 @@ func (o *orm) Begin() error {
|
|||||||
|
|
||||||
// commit transaction
|
// commit transaction
|
||||||
func (o *orm) Commit() error {
|
func (o *orm) Commit() error {
|
||||||
if o.isTx == false {
|
if !o.isTx {
|
||||||
return ErrTxDone
|
return ErrTxDone
|
||||||
}
|
}
|
||||||
err := o.db.(txEnder).Commit()
|
err := o.db.(txEnder).Commit()
|
||||||
@ -502,7 +492,7 @@ func (o *orm) Commit() error {
|
|||||||
|
|
||||||
// rollback transaction
|
// rollback transaction
|
||||||
func (o *orm) Rollback() error {
|
func (o *orm) Rollback() error {
|
||||||
if o.isTx == false {
|
if !o.isTx {
|
||||||
return ErrTxDone
|
return ErrTxDone
|
||||||
}
|
}
|
||||||
err := o.db.(txEnder).Rollback()
|
err := o.db.(txEnder).Rollback()
|
||||||
|
@ -72,7 +72,7 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, v1, exist := getExistPk(o.mi, o.ind)
|
_, v1, exist := getExistPk(o.mi, o.ind)
|
||||||
if exist == false {
|
if !exist {
|
||||||
panic(ErrMissPK)
|
panic(ErrMissPK)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,7 +87,7 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
|
|||||||
v2 = ind.Interface()
|
v2 = ind.Interface()
|
||||||
} else {
|
} else {
|
||||||
_, v2, exist = getExistPk(fi.relModelInfo, ind)
|
_, v2, exist = getExistPk(fi.relModelInfo, ind)
|
||||||
if exist == false {
|
if !exist {
|
||||||
panic(ErrMissPK)
|
panic(ErrMissPK)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -104,11 +104,7 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
|
|||||||
fi := o.fi
|
fi := o.fi
|
||||||
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
|
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
|
||||||
|
|
||||||
nums, err := qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete()
|
return qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete()
|
||||||
if err != nil {
|
|
||||||
return nums, err
|
|
||||||
}
|
|
||||||
return nums, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// check model is existed in relationship of origin model
|
// check model is existed in relationship of origin model
|
||||||
|
@ -153,6 +153,11 @@ func (o querySet) SetCond(cond *Condition) QuerySeter {
|
|||||||
return &o
|
return &o
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get condition from QuerySeter
|
||||||
|
func (o querySet) GetCond() *Condition {
|
||||||
|
return o.cond
|
||||||
|
}
|
||||||
|
|
||||||
// return QuerySeter execution result number
|
// return QuerySeter execution result number
|
||||||
func (o *querySet) Count() (int64, error) {
|
func (o *querySet) Count() (int64, error) {
|
||||||
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
|
||||||
|
@ -493,19 +493,33 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for i := 0; i < ind.NumField(); i++ {
|
// define recursive function
|
||||||
f := ind.Field(i)
|
var recursiveSetField func(rv reflect.Value)
|
||||||
fe := ind.Type().Field(i)
|
recursiveSetField = func(rv reflect.Value) {
|
||||||
_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
|
for i := 0; i < rv.NumField(); i++ {
|
||||||
var col string
|
f := rv.Field(i)
|
||||||
if col = tags["column"]; col == "" {
|
fe := rv.Type().Field(i)
|
||||||
col = snakeString(fe.Name)
|
|
||||||
}
|
// check if the field is a Struct
|
||||||
if v, ok := columnsMp[col]; ok {
|
// recursive the Struct type
|
||||||
value := reflect.ValueOf(v).Elem().Interface()
|
if fe.Type.Kind() == reflect.Struct {
|
||||||
o.setFieldValue(f, value)
|
recursiveSetField(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
|
||||||
|
var col string
|
||||||
|
if col = tags["column"]; col == "" {
|
||||||
|
col = snakeString(fe.Name)
|
||||||
|
}
|
||||||
|
if v, ok := columnsMp[col]; ok {
|
||||||
|
value := reflect.ValueOf(v).Elem().Interface()
|
||||||
|
o.setFieldValue(f, value)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// init call the recursive function
|
||||||
|
recursiveSetField(ind)
|
||||||
}
|
}
|
||||||
|
|
||||||
if eTyps[0].Kind() == reflect.Ptr {
|
if eTyps[0].Kind() == reflect.Ptr {
|
||||||
@ -671,7 +685,7 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
|
|||||||
ind *reflect.Value
|
ind *reflect.Value
|
||||||
)
|
)
|
||||||
|
|
||||||
typ := 0
|
var typ int
|
||||||
switch container.(type) {
|
switch container.(type) {
|
||||||
case *Params:
|
case *Params:
|
||||||
typ = 1
|
typ = 1
|
||||||
|
@ -93,14 +93,14 @@ wrongArg:
|
|||||||
}
|
}
|
||||||
|
|
||||||
func AssertIs(a interface{}, args ...interface{}) error {
|
func AssertIs(a interface{}, args ...interface{}) error {
|
||||||
if ok, err := ValuesCompare(true, a, args...); ok == false {
|
if ok, err := ValuesCompare(true, a, args...); !ok {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func AssertNot(a interface{}, args ...interface{}) error {
|
func AssertNot(a interface{}, args ...interface{}) error {
|
||||||
if ok, err := ValuesCompare(false, a, args...); ok == false {
|
if ok, err := ValuesCompare(false, a, args...); !ok {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -135,7 +135,7 @@ func getCaller(skip int) string {
|
|||||||
if i := strings.LastIndex(funName, "."); i > -1 {
|
if i := strings.LastIndex(funName, "."); i > -1 {
|
||||||
funName = funName[i+1:]
|
funName = funName[i+1:]
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s:%d: \n%s", fn, line, strings.Join(codes, "\n"))
|
return fmt.Sprintf("%s:%s:%d: \n%s", fn, funName, line, strings.Join(codes, "\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func throwFail(t *testing.T, err error, args ...interface{}) {
|
func throwFail(t *testing.T, err error, args ...interface{}) {
|
||||||
@ -193,6 +193,7 @@ func TestSyncDb(t *testing.T) {
|
|||||||
RegisterModel(new(InLineOneToOne))
|
RegisterModel(new(InLineOneToOne))
|
||||||
RegisterModel(new(IntegerPk))
|
RegisterModel(new(IntegerPk))
|
||||||
RegisterModel(new(UintPk))
|
RegisterModel(new(UintPk))
|
||||||
|
RegisterModel(new(PtrPk))
|
||||||
|
|
||||||
err := RunSyncdb("default", true, Debug)
|
err := RunSyncdb("default", true, Debug)
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
@ -216,6 +217,7 @@ func TestRegisterModels(t *testing.T) {
|
|||||||
RegisterModel(new(InLineOneToOne))
|
RegisterModel(new(InLineOneToOne))
|
||||||
RegisterModel(new(IntegerPk))
|
RegisterModel(new(IntegerPk))
|
||||||
RegisterModel(new(UintPk))
|
RegisterModel(new(UintPk))
|
||||||
|
RegisterModel(new(PtrPk))
|
||||||
|
|
||||||
BootStrap()
|
BootStrap()
|
||||||
|
|
||||||
@ -1012,6 +1014,8 @@ func TestAll(t *testing.T) {
|
|||||||
var users3 []*User
|
var users3 []*User
|
||||||
qs = dORM.QueryTable("user")
|
qs = dORM.QueryTable("user")
|
||||||
num, err = qs.Filter("user_name", "nothing").All(&users3)
|
num, err = qs.Filter("user_name", "nothing").All(&users3)
|
||||||
|
throwFailNow(t, err)
|
||||||
|
throwFailNow(t, AssertIs(num, 0))
|
||||||
throwFailNow(t, AssertIs(users3 == nil, false))
|
throwFailNow(t, AssertIs(users3 == nil, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1136,6 +1140,7 @@ func TestRelatedSel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = qs.Filter("user_name", "nobody").RelatedSel("profile").One(&user)
|
err = qs.Filter("user_name", "nobody").RelatedSel("profile").One(&user)
|
||||||
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(num, 1))
|
throwFail(t, AssertIs(num, 1))
|
||||||
throwFail(t, AssertIs(user.Profile, nil))
|
throwFail(t, AssertIs(user.Profile, nil))
|
||||||
|
|
||||||
@ -1244,20 +1249,24 @@ func TestLoadRelated(t *testing.T) {
|
|||||||
|
|
||||||
num, err = dORM.LoadRelated(&user, "Posts", true)
|
num, err = dORM.LoadRelated(&user, "Posts", true)
|
||||||
throwFailNow(t, err)
|
throwFailNow(t, err)
|
||||||
|
throwFailNow(t, AssertIs(num, 2))
|
||||||
throwFailNow(t, AssertIs(len(user.Posts), 2))
|
throwFailNow(t, AssertIs(len(user.Posts), 2))
|
||||||
throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie"))
|
throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie"))
|
||||||
|
|
||||||
num, err = dORM.LoadRelated(&user, "Posts", true, 1)
|
num, err = dORM.LoadRelated(&user, "Posts", true, 1)
|
||||||
throwFailNow(t, err)
|
throwFailNow(t, err)
|
||||||
|
throwFailNow(t, AssertIs(num, 1))
|
||||||
throwFailNow(t, AssertIs(len(user.Posts), 1))
|
throwFailNow(t, AssertIs(len(user.Posts), 1))
|
||||||
|
|
||||||
num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id")
|
num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id")
|
||||||
throwFailNow(t, err)
|
throwFailNow(t, err)
|
||||||
|
throwFailNow(t, AssertIs(num, 2))
|
||||||
throwFailNow(t, AssertIs(len(user.Posts), 2))
|
throwFailNow(t, AssertIs(len(user.Posts), 2))
|
||||||
throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting"))
|
throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting"))
|
||||||
|
|
||||||
num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id")
|
num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id")
|
||||||
throwFailNow(t, err)
|
throwFailNow(t, err)
|
||||||
|
throwFailNow(t, AssertIs(num, 1))
|
||||||
throwFailNow(t, AssertIs(len(user.Posts), 1))
|
throwFailNow(t, AssertIs(len(user.Posts), 1))
|
||||||
throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting"))
|
throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting"))
|
||||||
|
|
||||||
@ -1652,6 +1661,13 @@ func TestRawQueryRow(t *testing.T) {
|
|||||||
throwFail(t, AssertIs(pid, nil))
|
throwFail(t, AssertIs(pid, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// user_profile table
|
||||||
|
type userProfile struct {
|
||||||
|
User
|
||||||
|
Age int
|
||||||
|
Money float64
|
||||||
|
}
|
||||||
|
|
||||||
func TestQueryRows(t *testing.T) {
|
func TestQueryRows(t *testing.T) {
|
||||||
Q := dDbBaser.TableQuote()
|
Q := dDbBaser.TableQuote()
|
||||||
|
|
||||||
@ -1722,6 +1738,19 @@ func TestQueryRows(t *testing.T) {
|
|||||||
throwFailNow(t, AssertIs(usernames[1], "astaxie"))
|
throwFailNow(t, AssertIs(usernames[1], "astaxie"))
|
||||||
throwFailNow(t, AssertIs(ids[2], 4))
|
throwFailNow(t, AssertIs(ids[2], 4))
|
||||||
throwFailNow(t, AssertIs(usernames[2], "nobody"))
|
throwFailNow(t, AssertIs(usernames[2], "nobody"))
|
||||||
|
|
||||||
|
//test query rows by nested struct
|
||||||
|
var l []userProfile
|
||||||
|
query = fmt.Sprintf("SELECT * FROM %suser_profile%s LEFT JOIN %suser%s ON %suser_profile%s.%sid%s = %suser%s.%sid%s", Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q)
|
||||||
|
num, err = dORM.Raw(query).QueryRows(&l)
|
||||||
|
throwFailNow(t, err)
|
||||||
|
throwFailNow(t, AssertIs(num, 2))
|
||||||
|
throwFailNow(t, AssertIs(len(l), 2))
|
||||||
|
throwFailNow(t, AssertIs(l[0].UserName, "slene"))
|
||||||
|
throwFailNow(t, AssertIs(l[0].Age, 28))
|
||||||
|
throwFailNow(t, AssertIs(l[1].UserName, "astaxie"))
|
||||||
|
throwFailNow(t, AssertIs(l[1].Age, 30))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRawValues(t *testing.T) {
|
func TestRawValues(t *testing.T) {
|
||||||
@ -1974,6 +2003,7 @@ func TestReadOrCreate(t *testing.T) {
|
|||||||
created, pk, err := dORM.ReadOrCreate(u, "UserName")
|
created, pk, err := dORM.ReadOrCreate(u, "UserName")
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(created, true))
|
throwFail(t, AssertIs(created, true))
|
||||||
|
throwFail(t, AssertIs(u.ID, pk))
|
||||||
throwFail(t, AssertIs(u.UserName, "Kyle"))
|
throwFail(t, AssertIs(u.UserName, "Kyle"))
|
||||||
throwFail(t, AssertIs(u.Email, "kylemcc@gmail.com"))
|
throwFail(t, AssertIs(u.Email, "kylemcc@gmail.com"))
|
||||||
throwFail(t, AssertIs(u.Password, "other_pass"))
|
throwFail(t, AssertIs(u.Password, "other_pass"))
|
||||||
@ -2128,13 +2158,13 @@ func TestUintPk(t *testing.T) {
|
|||||||
Name: name,
|
Name: name,
|
||||||
}
|
}
|
||||||
|
|
||||||
created, pk, err := dORM.ReadOrCreate(u, "ID")
|
created, _, err := dORM.ReadOrCreate(u, "ID")
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(created, true))
|
throwFail(t, AssertIs(created, true))
|
||||||
throwFail(t, AssertIs(u.Name, name))
|
throwFail(t, AssertIs(u.Name, name))
|
||||||
|
|
||||||
nu := &UintPk{ID: 8}
|
nu := &UintPk{ID: 8}
|
||||||
created, pk, err = dORM.ReadOrCreate(nu, "ID")
|
created, pk, err := dORM.ReadOrCreate(nu, "ID")
|
||||||
throwFail(t, err)
|
throwFail(t, err)
|
||||||
throwFail(t, AssertIs(created, false))
|
throwFail(t, AssertIs(created, false))
|
||||||
throwFail(t, AssertIs(nu.ID, u.ID))
|
throwFail(t, AssertIs(nu.ID, u.ID))
|
||||||
@ -2144,6 +2174,48 @@ func TestUintPk(t *testing.T) {
|
|||||||
dORM.Delete(u)
|
dORM.Delete(u)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPtrPk(t *testing.T) {
|
||||||
|
parent := &IntegerPk{ID: 10, Value: "10"}
|
||||||
|
|
||||||
|
id, _ := dORM.Insert(parent)
|
||||||
|
if !IsMysql {
|
||||||
|
// MySql does not support last_insert_id in this case: see #2382
|
||||||
|
throwFail(t, AssertIs(id, 10))
|
||||||
|
}
|
||||||
|
|
||||||
|
ptr := PtrPk{ID: parent, Positive: true}
|
||||||
|
num, err := dORM.InsertMulti(2, []PtrPk{ptr})
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
throwFail(t, AssertIs(ptr.ID, parent))
|
||||||
|
|
||||||
|
nptr := &PtrPk{ID: parent}
|
||||||
|
created, pk, err := dORM.ReadOrCreate(nptr, "ID")
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(created, false))
|
||||||
|
throwFail(t, AssertIs(pk, 10))
|
||||||
|
throwFail(t, AssertIs(nptr.ID, parent))
|
||||||
|
throwFail(t, AssertIs(nptr.Positive, true))
|
||||||
|
|
||||||
|
nptr = &PtrPk{Positive: true}
|
||||||
|
created, pk, err = dORM.ReadOrCreate(nptr, "Positive")
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(created, false))
|
||||||
|
throwFail(t, AssertIs(pk, 10))
|
||||||
|
throwFail(t, AssertIs(nptr.ID, parent))
|
||||||
|
|
||||||
|
nptr.Positive = false
|
||||||
|
num, err = dORM.Update(nptr)
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
throwFail(t, AssertIs(nptr.ID, parent))
|
||||||
|
throwFail(t, AssertIs(nptr.Positive, false))
|
||||||
|
|
||||||
|
num, err = dORM.Delete(nptr)
|
||||||
|
throwFail(t, err)
|
||||||
|
throwFail(t, AssertIs(num, 1))
|
||||||
|
}
|
||||||
|
|
||||||
func TestSnake(t *testing.T) {
|
func TestSnake(t *testing.T) {
|
||||||
cases := map[string]string{
|
cases := map[string]string{
|
||||||
"i": "i",
|
"i": "i",
|
||||||
|
10
orm/types.go
10
orm/types.go
@ -145,6 +145,16 @@ type QuerySeter interface {
|
|||||||
// //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
|
// //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
|
||||||
// num, err := qs.SetCond(cond1).Count()
|
// num, err := qs.SetCond(cond1).Count()
|
||||||
SetCond(*Condition) QuerySeter
|
SetCond(*Condition) QuerySeter
|
||||||
|
// get condition from QuerySeter.
|
||||||
|
// sql's where condition
|
||||||
|
// cond := orm.NewCondition()
|
||||||
|
// cond = cond.And("profile__isnull", false).AndNot("status__in", 1)
|
||||||
|
// qs = qs.SetCond(cond)
|
||||||
|
// cond = qs.GetCond()
|
||||||
|
// cond := cond.Or("profile__age__gt", 2000)
|
||||||
|
// //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
|
||||||
|
// num, err := qs.SetCond(cond).Count()
|
||||||
|
GetCond() *Condition
|
||||||
// add LIMIT value.
|
// add LIMIT value.
|
||||||
// args[0] means offset, e.g. LIMIT num,offset.
|
// args[0] means offset, e.g. LIMIT num,offset.
|
||||||
// if Limit <= 0 then Limit will be set to default limit ,eg 1000
|
// if Limit <= 0 then Limit will be set to default limit ,eg 1000
|
||||||
|
29
orm/utils.go
29
orm/utils.go
@ -92,11 +92,11 @@ func (f StrTo) Int64() (int64, error) {
|
|||||||
i := new(big.Int)
|
i := new(big.Int)
|
||||||
ni, ok := i.SetString(f.String(), 10) // octal
|
ni, ok := i.SetString(f.String(), 10) // octal
|
||||||
if !ok {
|
if !ok {
|
||||||
return int64(v), err
|
return v, err
|
||||||
}
|
}
|
||||||
return ni.Int64(), nil
|
return ni.Int64(), nil
|
||||||
}
|
}
|
||||||
return int64(v), err
|
return v, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Uint string to uint
|
// Uint string to uint
|
||||||
@ -130,11 +130,11 @@ func (f StrTo) Uint64() (uint64, error) {
|
|||||||
i := new(big.Int)
|
i := new(big.Int)
|
||||||
ni, ok := i.SetString(f.String(), 10)
|
ni, ok := i.SetString(f.String(), 10)
|
||||||
if !ok {
|
if !ok {
|
||||||
return uint64(v), err
|
return v, err
|
||||||
}
|
}
|
||||||
return ni.Uint64(), nil
|
return ni.Uint64(), nil
|
||||||
}
|
}
|
||||||
return uint64(v), err
|
return v, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// String string to string
|
// String string to string
|
||||||
@ -219,22 +219,17 @@ func snakeString(s string) string {
|
|||||||
// camel string, xx_yy to XxYy
|
// camel string, xx_yy to XxYy
|
||||||
func camelString(s string) string {
|
func camelString(s string) string {
|
||||||
data := make([]byte, 0, len(s))
|
data := make([]byte, 0, len(s))
|
||||||
j := false
|
flag, num := true, len(s)-1
|
||||||
k := false
|
|
||||||
num := len(s) - 1
|
|
||||||
for i := 0; i <= num; i++ {
|
for i := 0; i <= num; i++ {
|
||||||
d := s[i]
|
d := s[i]
|
||||||
if k == false && d >= 'A' && d <= 'Z' {
|
if d == '_' {
|
||||||
k = true
|
flag = true
|
||||||
}
|
|
||||||
if d >= 'a' && d <= 'z' && (j || k == false) {
|
|
||||||
d = d - 32
|
|
||||||
j = false
|
|
||||||
k = true
|
|
||||||
}
|
|
||||||
if k && d == '_' && num > i && s[i+1] >= 'a' && s[i+1] <= 'z' {
|
|
||||||
j = true
|
|
||||||
continue
|
continue
|
||||||
|
} else if flag {
|
||||||
|
if d >= 'a' && d <= 'z' {
|
||||||
|
d = d - 32
|
||||||
|
}
|
||||||
|
flag = false
|
||||||
}
|
}
|
||||||
data = append(data, d)
|
data = append(data, d)
|
||||||
}
|
}
|
||||||
|
36
orm/utils_test.go
Normal file
36
orm/utils_test.go
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
// Copyright 2014 beego Author. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCamelString(t *testing.T) {
|
||||||
|
snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"}
|
||||||
|
camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"}
|
||||||
|
|
||||||
|
answer := make(map[string]string)
|
||||||
|
for i, v := range snake {
|
||||||
|
answer[v] = camel[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range snake {
|
||||||
|
res := camelString(v)
|
||||||
|
if res != answer[v] {
|
||||||
|
t.Error("Unit Test Fail:", v, res, answer[v])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
212
parser.go
212
parser.go
@ -24,9 +24,13 @@ import (
|
|||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/astaxie/beego/context/param"
|
||||||
"github.com/astaxie/beego/logs"
|
"github.com/astaxie/beego/logs"
|
||||||
"github.com/astaxie/beego/utils"
|
"github.com/astaxie/beego/utils"
|
||||||
)
|
)
|
||||||
@ -35,6 +39,7 @@ var globalRouterTemplate = `package routers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/astaxie/beego"
|
"github.com/astaxie/beego"
|
||||||
|
"github.com/astaxie/beego/context/param"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@ -81,7 +86,7 @@ func parserPkg(pkgRealpath, pkgpath string) error {
|
|||||||
if specDecl.Recv != nil {
|
if specDecl.Recv != nil {
|
||||||
exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser
|
exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser
|
||||||
if ok {
|
if ok {
|
||||||
parserComments(specDecl.Doc, specDecl.Name.String(), fmt.Sprint(exp.X), pkgpath)
|
parserComments(specDecl, fmt.Sprint(exp.X), pkgpath)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -93,44 +98,170 @@ func parserPkg(pkgRealpath, pkgpath string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parserComments(comments *ast.CommentGroup, funcName, controllerName, pkgpath string) error {
|
type parsedComment struct {
|
||||||
if comments != nil && comments.List != nil {
|
routerPath string
|
||||||
for _, c := range comments.List {
|
methods []string
|
||||||
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
|
params map[string]parsedParam
|
||||||
if strings.HasPrefix(t, "@router") {
|
}
|
||||||
elements := strings.TrimLeft(t, "@router ")
|
|
||||||
e1 := strings.SplitN(elements, " ", 2)
|
type parsedParam struct {
|
||||||
if len(e1) < 1 {
|
name string
|
||||||
return errors.New("you should has router information")
|
datatype string
|
||||||
}
|
location string
|
||||||
key := pkgpath + ":" + controllerName
|
defValue string
|
||||||
cc := ControllerComments{}
|
required bool
|
||||||
cc.Method = funcName
|
}
|
||||||
cc.Router = e1[0]
|
|
||||||
if len(e1) == 2 && e1[1] != "" {
|
func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error {
|
||||||
e1 = strings.SplitN(e1[1], " ", 2)
|
if f.Doc != nil {
|
||||||
if len(e1) >= 1 {
|
parsedComment, err := parseComment(f.Doc.List)
|
||||||
cc.AllowHTTPMethods = strings.Split(strings.Trim(e1[0], "[]"), ",")
|
if err != nil {
|
||||||
} else {
|
return err
|
||||||
cc.AllowHTTPMethods = append(cc.AllowHTTPMethods, "get")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cc.AllowHTTPMethods = append(cc.AllowHTTPMethods, "get")
|
|
||||||
}
|
|
||||||
if len(e1) == 2 && e1[1] != "" {
|
|
||||||
keyval := strings.Split(strings.Trim(e1[1], "[]"), " ")
|
|
||||||
for _, kv := range keyval {
|
|
||||||
kk := strings.Split(kv, ":")
|
|
||||||
cc.Params = append(cc.Params, map[string]string{strings.Join(kk[:len(kk)-1], ":"): kk[len(kk)-1]})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
genInfoList[key] = append(genInfoList[key], cc)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if parsedComment.routerPath != "" {
|
||||||
|
key := pkgpath + ":" + controllerName
|
||||||
|
cc := ControllerComments{}
|
||||||
|
cc.Method = f.Name.String()
|
||||||
|
cc.Router = parsedComment.routerPath
|
||||||
|
cc.AllowHTTPMethods = parsedComment.methods
|
||||||
|
cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment)
|
||||||
|
genInfoList[key] = append(genInfoList[key], cc)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam {
|
||||||
|
result := make([]*param.MethodParam, 0, len(funcParams))
|
||||||
|
for _, fparam := range funcParams {
|
||||||
|
for _, pName := range fparam.Names {
|
||||||
|
methodParam := buildMethodParam(fparam, pName.Name, pc)
|
||||||
|
result = append(result, methodParam)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildMethodParam(fparam *ast.Field, name string, pc *parsedComment) *param.MethodParam {
|
||||||
|
options := []param.MethodParamOption{}
|
||||||
|
if cparam, ok := pc.params[name]; ok {
|
||||||
|
//Build param from comment info
|
||||||
|
name = cparam.name
|
||||||
|
if cparam.required {
|
||||||
|
options = append(options, param.IsRequired)
|
||||||
|
}
|
||||||
|
switch cparam.location {
|
||||||
|
case "body":
|
||||||
|
options = append(options, param.InBody)
|
||||||
|
case "header":
|
||||||
|
options = append(options, param.InHeader)
|
||||||
|
case "path":
|
||||||
|
options = append(options, param.InPath)
|
||||||
|
}
|
||||||
|
if cparam.defValue != "" {
|
||||||
|
options = append(options, param.Default(cparam.defValue))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if paramInPath(name, pc.routerPath) {
|
||||||
|
options = append(options, param.InPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return param.New(name, options...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func paramInPath(name, route string) bool {
|
||||||
|
return strings.HasSuffix(route, ":"+name) ||
|
||||||
|
strings.Contains(route, ":"+name+"/")
|
||||||
|
}
|
||||||
|
|
||||||
|
var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`)
|
||||||
|
|
||||||
|
func parseComment(lines []*ast.Comment) (pc *parsedComment, err error) {
|
||||||
|
pc = &parsedComment{}
|
||||||
|
for _, c := range lines {
|
||||||
|
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
|
||||||
|
if strings.HasPrefix(t, "@router") {
|
||||||
|
matches := routeRegex.FindStringSubmatch(t)
|
||||||
|
if len(matches) == 3 {
|
||||||
|
pc.routerPath = matches[1]
|
||||||
|
methods := matches[2]
|
||||||
|
if methods == "" {
|
||||||
|
pc.methods = []string{"get"}
|
||||||
|
//pc.hasGet = true
|
||||||
|
} else {
|
||||||
|
pc.methods = strings.Split(methods, ",")
|
||||||
|
//pc.hasGet = strings.Contains(methods, "get")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, errors.New("Router information is missing")
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(t, "@Param") {
|
||||||
|
pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param")))
|
||||||
|
if len(pv) < 4 {
|
||||||
|
logs.Error("Invalid @Param format. Needs at least 4 parameters")
|
||||||
|
}
|
||||||
|
p := parsedParam{}
|
||||||
|
names := strings.SplitN(pv[0], "=>", 2)
|
||||||
|
p.name = names[0]
|
||||||
|
funcParamName := p.name
|
||||||
|
if len(names) > 1 {
|
||||||
|
funcParamName = names[1]
|
||||||
|
}
|
||||||
|
p.location = pv[1]
|
||||||
|
p.datatype = pv[2]
|
||||||
|
switch len(pv) {
|
||||||
|
case 5:
|
||||||
|
p.required, _ = strconv.ParseBool(pv[3])
|
||||||
|
case 6:
|
||||||
|
p.defValue = pv[3]
|
||||||
|
p.required, _ = strconv.ParseBool(pv[4])
|
||||||
|
}
|
||||||
|
if pc.params == nil {
|
||||||
|
pc.params = map[string]parsedParam{}
|
||||||
|
}
|
||||||
|
pc.params[funcParamName] = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// direct copy from bee\g_docs.go
|
||||||
|
// analysis params return []string
|
||||||
|
// @Param query form string true "The email for login"
|
||||||
|
// [query form string true "The email for login"]
|
||||||
|
func getparams(str string) []string {
|
||||||
|
var s []rune
|
||||||
|
var j int
|
||||||
|
var start bool
|
||||||
|
var r []string
|
||||||
|
var quoted int8
|
||||||
|
for _, c := range str {
|
||||||
|
if unicode.IsSpace(c) && quoted == 0 {
|
||||||
|
if !start {
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
start = false
|
||||||
|
j++
|
||||||
|
r = append(r, string(s))
|
||||||
|
s = make([]rune, 0)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
start = true
|
||||||
|
if c == '"' {
|
||||||
|
quoted ^= 1
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s = append(s, c)
|
||||||
|
}
|
||||||
|
if len(s) > 0 {
|
||||||
|
r = append(r, string(s))
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
func genRouterCode(pkgRealpath string) {
|
func genRouterCode(pkgRealpath string) {
|
||||||
os.Mkdir(getRouterDir(pkgRealpath), 0755)
|
os.Mkdir(getRouterDir(pkgRealpath), 0755)
|
||||||
logs.Info("generate router from comments")
|
logs.Info("generate router from comments")
|
||||||
@ -144,6 +275,7 @@ func genRouterCode(pkgRealpath string) {
|
|||||||
sort.Strings(sortKey)
|
sort.Strings(sortKey)
|
||||||
for _, k := range sortKey {
|
for _, k := range sortKey {
|
||||||
cList := genInfoList[k]
|
cList := genInfoList[k]
|
||||||
|
sort.Sort(ControllerCommentsSlice(cList))
|
||||||
for _, c := range cList {
|
for _, c := range cList {
|
||||||
allmethod := "nil"
|
allmethod := "nil"
|
||||||
if len(c.AllowHTTPMethods) > 0 {
|
if len(c.AllowHTTPMethods) > 0 {
|
||||||
@ -163,12 +295,24 @@ func genRouterCode(pkgRealpath string) {
|
|||||||
}
|
}
|
||||||
params = strings.TrimRight(params, ",") + "}"
|
params = strings.TrimRight(params, ",") + "}"
|
||||||
}
|
}
|
||||||
|
methodParams := "param.Make("
|
||||||
|
if len(c.MethodParams) > 0 {
|
||||||
|
lines := make([]string, 0, len(c.MethodParams))
|
||||||
|
for _, m := range c.MethodParams {
|
||||||
|
lines = append(lines, fmt.Sprint(m))
|
||||||
|
}
|
||||||
|
methodParams += "\n " +
|
||||||
|
strings.Join(lines, ",\n ") +
|
||||||
|
",\n "
|
||||||
|
}
|
||||||
|
methodParams += ")"
|
||||||
globalinfo = globalinfo + `
|
globalinfo = globalinfo + `
|
||||||
beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"],
|
beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"],
|
||||||
beego.ControllerComments{
|
beego.ControllerComments{
|
||||||
Method: "` + strings.TrimSpace(c.Method) + `",
|
Method: "` + strings.TrimSpace(c.Method) + `",
|
||||||
` + "Router: `" + c.Router + "`" + `,
|
` + "Router: `" + c.Router + "`" + `,
|
||||||
AllowHTTPMethods: ` + allmethod + `,
|
AllowHTTPMethods: ` + allmethod + `,
|
||||||
|
MethodParams: ` + methodParams + `,
|
||||||
Params: ` + params + `})
|
Params: ` + params + `})
|
||||||
`
|
`
|
||||||
}
|
}
|
||||||
|
@ -56,6 +56,7 @@
|
|||||||
package apiauth
|
package apiauth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
@ -128,53 +129,32 @@ func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc {
|
|||||||
|
|
||||||
// Signature used to generate signature with the appsecret/method/params/RequestURI
|
// Signature used to generate signature with the appsecret/method/params/RequestURI
|
||||||
func Signature(appsecret, method string, params url.Values, RequestURL string) (result string) {
|
func Signature(appsecret, method string, params url.Values, RequestURL string) (result string) {
|
||||||
var query string
|
var b bytes.Buffer
|
||||||
|
keys := make([]string, len(params))
|
||||||
pa := make(map[string]string)
|
pa := make(map[string]string)
|
||||||
for k, v := range params {
|
for k, v := range params {
|
||||||
pa[k] = v[0]
|
pa[k] = v[0]
|
||||||
|
keys = append(keys, k)
|
||||||
}
|
}
|
||||||
vs := mapSorter(pa)
|
|
||||||
vs.Sort()
|
sort.Strings(keys)
|
||||||
for i := 0; i < vs.Len(); i++ {
|
|
||||||
if vs.Keys[i] == "signature" {
|
for _, key := range keys {
|
||||||
|
if key == "signature" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if vs.Keys[i] != "" && vs.Vals[i] != "" {
|
|
||||||
query = fmt.Sprintf("%v%v%v", query, vs.Keys[i], vs.Vals[i])
|
val := pa[key]
|
||||||
|
if key != "" && val != "" {
|
||||||
|
b.WriteString(key)
|
||||||
|
b.WriteString(val)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
stringToSign := fmt.Sprintf("%v\n%v\n%v\n", method, query, RequestURL)
|
|
||||||
|
stringToSign := fmt.Sprintf("%v\n%v\n%v\n", method, b.String(), RequestURL)
|
||||||
|
|
||||||
sha256 := sha256.New
|
sha256 := sha256.New
|
||||||
hash := hmac.New(sha256, []byte(appsecret))
|
hash := hmac.New(sha256, []byte(appsecret))
|
||||||
hash.Write([]byte(stringToSign))
|
hash.Write([]byte(stringToSign))
|
||||||
return base64.StdEncoding.EncodeToString(hash.Sum(nil))
|
return base64.StdEncoding.EncodeToString(hash.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
type valSorter struct {
|
|
||||||
Keys []string
|
|
||||||
Vals []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func mapSorter(m map[string]string) *valSorter {
|
|
||||||
vs := &valSorter{
|
|
||||||
Keys: make([]string, 0, len(m)),
|
|
||||||
Vals: make([]string, 0, len(m)),
|
|
||||||
}
|
|
||||||
for k, v := range m {
|
|
||||||
vs.Keys = append(vs.Keys, k)
|
|
||||||
vs.Vals = append(vs.Vals, v)
|
|
||||||
}
|
|
||||||
return vs
|
|
||||||
}
|
|
||||||
|
|
||||||
func (vs *valSorter) Sort() {
|
|
||||||
sort.Sort(vs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (vs *valSorter) Len() int { return len(vs.Keys) }
|
|
||||||
func (vs *valSorter) Less(i, j int) bool { return vs.Keys[i] < vs.Keys[j] }
|
|
||||||
func (vs *valSorter) Swap(i, j int) {
|
|
||||||
vs.Vals[i], vs.Vals[j] = vs.Vals[j], vs.Vals[i]
|
|
||||||
vs.Keys[i], vs.Keys[j] = vs.Keys[j], vs.Keys[i]
|
|
||||||
}
|
|
||||||
|
20
plugins/apiauth/apiauth_test.go
Normal file
20
plugins/apiauth/apiauth_test.go
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
package apiauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSignature(t *testing.T) {
|
||||||
|
appsecret := "beego secret"
|
||||||
|
method := "GET"
|
||||||
|
RequestURL := "http://localhost/test/url"
|
||||||
|
params := make(url.Values)
|
||||||
|
params.Add("arg1", "hello")
|
||||||
|
params.Add("arg2", "beego")
|
||||||
|
|
||||||
|
signature := "mFdpvLh48ca4mDVEItE9++AKKQ/IVca7O/ZyyB8hR58="
|
||||||
|
if Signature(appsecret, method, params, RequestURL) != signature {
|
||||||
|
t.Error("Signature error")
|
||||||
|
}
|
||||||
|
}
|
86
plugins/authz/authz.go
Normal file
86
plugins/authz/authz.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
// Copyright 2014 beego Author. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Package authz provides handlers to enable ACL, RBAC, ABAC authorization support.
|
||||||
|
// Simple Usage:
|
||||||
|
// import(
|
||||||
|
// "github.com/astaxie/beego"
|
||||||
|
// "github.com/astaxie/beego/plugins/authz"
|
||||||
|
// "github.com/casbin/casbin"
|
||||||
|
// )
|
||||||
|
//
|
||||||
|
// func main(){
|
||||||
|
// // mediate the access for every request
|
||||||
|
// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
|
||||||
|
// beego.Run()
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// Advanced Usage:
|
||||||
|
//
|
||||||
|
// func main(){
|
||||||
|
// e := casbin.NewEnforcer("authz_model.conf", "")
|
||||||
|
// e.AddRoleForUser("alice", "admin")
|
||||||
|
// e.AddPolicy(...)
|
||||||
|
//
|
||||||
|
// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(e))
|
||||||
|
// beego.Run()
|
||||||
|
// }
|
||||||
|
package authz
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/astaxie/beego"
|
||||||
|
"github.com/astaxie/beego/context"
|
||||||
|
"github.com/casbin/casbin"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewAuthorizer returns the authorizer.
|
||||||
|
// Use a casbin enforcer as input
|
||||||
|
func NewAuthorizer(e *casbin.Enforcer) beego.FilterFunc {
|
||||||
|
return func(ctx *context.Context) {
|
||||||
|
a := &BasicAuthorizer{enforcer: e}
|
||||||
|
|
||||||
|
if !a.CheckPermission(ctx.Request) {
|
||||||
|
a.RequirePermission(ctx.ResponseWriter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BasicAuthorizer stores the casbin handler
|
||||||
|
type BasicAuthorizer struct {
|
||||||
|
enforcer *casbin.Enforcer
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserName gets the user name from the request.
|
||||||
|
// Currently, only HTTP basic authentication is supported
|
||||||
|
func (a *BasicAuthorizer) GetUserName(r *http.Request) string {
|
||||||
|
username, _, _ := r.BasicAuth()
|
||||||
|
return username
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckPermission checks the user/method/path combination from the request.
|
||||||
|
// Returns true (permission granted) or false (permission forbidden)
|
||||||
|
func (a *BasicAuthorizer) CheckPermission(r *http.Request) bool {
|
||||||
|
user := a.GetUserName(r)
|
||||||
|
method := r.Method
|
||||||
|
path := r.URL.Path
|
||||||
|
return a.enforcer.Enforce(user, path, method)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequirePermission returns the 403 Forbidden to the client
|
||||||
|
func (a *BasicAuthorizer) RequirePermission(w http.ResponseWriter) {
|
||||||
|
w.WriteHeader(403)
|
||||||
|
w.Write([]byte("403 Forbidden\n"))
|
||||||
|
}
|
14
plugins/authz/authz_model.conf
Normal file
14
plugins/authz/authz_model.conf
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
[request_definition]
|
||||||
|
r = sub, obj, act
|
||||||
|
|
||||||
|
[policy_definition]
|
||||||
|
p = sub, obj, act
|
||||||
|
|
||||||
|
[role_definition]
|
||||||
|
g = _, _
|
||||||
|
|
||||||
|
[policy_effect]
|
||||||
|
e = some(where (p.eft == allow))
|
||||||
|
|
||||||
|
[matchers]
|
||||||
|
m = g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*")
|
7
plugins/authz/authz_policy.csv
Normal file
7
plugins/authz/authz_policy.csv
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
p, alice, /dataset1/*, GET
|
||||||
|
p, alice, /dataset1/resource1, POST
|
||||||
|
p, bob, /dataset2/resource1, *
|
||||||
|
p, bob, /dataset2/resource2, GET
|
||||||
|
p, bob, /dataset2/folder1/*, POST
|
||||||
|
p, dataset1_admin, /dataset1/*, *
|
||||||
|
g, cathy, dataset1_admin
|
|
107
plugins/authz/authz_test.go
Normal file
107
plugins/authz/authz_test.go
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
// Copyright 2014 beego Author. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package authz
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/astaxie/beego"
|
||||||
|
"github.com/astaxie/beego/context"
|
||||||
|
"github.com/astaxie/beego/plugins/auth"
|
||||||
|
"github.com/casbin/casbin"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, path string, method string, code int) {
|
||||||
|
r, _ := http.NewRequest(method, path, nil)
|
||||||
|
r.SetBasicAuth(user, "123")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
if w.Code != code {
|
||||||
|
t.Errorf("%s, %s, %s: %d, supposed to be %d", user, path, method, w.Code, code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBasic(t *testing.T) {
|
||||||
|
handler := beego.NewControllerRegister()
|
||||||
|
|
||||||
|
handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123"))
|
||||||
|
handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
|
||||||
|
|
||||||
|
handler.Any("*", func(ctx *context.Context) {
|
||||||
|
ctx.Output.SetStatus(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
testRequest(t, handler, "alice", "/dataset1/resource1", "GET", 200)
|
||||||
|
testRequest(t, handler, "alice", "/dataset1/resource1", "POST", 200)
|
||||||
|
testRequest(t, handler, "alice", "/dataset1/resource2", "GET", 200)
|
||||||
|
testRequest(t, handler, "alice", "/dataset1/resource2", "POST", 403)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPathWildcard(t *testing.T) {
|
||||||
|
handler := beego.NewControllerRegister()
|
||||||
|
|
||||||
|
handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123"))
|
||||||
|
handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
|
||||||
|
|
||||||
|
handler.Any("*", func(ctx *context.Context) {
|
||||||
|
ctx.Output.SetStatus(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
testRequest(t, handler, "bob", "/dataset2/resource1", "GET", 200)
|
||||||
|
testRequest(t, handler, "bob", "/dataset2/resource1", "POST", 200)
|
||||||
|
testRequest(t, handler, "bob", "/dataset2/resource1", "DELETE", 200)
|
||||||
|
testRequest(t, handler, "bob", "/dataset2/resource2", "GET", 200)
|
||||||
|
testRequest(t, handler, "bob", "/dataset2/resource2", "POST", 403)
|
||||||
|
testRequest(t, handler, "bob", "/dataset2/resource2", "DELETE", 403)
|
||||||
|
|
||||||
|
testRequest(t, handler, "bob", "/dataset2/folder1/item1", "GET", 403)
|
||||||
|
testRequest(t, handler, "bob", "/dataset2/folder1/item1", "POST", 200)
|
||||||
|
testRequest(t, handler, "bob", "/dataset2/folder1/item1", "DELETE", 403)
|
||||||
|
testRequest(t, handler, "bob", "/dataset2/folder1/item2", "GET", 403)
|
||||||
|
testRequest(t, handler, "bob", "/dataset2/folder1/item2", "POST", 200)
|
||||||
|
testRequest(t, handler, "bob", "/dataset2/folder1/item2", "DELETE", 403)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRBAC(t *testing.T) {
|
||||||
|
handler := beego.NewControllerRegister()
|
||||||
|
|
||||||
|
handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123"))
|
||||||
|
e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")
|
||||||
|
handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e))
|
||||||
|
|
||||||
|
handler.Any("*", func(ctx *context.Context) {
|
||||||
|
ctx.Output.SetStatus(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
// cathy can access all /dataset1/* resources via all methods because it has the dataset1_admin role.
|
||||||
|
testRequest(t, handler, "cathy", "/dataset1/item", "GET", 200)
|
||||||
|
testRequest(t, handler, "cathy", "/dataset1/item", "POST", 200)
|
||||||
|
testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 200)
|
||||||
|
testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403)
|
||||||
|
testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403)
|
||||||
|
testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403)
|
||||||
|
|
||||||
|
// delete all roles on user cathy, so cathy cannot access any resources now.
|
||||||
|
e.DeleteRolesForUser("cathy")
|
||||||
|
|
||||||
|
testRequest(t, handler, "cathy", "/dataset1/item", "GET", 403)
|
||||||
|
testRequest(t, handler, "cathy", "/dataset1/item", "POST", 403)
|
||||||
|
testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 403)
|
||||||
|
testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403)
|
||||||
|
testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403)
|
||||||
|
testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403)
|
||||||
|
}
|
@ -23,7 +23,7 @@ import (
|
|||||||
// PolicyFunc defines a policy function which is invoked before the controller handler is executed.
|
// PolicyFunc defines a policy function which is invoked before the controller handler is executed.
|
||||||
type PolicyFunc func(*context.Context)
|
type PolicyFunc func(*context.Context)
|
||||||
|
|
||||||
// FindRouter Find Router info for URL
|
// FindPolicy Find Router info for URL
|
||||||
func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc {
|
func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc {
|
||||||
var urlPath = cont.Input.URL()
|
var urlPath = cont.Input.URL()
|
||||||
if !BConfig.RouterCaseSensitive {
|
if !BConfig.RouterCaseSensitive {
|
||||||
@ -71,7 +71,7 @@ func (p *ControllerRegister) addToPolicy(method, pattern string, r ...PolicyFunc
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register new policy in beego
|
// Policy Register new policy in beego
|
||||||
func Policy(pattern, method string, policy ...PolicyFunc) {
|
func Policy(pattern, method string, policy ...PolicyFunc) {
|
||||||
BeeApp.Handlers.addToPolicy(method, pattern, policy...)
|
BeeApp.Handlers.addToPolicy(method, pattern, policy...)
|
||||||
}
|
}
|
||||||
|
134
router.go
134
router.go
@ -17,7 +17,6 @@ package beego
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -28,6 +27,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
beecontext "github.com/astaxie/beego/context"
|
beecontext "github.com/astaxie/beego/context"
|
||||||
|
"github.com/astaxie/beego/context/param"
|
||||||
"github.com/astaxie/beego/logs"
|
"github.com/astaxie/beego/logs"
|
||||||
"github.com/astaxie/beego/toolbox"
|
"github.com/astaxie/beego/toolbox"
|
||||||
"github.com/astaxie/beego/utils"
|
"github.com/astaxie/beego/utils"
|
||||||
@ -51,15 +51,22 @@ const (
|
|||||||
var (
|
var (
|
||||||
// HTTPMETHOD list the supported http methods.
|
// HTTPMETHOD list the supported http methods.
|
||||||
HTTPMETHOD = map[string]string{
|
HTTPMETHOD = map[string]string{
|
||||||
"GET": "GET",
|
"GET": "GET",
|
||||||
"POST": "POST",
|
"POST": "POST",
|
||||||
"PUT": "PUT",
|
"PUT": "PUT",
|
||||||
"DELETE": "DELETE",
|
"DELETE": "DELETE",
|
||||||
"PATCH": "PATCH",
|
"PATCH": "PATCH",
|
||||||
"OPTIONS": "OPTIONS",
|
"OPTIONS": "OPTIONS",
|
||||||
"HEAD": "HEAD",
|
"HEAD": "HEAD",
|
||||||
"TRACE": "TRACE",
|
"TRACE": "TRACE",
|
||||||
"CONNECT": "CONNECT",
|
"CONNECT": "CONNECT",
|
||||||
|
"MKCOL": "MKCOL",
|
||||||
|
"COPY": "COPY",
|
||||||
|
"MOVE": "MOVE",
|
||||||
|
"PROPFIND": "PROPFIND",
|
||||||
|
"PROPPATCH": "PROPPATCH",
|
||||||
|
"LOCK": "LOCK",
|
||||||
|
"UNLOCK": "UNLOCK",
|
||||||
}
|
}
|
||||||
// these beego.Controller's methods shouldn't reflect to AutoRouter
|
// these beego.Controller's methods shouldn't reflect to AutoRouter
|
||||||
exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString",
|
exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString",
|
||||||
@ -102,13 +109,15 @@ func ExceptMethodAppend(action string) {
|
|||||||
exceptMethod = append(exceptMethod, action)
|
exceptMethod = append(exceptMethod, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
type controllerInfo struct {
|
// ControllerInfo holds information about the controller.
|
||||||
|
type ControllerInfo struct {
|
||||||
pattern string
|
pattern string
|
||||||
controllerType reflect.Type
|
controllerType reflect.Type
|
||||||
methods map[string]string
|
methods map[string]string
|
||||||
handler http.Handler
|
handler http.Handler
|
||||||
runFunction FilterFunc
|
runFunction FilterFunc
|
||||||
routerType int
|
routerType int
|
||||||
|
methodParams []*param.MethodParam
|
||||||
}
|
}
|
||||||
|
|
||||||
// ControllerRegister containers registered router rules, controller handlers and filters.
|
// ControllerRegister containers registered router rules, controller handlers and filters.
|
||||||
@ -144,6 +153,10 @@ func NewControllerRegister() *ControllerRegister {
|
|||||||
// Add("/api",&RestController{},"get,post:ApiFunc"
|
// Add("/api",&RestController{},"get,post:ApiFunc"
|
||||||
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
|
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
|
||||||
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
|
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
|
||||||
|
p.addWithMethodParams(pattern, c, nil, mappingMethods...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) {
|
||||||
reflectVal := reflect.ValueOf(c)
|
reflectVal := reflect.ValueOf(c)
|
||||||
t := reflect.Indirect(reflectVal).Type()
|
t := reflect.Indirect(reflectVal).Type()
|
||||||
methods := make(map[string]string)
|
methods := make(map[string]string)
|
||||||
@ -169,11 +182,12 @@ func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingM
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
route := &controllerInfo{}
|
route := &ControllerInfo{}
|
||||||
route.pattern = pattern
|
route.pattern = pattern
|
||||||
route.methods = methods
|
route.methods = methods
|
||||||
route.routerType = routerTypeBeego
|
route.routerType = routerTypeBeego
|
||||||
route.controllerType = t
|
route.controllerType = t
|
||||||
|
route.methodParams = methodParams
|
||||||
if len(methods) == 0 {
|
if len(methods) == 0 {
|
||||||
for _, m := range HTTPMETHOD {
|
for _, m := range HTTPMETHOD {
|
||||||
p.addToRouter(m, pattern, route)
|
p.addToRouter(m, pattern, route)
|
||||||
@ -191,7 +205,7 @@ func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingM
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ControllerRegister) addToRouter(method, pattern string, r *controllerInfo) {
|
func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) {
|
||||||
if !BConfig.RouterCaseSensitive {
|
if !BConfig.RouterCaseSensitive {
|
||||||
pattern = strings.ToLower(pattern)
|
pattern = strings.ToLower(pattern)
|
||||||
}
|
}
|
||||||
@ -212,13 +226,11 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) {
|
|||||||
for _, c := range cList {
|
for _, c := range cList {
|
||||||
reflectVal := reflect.ValueOf(c)
|
reflectVal := reflect.ValueOf(c)
|
||||||
t := reflect.Indirect(reflectVal).Type()
|
t := reflect.Indirect(reflectVal).Type()
|
||||||
gopath := os.Getenv("GOPATH")
|
wgopath := utils.GetGOPATHs()
|
||||||
if gopath == "" {
|
if len(wgopath) == 0 {
|
||||||
panic("you are in dev mode. So please set gopath")
|
panic("you are in dev mode. So please set gopath")
|
||||||
}
|
}
|
||||||
pkgpath := ""
|
pkgpath := ""
|
||||||
|
|
||||||
wgopath := filepath.SplitList(gopath)
|
|
||||||
for _, wg := range wgopath {
|
for _, wg := range wgopath {
|
||||||
wg, _ = filepath.EvalSymlinks(filepath.Join(wg, "src", t.PkgPath()))
|
wg, _ = filepath.EvalSymlinks(filepath.Join(wg, "src", t.PkgPath()))
|
||||||
if utils.FileExists(wg) {
|
if utils.FileExists(wg) {
|
||||||
@ -240,7 +252,7 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) {
|
|||||||
key := t.PkgPath() + ":" + t.Name()
|
key := t.PkgPath() + ":" + t.Name()
|
||||||
if comm, ok := GlobalControllerRouter[key]; ok {
|
if comm, ok := GlobalControllerRouter[key]; ok {
|
||||||
for _, a := range comm {
|
for _, a := range comm {
|
||||||
p.Add(a.Router, c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)
|
p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -328,7 +340,7 @@ func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
|
|||||||
if _, ok := HTTPMETHOD[method]; method != "*" && !ok {
|
if _, ok := HTTPMETHOD[method]; method != "*" && !ok {
|
||||||
panic("not support http method: " + method)
|
panic("not support http method: " + method)
|
||||||
}
|
}
|
||||||
route := &controllerInfo{}
|
route := &ControllerInfo{}
|
||||||
route.pattern = pattern
|
route.pattern = pattern
|
||||||
route.routerType = routerTypeRESTFul
|
route.routerType = routerTypeRESTFul
|
||||||
route.runFunction = f
|
route.runFunction = f
|
||||||
@ -354,7 +366,7 @@ func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
|
|||||||
|
|
||||||
// Handler add user defined Handler
|
// Handler add user defined Handler
|
||||||
func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) {
|
func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) {
|
||||||
route := &controllerInfo{}
|
route := &ControllerInfo{}
|
||||||
route.pattern = pattern
|
route.pattern = pattern
|
||||||
route.routerType = routerTypeHandler
|
route.routerType = routerTypeHandler
|
||||||
route.handler = h
|
route.handler = h
|
||||||
@ -389,7 +401,7 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface)
|
|||||||
controllerName := strings.TrimSuffix(ct.Name(), "Controller")
|
controllerName := strings.TrimSuffix(ct.Name(), "Controller")
|
||||||
for i := 0; i < rt.NumMethod(); i++ {
|
for i := 0; i < rt.NumMethod(); i++ {
|
||||||
if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
|
if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
|
||||||
route := &controllerInfo{}
|
route := &ControllerInfo{}
|
||||||
route.routerType = routerTypeBeego
|
route.routerType = routerTypeBeego
|
||||||
route.methods = map[string]string{"*": rt.Method(i).Name}
|
route.methods = map[string]string{"*": rt.Method(i).Name}
|
||||||
route.controllerType = ct
|
route.controllerType = ct
|
||||||
@ -495,7 +507,7 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, l := range t.leaves {
|
for _, l := range t.leaves {
|
||||||
if c, ok := l.runObject.(*controllerInfo); ok {
|
if c, ok := l.runObject.(*ControllerInfo); ok {
|
||||||
if c.routerType == routerTypeBeego &&
|
if c.routerType == routerTypeBeego &&
|
||||||
strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllName) {
|
strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllName) {
|
||||||
find := false
|
find := false
|
||||||
@ -619,11 +631,12 @@ func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath str
|
|||||||
func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
var (
|
var (
|
||||||
runRouter reflect.Type
|
runRouter reflect.Type
|
||||||
findRouter bool
|
findRouter bool
|
||||||
runMethod string
|
runMethod string
|
||||||
routerInfo *controllerInfo
|
methodParams []*param.MethodParam
|
||||||
isRunnable bool
|
routerInfo *ControllerInfo
|
||||||
|
isRunnable bool
|
||||||
)
|
)
|
||||||
context := p.pool.Get().(*beecontext.Context)
|
context := p.pool.Get().(*beecontext.Context)
|
||||||
context.Reset(rw, r)
|
context.Reset(rw, r)
|
||||||
@ -663,7 +676,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
|||||||
goto Admin
|
goto Admin
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Method != "GET" && r.Method != "HEAD" {
|
if r.Method != http.MethodGet && r.Method != http.MethodHead {
|
||||||
if BConfig.CopyRequestBody && !context.Input.IsUpload() {
|
if BConfig.CopyRequestBody && !context.Input.IsUpload() {
|
||||||
context.Input.CopyBody(BConfig.MaxMemory)
|
context.Input.CopyBody(BConfig.MaxMemory)
|
||||||
}
|
}
|
||||||
@ -691,7 +704,6 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
|||||||
// User can define RunController and RunMethod in filter
|
// User can define RunController and RunMethod in filter
|
||||||
if context.Input.RunController != nil && context.Input.RunMethod != "" {
|
if context.Input.RunController != nil && context.Input.RunMethod != "" {
|
||||||
findRouter = true
|
findRouter = true
|
||||||
isRunnable = true
|
|
||||||
runMethod = context.Input.RunMethod
|
runMethod = context.Input.RunMethod
|
||||||
runRouter = context.Input.RunController
|
runRouter = context.Input.RunController
|
||||||
} else {
|
} else {
|
||||||
@ -735,12 +747,13 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
|||||||
routerInfo.handler.ServeHTTP(rw, r)
|
routerInfo.handler.ServeHTTP(rw, r)
|
||||||
} else {
|
} else {
|
||||||
runRouter = routerInfo.controllerType
|
runRouter = routerInfo.controllerType
|
||||||
|
methodParams = routerInfo.methodParams
|
||||||
method := r.Method
|
method := r.Method
|
||||||
if r.Method == "POST" && context.Input.Query("_method") == "PUT" {
|
if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPost {
|
||||||
method = "PUT"
|
method = http.MethodPut
|
||||||
}
|
}
|
||||||
if r.Method == "POST" && context.Input.Query("_method") == "DELETE" {
|
if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodDelete {
|
||||||
method = "DELETE"
|
method = http.MethodDelete
|
||||||
}
|
}
|
||||||
if m, ok := routerInfo.methods[method]; ok {
|
if m, ok := routerInfo.methods[method]; ok {
|
||||||
runMethod = m
|
runMethod = m
|
||||||
@ -770,8 +783,8 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
|||||||
//if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf
|
//if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf
|
||||||
if BConfig.WebConfig.EnableXSRF {
|
if BConfig.WebConfig.EnableXSRF {
|
||||||
execController.XSRFToken()
|
execController.XSRFToken()
|
||||||
if r.Method == "POST" || r.Method == "DELETE" || r.Method == "PUT" ||
|
if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut ||
|
||||||
(r.Method == "POST" && (context.Input.Query("_method") == "DELETE" || context.Input.Query("_method") == "PUT")) {
|
(r.Method == http.MethodPost && (context.Input.Query("_method") == http.MethodDelete || context.Input.Query("_method") == http.MethodPut)) {
|
||||||
execController.CheckXSRFCookie()
|
execController.CheckXSRFCookie()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -781,25 +794,30 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
|||||||
if !context.ResponseWriter.Started {
|
if !context.ResponseWriter.Started {
|
||||||
//exec main logic
|
//exec main logic
|
||||||
switch runMethod {
|
switch runMethod {
|
||||||
case "GET":
|
case http.MethodGet:
|
||||||
execController.Get()
|
execController.Get()
|
||||||
case "POST":
|
case http.MethodPost:
|
||||||
execController.Post()
|
execController.Post()
|
||||||
case "DELETE":
|
case http.MethodDelete:
|
||||||
execController.Delete()
|
execController.Delete()
|
||||||
case "PUT":
|
case http.MethodPut:
|
||||||
execController.Put()
|
execController.Put()
|
||||||
case "HEAD":
|
case http.MethodHead:
|
||||||
execController.Head()
|
execController.Head()
|
||||||
case "PATCH":
|
case http.MethodPatch:
|
||||||
execController.Patch()
|
execController.Patch()
|
||||||
case "OPTIONS":
|
case http.MethodOptions:
|
||||||
execController.Options()
|
execController.Options()
|
||||||
default:
|
default:
|
||||||
if !execController.HandlerFunc(runMethod) {
|
if !execController.HandlerFunc(runMethod) {
|
||||||
var in []reflect.Value
|
|
||||||
method := vc.MethodByName(runMethod)
|
method := vc.MethodByName(runMethod)
|
||||||
method.Call(in)
|
in := param.ConvertParams(methodParams, method.Type(), context)
|
||||||
|
out := method.Call(in)
|
||||||
|
|
||||||
|
//For backward compatibility we only handle response if we had incoming methodParams
|
||||||
|
if methodParams != nil {
|
||||||
|
p.handleParamResponse(context, execController, out)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -830,7 +848,15 @@ Admin:
|
|||||||
//admin module record QPS
|
//admin module record QPS
|
||||||
if BConfig.Listen.EnableAdmin {
|
if BConfig.Listen.EnableAdmin {
|
||||||
timeDur := time.Since(startTime)
|
timeDur := time.Since(startTime)
|
||||||
if FilterMonitorFunc(r.Method, r.URL.Path, timeDur) {
|
pattern := ""
|
||||||
|
if routerInfo != nil {
|
||||||
|
pattern = routerInfo.pattern
|
||||||
|
}
|
||||||
|
statusCode := context.ResponseWriter.Status
|
||||||
|
if statusCode == 0 {
|
||||||
|
statusCode = 200
|
||||||
|
}
|
||||||
|
if FilterMonitorFunc(r.Method, r.URL.Path, timeDur, pattern, statusCode) {
|
||||||
if runRouter != nil {
|
if runRouter != nil {
|
||||||
go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runRouter.Name(), timeDur)
|
go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runRouter.Name(), timeDur)
|
||||||
} else {
|
} else {
|
||||||
@ -879,8 +905,22 @@ Admin:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) {
|
||||||
|
//looping in reverse order for the case when both error and value are returned and error sets the response status code
|
||||||
|
for i := len(results) - 1; i >= 0; i-- {
|
||||||
|
result := results[i]
|
||||||
|
if result.Kind() != reflect.Interface || !result.IsNil() {
|
||||||
|
resultValue := result.Interface()
|
||||||
|
context.RenderMethodResult(resultValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !context.ResponseWriter.Started && context.Output.Status == 0 {
|
||||||
|
context.Output.SetStatus(200)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// FindRouter Find Router info for URL
|
// FindRouter Find Router info for URL
|
||||||
func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *controllerInfo, isFind bool) {
|
func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) {
|
||||||
var urlPath = context.Input.URL()
|
var urlPath = context.Input.URL()
|
||||||
if !BConfig.RouterCaseSensitive {
|
if !BConfig.RouterCaseSensitive {
|
||||||
urlPath = strings.ToLower(urlPath)
|
urlPath = strings.ToLower(urlPath)
|
||||||
@ -888,7 +928,7 @@ func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo
|
|||||||
httpMethod := context.Input.Method()
|
httpMethod := context.Input.Method()
|
||||||
if t, ok := p.routers[httpMethod]; ok {
|
if t, ok := p.routers[httpMethod]; ok {
|
||||||
runObject := t.Match(urlPath, context)
|
runObject := t.Match(urlPath, context)
|
||||||
if r, ok := runObject.(*controllerInfo); ok {
|
if r, ok := runObject.(*ControllerInfo); ok {
|
||||||
return r, true
|
return r, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user