", result)
+ }
+ }
+
+ func() {
+ ctrl.TplName = "file2.tpl"
+ defer func() {
+ if r := recover(); r == nil {
+ t.Fatal("TestAdditionalViewPaths expected error")
+ }
+ }()
+ ctrl.RenderString()
+ }()
+
+ ctrl.TplName = "file2.tpl"
+ ctrl.ViewPath = dir2
+ ctrl.RenderString()
+}
diff --git a/pkg/doc.go b/pkg/doc.go
new file mode 100644
index 00000000..8825bd29
--- /dev/null
+++ b/pkg/doc.go
@@ -0,0 +1,17 @@
+/*
+Package beego provide a MVC framework
+beego: an open-source, high-performance, modular, full-stack web framework
+
+It is used for rapid development of RESTful APIs, web apps and backend services in Go.
+beego is inspired by Tornado, Sinatra and Flask with the added benefit of some Go-specific features such as interfaces and struct embedding.
+
+ package main
+ import "github.com/astaxie/beego"
+
+ func main() {
+ beego.Run()
+ }
+
+more information: http://beego.me
+*/
+package beego
diff --git a/pkg/error.go b/pkg/error.go
new file mode 100644
index 00000000..f268f723
--- /dev/null
+++ b/pkg/error.go
@@ -0,0 +1,488 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package beego
+
+import (
+ "fmt"
+ "html/template"
+ "net/http"
+ "reflect"
+ "runtime"
+ "strconv"
+ "strings"
+
+ "github.com/astaxie/beego/context"
+ "github.com/astaxie/beego/utils"
+)
+
+const (
+ errorTypeHandler = iota
+ errorTypeController
+)
+
+var tpl = `
+
+
+
+
+ beego application error
+
+
+
+
+
+ {{.Content}}
+ Go Home
+
+ Powered by beego {{.BeegoVersion}}
+
+
+
+
+
+`
+
+type errorInfo struct {
+ controllerType reflect.Type
+ handler http.HandlerFunc
+ method string
+ errorType int
+}
+
+// ErrorMaps holds map of http handlers for each error string.
+// there is 10 kinds default error(40x and 50x)
+var ErrorMaps = make(map[string]*errorInfo, 10)
+
+// show 401 unauthorized error.
+func unauthorized(rw http.ResponseWriter, r *http.Request) {
+ responseError(rw, r,
+ 401,
+ " The page you have requested can't be authorized."+
+ " Perhaps you are here because:"+
+ "
"+
+ " The credentials you supplied are incorrect"+
+ " There are errors in the website address"+
+ "
",
+ )
+}
+
+// show 402 Payment Required
+func paymentRequired(rw http.ResponseWriter, r *http.Request) {
+ responseError(rw, r,
+ 402,
+ " The page you have requested Payment Required."+
+ " Perhaps you are here because:"+
+ "
"+
+ " The credentials you supplied are incorrect"+
+ " There are errors in the website address"+
+ "
",
+ )
+}
+
+// show 403 forbidden error.
+func forbidden(rw http.ResponseWriter, r *http.Request) {
+ responseError(rw, r,
+ 403,
+ " The page you have requested is forbidden."+
+ " Perhaps you are here because:"+
+ "
"+
+ " Your address may be blocked"+
+ " The site may be disabled"+
+ " You need to log in"+
+ "
",
+ )
+}
+
+// show 422 missing xsrf token
+func missingxsrf(rw http.ResponseWriter, r *http.Request) {
+ responseError(rw, r,
+ 422,
+ " The page you have requested is forbidden."+
+ " Perhaps you are here because:"+
+ "
"+
+ " '_xsrf' argument missing from POST"+
+ "
",
+ )
+}
+
+// show 417 invalid xsrf token
+func invalidxsrf(rw http.ResponseWriter, r *http.Request) {
+ responseError(rw, r,
+ 417,
+ " The page you have requested is forbidden."+
+ " Perhaps you are here because:"+
+ "
"+
+ " expected XSRF not found"+
+ "
",
+ )
+}
+
+// show 404 not found error.
+func notFound(rw http.ResponseWriter, r *http.Request) {
+ responseError(rw, r,
+ 404,
+ " The page you have requested has flown the coop."+
+ " Perhaps you are here because:"+
+ "
"+
+ " The page has moved"+
+ " The page no longer exists"+
+ " You were looking for your puppy and got lost"+
+ " You like 404 pages"+
+ "
",
+ )
+}
+
+// show 405 Method Not Allowed
+func methodNotAllowed(rw http.ResponseWriter, r *http.Request) {
+ responseError(rw, r,
+ 405,
+ " The method you have requested Not Allowed."+
+ " Perhaps you are here because:"+
+ "
"+
+ " The method specified in the Request-Line is not allowed for the resource identified by the Request-URI"+
+ " The response MUST include an Allow header containing a list of valid methods for the requested resource."+
+ "
",
+ )
+}
+
+// show 500 internal server error.
+func internalServerError(rw http.ResponseWriter, r *http.Request) {
+ responseError(rw, r,
+ 500,
+ " The page you have requested is down right now."+
+ "
"+
+ " Please try again later and report the error to the website administrator"+
+ "
",
+ )
+}
+
+// show 501 Not Implemented.
+func notImplemented(rw http.ResponseWriter, r *http.Request) {
+ responseError(rw, r,
+ 501,
+ " The page you have requested is Not Implemented."+
+ "
"+
+ " Please try again later and report the error to the website administrator"+
+ "
",
+ )
+}
+
+// show 502 Bad Gateway.
+func badGateway(rw http.ResponseWriter, r *http.Request) {
+ responseError(rw, r,
+ 502,
+ " The page you have requested is down right now."+
+ "
"+
+ " The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request."+
+ " Please try again later and report the error to the website administrator"+
+ "
",
+ )
+}
+
+// show 503 service unavailable error.
+func serviceUnavailable(rw http.ResponseWriter, r *http.Request) {
+ responseError(rw, r,
+ 503,
+ " The page you have requested is unavailable."+
+ " Perhaps you are here because:"+
+ "
"+
+ "
The page is overloaded"+
+ " Please try again later."+
+ "
",
+ )
+}
+
+// show 504 Gateway Timeout.
+func gatewayTimeout(rw http.ResponseWriter, r *http.Request) {
+ responseError(rw, r,
+ 504,
+ " The page you have requested is unavailable"+
+ " Perhaps you are here because:"+
+ "
"+
+ "
The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI."+
+ " Please try again later."+
+ "
",
+ )
+}
+
+// show 413 Payload Too Large
+func payloadTooLarge(rw http.ResponseWriter, r *http.Request) {
+ responseError(rw, r,
+ 413,
+ ` The page you have requested is unavailable.
+ Perhaps you are here because:
+
+ The request entity is larger than limits defined by server.
+ Please change the request entity and try again.
+
+ `,
+ )
+}
+
+func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errContent string) {
+ t, _ := template.New("beegoerrortemp").Parse(errtpl)
+ data := M{
+ "Title": http.StatusText(errCode),
+ "BeegoVersion": VERSION,
+ "Content": template.HTML(errContent),
+ }
+ t.Execute(rw, data)
+}
+
+// ErrorHandler registers http.HandlerFunc to each http err code string.
+// usage:
+// beego.ErrorHandler("404",NotFound)
+// beego.ErrorHandler("500",InternalServerError)
+func ErrorHandler(code string, h http.HandlerFunc) *App {
+ ErrorMaps[code] = &errorInfo{
+ errorType: errorTypeHandler,
+ handler: h,
+ method: code,
+ }
+ return BeeApp
+}
+
+// ErrorController registers ControllerInterface to each http err code string.
+// usage:
+// beego.ErrorController(&controllers.ErrorController{})
+func ErrorController(c ControllerInterface) *App {
+ reflectVal := reflect.ValueOf(c)
+ rt := reflectVal.Type()
+ ct := reflect.Indirect(reflectVal).Type()
+ for i := 0; i < rt.NumMethod(); i++ {
+ methodName := rt.Method(i).Name
+ if !utils.InSlice(methodName, exceptMethod) && strings.HasPrefix(methodName, "Error") {
+ errName := strings.TrimPrefix(methodName, "Error")
+ ErrorMaps[errName] = &errorInfo{
+ errorType: errorTypeController,
+ controllerType: ct,
+ method: methodName,
+ }
+ }
+ }
+ return BeeApp
+}
+
+// Exception Write HttpStatus with errCode and Exec error handler if exist.
+func Exception(errCode uint64, ctx *context.Context) {
+ exception(strconv.FormatUint(errCode, 10), ctx)
+}
+
+// show error string as simple text message.
+// if error string is empty, show 503 or 500 error as default.
+func exception(errCode string, ctx *context.Context) {
+ atoi := func(code string) int {
+ v, err := strconv.Atoi(code)
+ if err == nil {
+ return v
+ }
+ if ctx.Output.Status == 0 {
+ return 503
+ }
+ return ctx.Output.Status
+ }
+
+ for _, ec := range []string{errCode, "503", "500"} {
+ if h, ok := ErrorMaps[ec]; ok {
+ executeError(h, ctx, atoi(ec))
+ return
+ }
+ }
+ //if 50x error has been removed from errorMap
+ ctx.ResponseWriter.WriteHeader(atoi(errCode))
+ ctx.WriteString(errCode)
+}
+
+func executeError(err *errorInfo, ctx *context.Context, code int) {
+ //make sure to log the error in the access log
+ LogAccess(ctx, nil, code)
+
+ if err.errorType == errorTypeHandler {
+ ctx.ResponseWriter.WriteHeader(code)
+ err.handler(ctx.ResponseWriter, ctx.Request)
+ return
+ }
+ if err.errorType == errorTypeController {
+ ctx.Output.SetStatus(code)
+ //Invoke the request handler
+ vc := reflect.New(err.controllerType)
+ execController, ok := vc.Interface().(ControllerInterface)
+ if !ok {
+ panic("controller is not ControllerInterface")
+ }
+ //call the controller init function
+ execController.Init(ctx, err.controllerType.Name(), err.method, vc.Interface())
+
+ //call prepare function
+ execController.Prepare()
+
+ execController.URLMapping()
+
+ method := vc.MethodByName(err.method)
+ method.Call([]reflect.Value{})
+
+ //render template
+ if BConfig.WebConfig.AutoRender {
+ if err := execController.Render(); err != nil {
+ panic(err)
+ }
+ }
+
+ // finish all runrouter. release resource
+ execController.Finish()
+ }
+}
diff --git a/pkg/error_test.go b/pkg/error_test.go
new file mode 100644
index 00000000..378aa953
--- /dev/null
+++ b/pkg/error_test.go
@@ -0,0 +1,88 @@
+// Copyright 2016 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package beego
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "strings"
+ "testing"
+)
+
+type errorTestController struct {
+ Controller
+}
+
+const parseCodeError = "parse code error"
+
+func (ec *errorTestController) Get() {
+ errorCode, err := ec.GetInt("code")
+ if err != nil {
+ ec.Abort(parseCodeError)
+ }
+ if errorCode != 0 {
+ ec.CustomAbort(errorCode, ec.GetString("code"))
+ }
+ ec.Abort("404")
+}
+
+func TestErrorCode_01(t *testing.T) {
+ registerDefaultErrorHandler()
+ for k := range ErrorMaps {
+ r, _ := http.NewRequest("GET", "/error?code="+k, nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.Add("/error", &errorTestController{})
+ handler.ServeHTTP(w, r)
+ code, _ := strconv.Atoi(k)
+ if w.Code != code {
+ t.Fail()
+ }
+ if !strings.Contains(w.Body.String(), http.StatusText(code)) {
+ t.Fail()
+ }
+ }
+}
+
+func TestErrorCode_02(t *testing.T) {
+ registerDefaultErrorHandler()
+ r, _ := http.NewRequest("GET", "/error?code=0", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.Add("/error", &errorTestController{})
+ handler.ServeHTTP(w, r)
+ if w.Code != 404 {
+ t.Fail()
+ }
+}
+
+func TestErrorCode_03(t *testing.T) {
+ registerDefaultErrorHandler()
+ r, _ := http.NewRequest("GET", "/error?code=panic", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.Add("/error", &errorTestController{})
+ handler.ServeHTTP(w, r)
+ if w.Code != 200 {
+ t.Fail()
+ }
+ if w.Body.String() != parseCodeError {
+ t.Fail()
+ }
+}
diff --git a/pkg/filter.go b/pkg/filter.go
new file mode 100644
index 00000000..9cc6e913
--- /dev/null
+++ b/pkg/filter.go
@@ -0,0 +1,44 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package beego
+
+import "github.com/astaxie/beego/context"
+
+// FilterFunc defines a filter function which is invoked before the controller handler is executed.
+type FilterFunc func(*context.Context)
+
+// FilterRouter defines a filter operation which is invoked before the controller handler is executed.
+// It can match the URL against a pattern, and execute a filter function
+// when a request with a matching URL arrives.
+type FilterRouter struct {
+ filterFunc FilterFunc
+ tree *Tree
+ pattern string
+ returnOnOutput bool
+ resetParams bool
+}
+
+// ValidRouter checks if the current request is matched by this filter.
+// If the request is matched, the values of the URL parameters defined
+// by the filter pattern are also returned.
+func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool {
+ isOk := f.tree.Match(url, ctx)
+ if isOk != nil {
+ if b, ok := isOk.(bool); ok {
+ return b
+ }
+ }
+ return false
+}
diff --git a/pkg/filter_test.go b/pkg/filter_test.go
new file mode 100644
index 00000000..4ca4d2b8
--- /dev/null
+++ b/pkg/filter_test.go
@@ -0,0 +1,68 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package beego
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/astaxie/beego/context"
+)
+
+var FilterUser = func(ctx *context.Context) {
+ ctx.Output.Body([]byte("i am " + ctx.Input.Param(":last") + ctx.Input.Param(":first")))
+}
+
+func TestFilter(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/person/asta/Xie", nil)
+ w := httptest.NewRecorder()
+ handler := NewControllerRegister()
+ handler.InsertFilter("/person/:last/:first", BeforeRouter, FilterUser)
+ handler.Add("/person/:last/:first", &TestController{})
+ handler.ServeHTTP(w, r)
+ if w.Body.String() != "i am astaXie" {
+ t.Errorf("user define func can't run")
+ }
+}
+
+var FilterAdminUser = func(ctx *context.Context) {
+ ctx.Output.Body([]byte("i am admin"))
+}
+
+// Filter pattern /admin/:all
+// all url like /admin/ /admin/xie will all get filter
+
+func TestPatternTwo(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/admin/", nil)
+ w := httptest.NewRecorder()
+ handler := NewControllerRegister()
+ handler.InsertFilter("/admin/?:all", BeforeRouter, FilterAdminUser)
+ handler.ServeHTTP(w, r)
+ if w.Body.String() != "i am admin" {
+ t.Errorf("filter /admin/ can't run")
+ }
+}
+
+func TestPatternThree(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/admin/astaxie", nil)
+ w := httptest.NewRecorder()
+ handler := NewControllerRegister()
+ handler.InsertFilter("/admin/:all", BeforeRouter, FilterAdminUser)
+ handler.ServeHTTP(w, r)
+ if w.Body.String() != "i am admin" {
+ t.Errorf("filter /admin/astaxie can't run")
+ }
+}
diff --git a/pkg/flash.go b/pkg/flash.go
new file mode 100644
index 00000000..a6485a17
--- /dev/null
+++ b/pkg/flash.go
@@ -0,0 +1,110 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package beego
+
+import (
+ "fmt"
+ "net/url"
+ "strings"
+)
+
+// FlashData is a tools to maintain data when using across request.
+type FlashData struct {
+ Data map[string]string
+}
+
+// NewFlash return a new empty FlashData struct.
+func NewFlash() *FlashData {
+ return &FlashData{
+ Data: make(map[string]string),
+ }
+}
+
+// Set message to flash
+func (fd *FlashData) Set(key string, msg string, args ...interface{}) {
+ if len(args) == 0 {
+ fd.Data[key] = msg
+ } else {
+ fd.Data[key] = fmt.Sprintf(msg, args...)
+ }
+}
+
+// Success writes success message to flash.
+func (fd *FlashData) Success(msg string, args ...interface{}) {
+ if len(args) == 0 {
+ fd.Data["success"] = msg
+ } else {
+ fd.Data["success"] = fmt.Sprintf(msg, args...)
+ }
+}
+
+// Notice writes notice message to flash.
+func (fd *FlashData) Notice(msg string, args ...interface{}) {
+ if len(args) == 0 {
+ fd.Data["notice"] = msg
+ } else {
+ fd.Data["notice"] = fmt.Sprintf(msg, args...)
+ }
+}
+
+// Warning writes warning message to flash.
+func (fd *FlashData) Warning(msg string, args ...interface{}) {
+ if len(args) == 0 {
+ fd.Data["warning"] = msg
+ } else {
+ fd.Data["warning"] = fmt.Sprintf(msg, args...)
+ }
+}
+
+// Error writes error message to flash.
+func (fd *FlashData) Error(msg string, args ...interface{}) {
+ if len(args) == 0 {
+ fd.Data["error"] = msg
+ } else {
+ fd.Data["error"] = fmt.Sprintf(msg, args...)
+ }
+}
+
+// Store does the saving operation of flash data.
+// the data are encoded and saved in cookie.
+func (fd *FlashData) Store(c *Controller) {
+ c.Data["flash"] = fd.Data
+ var flashValue string
+ for key, value := range fd.Data {
+ flashValue += "\x00" + key + "\x23" + BConfig.WebConfig.FlashSeparator + "\x23" + value + "\x00"
+ }
+ c.Ctx.SetCookie(BConfig.WebConfig.FlashName, url.QueryEscape(flashValue), 0, "/")
+}
+
+// ReadFromRequest parsed flash data from encoded values in cookie.
+func ReadFromRequest(c *Controller) *FlashData {
+ flash := NewFlash()
+ if cookie, err := c.Ctx.Request.Cookie(BConfig.WebConfig.FlashName); err == nil {
+ v, _ := url.QueryUnescape(cookie.Value)
+ vals := strings.Split(v, "\x00")
+ for _, v := range vals {
+ if len(v) > 0 {
+ kv := strings.Split(v, "\x23"+BConfig.WebConfig.FlashSeparator+"\x23")
+ if len(kv) == 2 {
+ flash.Data[kv[0]] = kv[1]
+ }
+ }
+ }
+ //read one time then delete it
+ c.Ctx.SetCookie(BConfig.WebConfig.FlashName, "", -1, "/")
+ }
+ c.Data["flash"] = flash.Data
+ return flash
+}
diff --git a/pkg/flash_test.go b/pkg/flash_test.go
new file mode 100644
index 00000000..d5e9608d
--- /dev/null
+++ b/pkg/flash_test.go
@@ -0,0 +1,54 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package beego
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+)
+
+type TestFlashController struct {
+ Controller
+}
+
+func (t *TestFlashController) TestWriteFlash() {
+ flash := NewFlash()
+ flash.Notice("TestFlashString")
+ flash.Store(&t.Controller)
+ // we choose to serve json because we don't want to load a template html file
+ t.ServeJSON(true)
+}
+
+func TestFlashHeader(t *testing.T) {
+ // create fake GET request
+ r, _ := http.NewRequest("GET", "/", nil)
+ w := httptest.NewRecorder()
+
+ // setup the handler
+ handler := NewControllerRegister()
+ handler.Add("/", &TestFlashController{}, "get:TestWriteFlash")
+ handler.ServeHTTP(w, r)
+
+ // get the Set-Cookie value
+ sc := w.Header().Get("Set-Cookie")
+ // match for the expected header
+ res := strings.Contains(sc, "BEEGO_FLASH=%00notice%23BEEGOFLASH%23TestFlashString%00")
+ // validate the assertion
+ if !res {
+ t.Errorf("TestFlashHeader() unable to validate flash message")
+ }
+}
diff --git a/pkg/fs.go b/pkg/fs.go
new file mode 100644
index 00000000..41cc6f6e
--- /dev/null
+++ b/pkg/fs.go
@@ -0,0 +1,74 @@
+package beego
+
+import (
+ "net/http"
+ "os"
+ "path/filepath"
+)
+
+type FileSystem struct {
+}
+
+func (d FileSystem) Open(name string) (http.File, error) {
+ return os.Open(name)
+}
+
+// Walk walks the file tree rooted at root in filesystem, calling walkFn for each file or
+// directory in the tree, including root. All errors that arise visiting files
+// and directories are filtered by walkFn.
+func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error {
+
+ f, err := fs.Open(root)
+ if err != nil {
+ return err
+ }
+ info, err := f.Stat()
+ if err != nil {
+ err = walkFn(root, nil, err)
+ } else {
+ err = walk(fs, root, info, walkFn)
+ }
+ if err == filepath.SkipDir {
+ return nil
+ }
+ return err
+}
+
+// walk recursively descends path, calling walkFn.
+func walk(fs http.FileSystem, path string, info os.FileInfo, walkFn filepath.WalkFunc) error {
+ var err error
+ if !info.IsDir() {
+ return walkFn(path, info, nil)
+ }
+
+ dir, err := fs.Open(path)
+ if err != nil {
+ if err1 := walkFn(path, info, err); err1 != nil {
+ return err1
+ }
+ return err
+ }
+ defer dir.Close()
+ dirs, err := dir.Readdir(-1)
+ err1 := walkFn(path, info, err)
+ // If err != nil, walk can't walk into this directory.
+ // err1 != nil means walkFn want walk to skip this directory or stop walking.
+ // Therefore, if one of err and err1 isn't nil, walk will return.
+ if err != nil || err1 != nil {
+ // The caller's behavior is controlled by the return value, which is decided
+ // by walkFn. walkFn may ignore err and return nil.
+ // If walkFn returns SkipDir, it will be handled by the caller.
+ // So walk should return whatever walkFn returns.
+ return err1
+ }
+
+ for _, fileInfo := range dirs {
+ filename := filepath.Join(path, fileInfo.Name())
+ if err = walk(fs, filename, fileInfo, walkFn); err != nil {
+ if !fileInfo.IsDir() || err != filepath.SkipDir {
+ return err
+ }
+ }
+ }
+ return nil
+}
diff --git a/pkg/grace/grace.go b/pkg/grace/grace.go
new file mode 100644
index 00000000..fb0cb7bb
--- /dev/null
+++ b/pkg/grace/grace.go
@@ -0,0 +1,166 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package grace use to hot reload
+// Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/
+//
+// Usage:
+//
+// import(
+// "log"
+// "net/http"
+// "os"
+//
+// "github.com/astaxie/beego/grace"
+// )
+//
+// func handler(w http.ResponseWriter, r *http.Request) {
+// w.Write([]byte("WORLD!"))
+// }
+//
+// func main() {
+// mux := http.NewServeMux()
+// mux.HandleFunc("/hello", handler)
+//
+// err := grace.ListenAndServe("localhost:8080", mux)
+// if err != nil {
+// log.Println(err)
+// }
+// log.Println("Server on 8080 stopped")
+// os.Exit(0)
+// }
+package grace
+
+import (
+ "flag"
+ "net/http"
+ "os"
+ "strings"
+ "sync"
+ "syscall"
+ "time"
+)
+
+const (
+ // PreSignal is the position to add filter before signal
+ PreSignal = iota
+ // PostSignal is the position to add filter after signal
+ PostSignal
+ // StateInit represent the application inited
+ StateInit
+ // StateRunning represent the application is running
+ StateRunning
+ // StateShuttingDown represent the application is shutting down
+ StateShuttingDown
+ // StateTerminate represent the application is killed
+ StateTerminate
+)
+
+var (
+ regLock *sync.Mutex
+ runningServers map[string]*Server
+ runningServersOrder []string
+ socketPtrOffsetMap map[string]uint
+ runningServersForked bool
+
+ // DefaultReadTimeOut is the HTTP read timeout
+ DefaultReadTimeOut time.Duration
+ // DefaultWriteTimeOut is the HTTP Write timeout
+ DefaultWriteTimeOut time.Duration
+ // DefaultMaxHeaderBytes is the Max HTTP Header size, default is 0, no limit
+ DefaultMaxHeaderBytes int
+ // DefaultTimeout is the shutdown server's timeout. default is 60s
+ DefaultTimeout = 60 * time.Second
+
+ isChild bool
+ socketOrder string
+
+ hookableSignals []os.Signal
+)
+
+func init() {
+ flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)")
+ flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started")
+
+ regLock = &sync.Mutex{}
+ runningServers = make(map[string]*Server)
+ runningServersOrder = []string{}
+ socketPtrOffsetMap = make(map[string]uint)
+
+ hookableSignals = []os.Signal{
+ syscall.SIGHUP,
+ syscall.SIGINT,
+ syscall.SIGTERM,
+ }
+}
+
+// NewServer returns a new graceServer.
+func NewServer(addr string, handler http.Handler) (srv *Server) {
+ regLock.Lock()
+ defer regLock.Unlock()
+
+ if !flag.Parsed() {
+ flag.Parse()
+ }
+ if len(socketOrder) > 0 {
+ for i, addr := range strings.Split(socketOrder, ",") {
+ socketPtrOffsetMap[addr] = uint(i)
+ }
+ } else {
+ socketPtrOffsetMap[addr] = uint(len(runningServersOrder))
+ }
+
+ srv = &Server{
+ sigChan: make(chan os.Signal),
+ isChild: isChild,
+ SignalHooks: map[int]map[os.Signal][]func(){
+ PreSignal: {
+ syscall.SIGHUP: {},
+ syscall.SIGINT: {},
+ syscall.SIGTERM: {},
+ },
+ PostSignal: {
+ syscall.SIGHUP: {},
+ syscall.SIGINT: {},
+ syscall.SIGTERM: {},
+ },
+ },
+ state: StateInit,
+ Network: "tcp",
+ terminalChan: make(chan error), //no cache channel
+ }
+ srv.Server = &http.Server{
+ Addr: addr,
+ ReadTimeout: DefaultReadTimeOut,
+ WriteTimeout: DefaultWriteTimeOut,
+ MaxHeaderBytes: DefaultMaxHeaderBytes,
+ Handler: handler,
+ }
+
+ runningServersOrder = append(runningServersOrder, addr)
+ runningServers[addr] = srv
+ return srv
+}
+
+// ListenAndServe refer http.ListenAndServe
+func ListenAndServe(addr string, handler http.Handler) error {
+ server := NewServer(addr, handler)
+ return server.ListenAndServe()
+}
+
+// ListenAndServeTLS refer http.ListenAndServeTLS
+func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error {
+ server := NewServer(addr, handler)
+ return server.ListenAndServeTLS(certFile, keyFile)
+}
diff --git a/pkg/grace/server.go b/pkg/grace/server.go
new file mode 100644
index 00000000..008a6171
--- /dev/null
+++ b/pkg/grace/server.go
@@ -0,0 +1,356 @@
+package grace
+
+import (
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "net"
+ "net/http"
+ "os"
+ "os/exec"
+ "os/signal"
+ "strings"
+ "syscall"
+ "time"
+)
+
+// Server embedded http.Server
+type Server struct {
+ *http.Server
+ ln net.Listener
+ SignalHooks map[int]map[os.Signal][]func()
+ sigChan chan os.Signal
+ isChild bool
+ state uint8
+ Network string
+ terminalChan chan error
+}
+
+// Serve accepts incoming connections on the Listener l,
+// creating a new service goroutine for each.
+// The service goroutines read requests and then call srv.Handler to reply to them.
+func (srv *Server) Serve() (err error) {
+ srv.state = StateRunning
+ defer func() { srv.state = StateTerminate }()
+
+ // When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS
+ // immediately return ErrServerClosed. Make sure the program doesn't exit
+ // and waits instead for Shutdown to return.
+ if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed {
+ log.Println(syscall.Getpid(), "Server.Serve() error:", err)
+ return err
+ }
+
+ log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.")
+ // wait for Shutdown to return
+ if shutdownErr := <-srv.terminalChan; shutdownErr != nil {
+ return shutdownErr
+ }
+ return
+}
+
+// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
+// to handle requests on incoming connections. If srv.Addr is blank, ":http" is
+// used.
+func (srv *Server) ListenAndServe() (err error) {
+ addr := srv.Addr
+ if addr == "" {
+ addr = ":http"
+ }
+
+ go srv.handleSignals()
+
+ srv.ln, err = srv.getListener(addr)
+ if err != nil {
+ log.Println(err)
+ return err
+ }
+
+ if srv.isChild {
+ process, err := os.FindProcess(os.Getppid())
+ if err != nil {
+ log.Println(err)
+ return err
+ }
+ err = process.Signal(syscall.SIGTERM)
+ if err != nil {
+ return err
+ }
+ }
+
+ log.Println(os.Getpid(), srv.Addr)
+ return srv.Serve()
+}
+
+// ListenAndServeTLS listens on the TCP network address srv.Addr and then calls
+// Serve to handle requests on incoming TLS connections.
+//
+// Filenames containing a certificate and matching private key for the server must
+// be provided. If the certificate is signed by a certificate authority, the
+// certFile should be the concatenation of the server's certificate followed by the
+// CA's certificate.
+//
+// If srv.Addr is blank, ":https" is used.
+func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
+ addr := srv.Addr
+ if addr == "" {
+ addr = ":https"
+ }
+
+ if srv.TLSConfig == nil {
+ srv.TLSConfig = &tls.Config{}
+ }
+ if srv.TLSConfig.NextProtos == nil {
+ srv.TLSConfig.NextProtos = []string{"http/1.1"}
+ }
+
+ srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
+ srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
+ if err != nil {
+ return
+ }
+
+ go srv.handleSignals()
+
+ ln, err := srv.getListener(addr)
+ if err != nil {
+ log.Println(err)
+ return err
+ }
+ srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
+
+ if srv.isChild {
+ process, err := os.FindProcess(os.Getppid())
+ if err != nil {
+ log.Println(err)
+ return err
+ }
+ err = process.Signal(syscall.SIGTERM)
+ if err != nil {
+ return err
+ }
+ }
+
+ log.Println(os.Getpid(), srv.Addr)
+ return srv.Serve()
+}
+
+// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls
+// Serve to handle requests on incoming mutual TLS connections.
+func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) {
+ addr := srv.Addr
+ if addr == "" {
+ addr = ":https"
+ }
+
+ if srv.TLSConfig == nil {
+ srv.TLSConfig = &tls.Config{}
+ }
+ if srv.TLSConfig.NextProtos == nil {
+ srv.TLSConfig.NextProtos = []string{"http/1.1"}
+ }
+
+ srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
+ srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
+ if err != nil {
+ return
+ }
+ srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
+ pool := x509.NewCertPool()
+ data, err := ioutil.ReadFile(trustFile)
+ if err != nil {
+ log.Println(err)
+ return err
+ }
+ pool.AppendCertsFromPEM(data)
+ srv.TLSConfig.ClientCAs = pool
+ log.Println("Mutual HTTPS")
+ go srv.handleSignals()
+
+ ln, err := srv.getListener(addr)
+ if err != nil {
+ log.Println(err)
+ return err
+ }
+ srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
+
+ if srv.isChild {
+ process, err := os.FindProcess(os.Getppid())
+ if err != nil {
+ log.Println(err)
+ return err
+ }
+ err = process.Signal(syscall.SIGTERM)
+ if err != nil {
+ return err
+ }
+ }
+
+ log.Println(os.Getpid(), srv.Addr)
+ return srv.Serve()
+}
+
+// getListener either opens a new socket to listen on, or takes the acceptor socket
+// it got passed when restarted.
+func (srv *Server) getListener(laddr string) (l net.Listener, err error) {
+ if srv.isChild {
+ var ptrOffset uint
+ if len(socketPtrOffsetMap) > 0 {
+ ptrOffset = socketPtrOffsetMap[laddr]
+ log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr])
+ }
+
+ f := os.NewFile(uintptr(3+ptrOffset), "")
+ l, err = net.FileListener(f)
+ if err != nil {
+ err = fmt.Errorf("net.FileListener error: %v", err)
+ return
+ }
+ } else {
+ l, err = net.Listen(srv.Network, laddr)
+ if err != nil {
+ err = fmt.Errorf("net.Listen error: %v", err)
+ return
+ }
+ }
+ return
+}
+
+type tcpKeepAliveListener struct {
+ *net.TCPListener
+}
+
+func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
+ tc, err := ln.AcceptTCP()
+ if err != nil {
+ return
+ }
+ tc.SetKeepAlive(true)
+ tc.SetKeepAlivePeriod(3 * time.Minute)
+ return tc, nil
+}
+
+// handleSignals listens for os Signals and calls any hooked in function that the
+// user had registered with the signal.
+func (srv *Server) handleSignals() {
+ var sig os.Signal
+
+ signal.Notify(
+ srv.sigChan,
+ hookableSignals...,
+ )
+
+ pid := syscall.Getpid()
+ for {
+ sig = <-srv.sigChan
+ srv.signalHooks(PreSignal, sig)
+ switch sig {
+ case syscall.SIGHUP:
+ log.Println(pid, "Received SIGHUP. forking.")
+ err := srv.fork()
+ if err != nil {
+ log.Println("Fork err:", err)
+ }
+ case syscall.SIGINT:
+ log.Println(pid, "Received SIGINT.")
+ srv.shutdown()
+ case syscall.SIGTERM:
+ log.Println(pid, "Received SIGTERM.")
+ srv.shutdown()
+ default:
+ log.Printf("Received %v: nothing i care about...\n", sig)
+ }
+ srv.signalHooks(PostSignal, sig)
+ }
+}
+
+func (srv *Server) signalHooks(ppFlag int, sig os.Signal) {
+ if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet {
+ return
+ }
+ for _, f := range srv.SignalHooks[ppFlag][sig] {
+ f()
+ }
+}
+
+// shutdown closes the listener so that no new connections are accepted. it also
+// starts a goroutine that will serverTimeout (stop all running requests) the server
+// after DefaultTimeout.
+func (srv *Server) shutdown() {
+ if srv.state != StateRunning {
+ return
+ }
+
+ srv.state = StateShuttingDown
+ log.Println(syscall.Getpid(), "Waiting for connections to finish...")
+ ctx := context.Background()
+ if DefaultTimeout >= 0 {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout)
+ defer cancel()
+ }
+ srv.terminalChan <- srv.Server.Shutdown(ctx)
+}
+
+func (srv *Server) fork() (err error) {
+ regLock.Lock()
+ defer regLock.Unlock()
+ if runningServersForked {
+ return
+ }
+ runningServersForked = true
+
+ var files = make([]*os.File, len(runningServers))
+ var orderArgs = make([]string, len(runningServers))
+ for _, srvPtr := range runningServers {
+ f, _ := srvPtr.ln.(*net.TCPListener).File()
+ files[socketPtrOffsetMap[srvPtr.Server.Addr]] = f
+ orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr
+ }
+
+ log.Println(files)
+ path := os.Args[0]
+ var args []string
+ if len(os.Args) > 1 {
+ for _, arg := range os.Args[1:] {
+ if arg == "-graceful" {
+ break
+ }
+ args = append(args, arg)
+ }
+ }
+ args = append(args, "-graceful")
+ if len(runningServers) > 1 {
+ args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ",")))
+ log.Println(args)
+ }
+ cmd := exec.Command(path, args...)
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+ cmd.ExtraFiles = files
+ err = cmd.Start()
+ if err != nil {
+ log.Fatalf("Restart: Failed to launch, error: %v", err)
+ }
+
+ return
+}
+
+// RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal.
+func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) {
+ if ppFlag != PreSignal && ppFlag != PostSignal {
+ err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal")
+ return
+ }
+ for _, s := range hookableSignals {
+ if s == sig {
+ srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f)
+ return
+ }
+ }
+ err = fmt.Errorf("Signal '%v' is not supported", sig)
+ return
+}
diff --git a/pkg/hooks.go b/pkg/hooks.go
new file mode 100644
index 00000000..49c42d5a
--- /dev/null
+++ b/pkg/hooks.go
@@ -0,0 +1,104 @@
+package beego
+
+import (
+ "encoding/json"
+ "mime"
+ "net/http"
+ "path/filepath"
+
+ "github.com/astaxie/beego/context"
+ "github.com/astaxie/beego/logs"
+ "github.com/astaxie/beego/session"
+)
+
+// register MIME type with content type
+func registerMime() error {
+ for k, v := range mimemaps {
+ mime.AddExtensionType(k, v)
+ }
+ return nil
+}
+
+// register default error http handlers, 404,401,403,500 and 503.
+func registerDefaultErrorHandler() error {
+ m := map[string]func(http.ResponseWriter, *http.Request){
+ "401": unauthorized,
+ "402": paymentRequired,
+ "403": forbidden,
+ "404": notFound,
+ "405": methodNotAllowed,
+ "500": internalServerError,
+ "501": notImplemented,
+ "502": badGateway,
+ "503": serviceUnavailable,
+ "504": gatewayTimeout,
+ "417": invalidxsrf,
+ "422": missingxsrf,
+ "413": payloadTooLarge,
+ }
+ for e, h := range m {
+ if _, ok := ErrorMaps[e]; !ok {
+ ErrorHandler(e, h)
+ }
+ }
+ return nil
+}
+
+func registerSession() error {
+ if BConfig.WebConfig.Session.SessionOn {
+ var err error
+ sessionConfig := AppConfig.String("sessionConfig")
+ conf := new(session.ManagerConfig)
+ if sessionConfig == "" {
+ conf.CookieName = BConfig.WebConfig.Session.SessionName
+ conf.EnableSetCookie = BConfig.WebConfig.Session.SessionAutoSetCookie
+ conf.Gclifetime = BConfig.WebConfig.Session.SessionGCMaxLifetime
+ conf.Secure = BConfig.Listen.EnableHTTPS
+ conf.CookieLifeTime = BConfig.WebConfig.Session.SessionCookieLifeTime
+ conf.ProviderConfig = filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig)
+ conf.DisableHTTPOnly = BConfig.WebConfig.Session.SessionDisableHTTPOnly
+ conf.Domain = BConfig.WebConfig.Session.SessionDomain
+ conf.EnableSidInHTTPHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader
+ conf.SessionNameInHTTPHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader
+ conf.EnableSidInURLQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery
+ } else {
+ if err = json.Unmarshal([]byte(sessionConfig), conf); err != nil {
+ return err
+ }
+ }
+ if GlobalSessions, err = session.NewManager(BConfig.WebConfig.Session.SessionProvider, conf); err != nil {
+ return err
+ }
+ go GlobalSessions.GC()
+ }
+ return nil
+}
+
+func registerTemplate() error {
+ defer lockViewPaths()
+ if err := AddViewPath(BConfig.WebConfig.ViewsPath); err != nil {
+ if BConfig.RunMode == DEV {
+ logs.Warn(err)
+ }
+ return err
+ }
+ return nil
+}
+
+func registerAdmin() error {
+ if BConfig.Listen.EnableAdmin {
+ go beeAdminApp.Run()
+ }
+ return nil
+}
+
+func registerGzip() error {
+ if BConfig.EnableGzip {
+ context.InitGzip(
+ AppConfig.DefaultInt("gzipMinLength", -1),
+ AppConfig.DefaultInt("gzipCompressLevel", -1),
+ AppConfig.DefaultStrings("includedMethods", []string{"GET"}),
+ )
+ }
+ return nil
+}
diff --git a/pkg/httplib/README.md b/pkg/httplib/README.md
new file mode 100644
index 00000000..97df8e6b
--- /dev/null
+++ b/pkg/httplib/README.md
@@ -0,0 +1,97 @@
+# httplib
+httplib is an libs help you to curl remote url.
+
+# How to use?
+
+## GET
+you can use Get to crawl data.
+
+ import "github.com/astaxie/beego/httplib"
+
+ str, err := httplib.Get("http://beego.me/").String()
+ if err != nil {
+ // error
+ }
+ fmt.Println(str)
+
+## POST
+POST data to remote url
+
+ req := httplib.Post("http://beego.me/")
+ req.Param("username","astaxie")
+ req.Param("password","123456")
+ str, err := req.String()
+ if err != nil {
+ // error
+ }
+ fmt.Println(str)
+
+## Set timeout
+
+The default timeout is `60` seconds, function prototype:
+
+ SetTimeout(connectTimeout, readWriteTimeout time.Duration)
+
+Example:
+
+ // GET
+ httplib.Get("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second)
+
+ // POST
+ httplib.Post("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second)
+
+
+## Debug
+
+If you want to debug the request info, set the debug on
+
+ httplib.Get("http://beego.me/").Debug(true)
+
+## Set HTTP Basic Auth
+
+ str, err := Get("http://beego.me/").SetBasicAuth("user", "passwd").String()
+ if err != nil {
+ // error
+ }
+ fmt.Println(str)
+
+## Set HTTPS
+
+If request url is https, You can set the client support TSL:
+
+ httplib.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true})
+
+More info about the `tls.Config` please visit http://golang.org/pkg/crypto/tls/#Config
+
+## Set HTTP Version
+
+some servers need to specify the protocol version of HTTP
+
+ httplib.Get("http://beego.me/").SetProtocolVersion("HTTP/1.1")
+
+## Set Cookie
+
+some http request need setcookie. So set it like this:
+
+ cookie := &http.Cookie{}
+ cookie.Name = "username"
+ cookie.Value = "astaxie"
+ httplib.Get("http://beego.me/").SetCookie(cookie)
+
+## Upload file
+
+httplib support mutil file upload, use `req.PostFile()`
+
+ req := httplib.Post("http://beego.me/")
+ req.Param("username","astaxie")
+ req.PostFile("uploadfile1", "httplib.pdf")
+ str, err := req.String()
+ if err != nil {
+ // error
+ }
+ fmt.Println(str)
+
+
+See godoc for further documentation and examples.
+
+* [godoc.org/github.com/astaxie/beego/httplib](https://godoc.org/github.com/astaxie/beego/httplib)
diff --git a/pkg/httplib/httplib.go b/pkg/httplib/httplib.go
new file mode 100644
index 00000000..60aa4e8b
--- /dev/null
+++ b/pkg/httplib/httplib.go
@@ -0,0 +1,654 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package httplib is used as http.Client
+// Usage:
+//
+// import "github.com/astaxie/beego/httplib"
+//
+// b := httplib.Post("http://beego.me/")
+// b.Param("username","astaxie")
+// b.Param("password","123456")
+// b.PostFile("uploadfile1", "httplib.pdf")
+// b.PostFile("uploadfile2", "httplib.txt")
+// str, err := b.String()
+// if err != nil {
+// t.Fatal(err)
+// }
+// fmt.Println(str)
+//
+// more docs http://beego.me/docs/module/httplib.md
+package httplib
+
+import (
+ "bytes"
+ "compress/gzip"
+ "crypto/tls"
+ "encoding/json"
+ "encoding/xml"
+ "io"
+ "io/ioutil"
+ "log"
+ "mime/multipart"
+ "net"
+ "net/http"
+ "net/http/cookiejar"
+ "net/http/httputil"
+ "net/url"
+ "os"
+ "path"
+ "strings"
+ "sync"
+ "time"
+
+ "gopkg.in/yaml.v2"
+)
+
+var defaultSetting = BeegoHTTPSettings{
+ UserAgent: "beegoServer",
+ ConnectTimeout: 60 * time.Second,
+ ReadWriteTimeout: 60 * time.Second,
+ Gzip: true,
+ DumpBody: true,
+}
+
+var defaultCookieJar http.CookieJar
+var settingMutex sync.Mutex
+
+// createDefaultCookie creates a global cookiejar to store cookies.
+func createDefaultCookie() {
+ settingMutex.Lock()
+ defer settingMutex.Unlock()
+ defaultCookieJar, _ = cookiejar.New(nil)
+}
+
+// SetDefaultSetting Overwrite default settings
+func SetDefaultSetting(setting BeegoHTTPSettings) {
+ settingMutex.Lock()
+ defer settingMutex.Unlock()
+ defaultSetting = setting
+}
+
+// NewBeegoRequest return *BeegoHttpRequest with specific method
+func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest {
+ var resp http.Response
+ u, err := url.Parse(rawurl)
+ if err != nil {
+ log.Println("Httplib:", err)
+ }
+ req := http.Request{
+ URL: u,
+ Method: method,
+ Header: make(http.Header),
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ }
+ return &BeegoHTTPRequest{
+ url: rawurl,
+ req: &req,
+ params: map[string][]string{},
+ files: map[string]string{},
+ setting: defaultSetting,
+ resp: &resp,
+ }
+}
+
+// Get returns *BeegoHttpRequest with GET method.
+func Get(url string) *BeegoHTTPRequest {
+ return NewBeegoRequest(url, "GET")
+}
+
+// Post returns *BeegoHttpRequest with POST method.
+func Post(url string) *BeegoHTTPRequest {
+ return NewBeegoRequest(url, "POST")
+}
+
+// Put returns *BeegoHttpRequest with PUT method.
+func Put(url string) *BeegoHTTPRequest {
+ return NewBeegoRequest(url, "PUT")
+}
+
+// Delete returns *BeegoHttpRequest DELETE method.
+func Delete(url string) *BeegoHTTPRequest {
+ return NewBeegoRequest(url, "DELETE")
+}
+
+// Head returns *BeegoHttpRequest with HEAD method.
+func Head(url string) *BeegoHTTPRequest {
+ return NewBeegoRequest(url, "HEAD")
+}
+
+// BeegoHTTPSettings is the http.Client setting
+type BeegoHTTPSettings struct {
+ ShowDebug bool
+ UserAgent string
+ ConnectTimeout time.Duration
+ ReadWriteTimeout time.Duration
+ TLSClientConfig *tls.Config
+ Proxy func(*http.Request) (*url.URL, error)
+ Transport http.RoundTripper
+ CheckRedirect func(req *http.Request, via []*http.Request) error
+ EnableCookie bool
+ Gzip bool
+ DumpBody bool
+ Retries int // if set to -1 means will retry forever
+ RetryDelay time.Duration
+}
+
+// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request.
+type BeegoHTTPRequest struct {
+ url string
+ req *http.Request
+ params map[string][]string
+ files map[string]string
+ setting BeegoHTTPSettings
+ resp *http.Response
+ body []byte
+ dump []byte
+}
+
+// GetRequest return the request object
+func (b *BeegoHTTPRequest) GetRequest() *http.Request {
+ return b.req
+}
+
+// Setting Change request settings
+func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest {
+ b.setting = setting
+ return b
+}
+
+// SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password.
+func (b *BeegoHTTPRequest) SetBasicAuth(username, password string) *BeegoHTTPRequest {
+ b.req.SetBasicAuth(username, password)
+ return b
+}
+
+// SetEnableCookie sets enable/disable cookiejar
+func (b *BeegoHTTPRequest) SetEnableCookie(enable bool) *BeegoHTTPRequest {
+ b.setting.EnableCookie = enable
+ return b
+}
+
+// SetUserAgent sets User-Agent header field
+func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest {
+ b.setting.UserAgent = useragent
+ return b
+}
+
+// Debug sets show debug or not when executing request.
+func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest {
+ b.setting.ShowDebug = isdebug
+ return b
+}
+
+// Retries sets Retries times.
+// default is 0 means no retried.
+// -1 means retried forever.
+// others means retried times.
+func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest {
+ b.setting.Retries = times
+ return b
+}
+
+func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest {
+ b.setting.RetryDelay = delay
+ return b
+}
+
+// DumpBody setting whether need to Dump the Body.
+func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest {
+ b.setting.DumpBody = isdump
+ return b
+}
+
+// DumpRequest return the DumpRequest
+func (b *BeegoHTTPRequest) DumpRequest() []byte {
+ return b.dump
+}
+
+// SetTimeout sets connect time out and read-write time out for BeegoRequest.
+func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest {
+ b.setting.ConnectTimeout = connectTimeout
+ b.setting.ReadWriteTimeout = readWriteTimeout
+ return b
+}
+
+// SetTLSClientConfig sets tls connection configurations if visiting https url.
+func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest {
+ b.setting.TLSClientConfig = config
+ return b
+}
+
+// Header add header item string in request.
+func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest {
+ b.req.Header.Set(key, value)
+ return b
+}
+
+// SetHost set the request host
+func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest {
+ b.req.Host = host
+ return b
+}
+
+// SetProtocolVersion Set the protocol version for incoming requests.
+// Client requests always use HTTP/1.1.
+func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest {
+ if len(vers) == 0 {
+ vers = "HTTP/1.1"
+ }
+
+ major, minor, ok := http.ParseHTTPVersion(vers)
+ if ok {
+ b.req.Proto = vers
+ b.req.ProtoMajor = major
+ b.req.ProtoMinor = minor
+ }
+
+ return b
+}
+
+// SetCookie add cookie into request.
+func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest {
+ b.req.Header.Add("Cookie", cookie.String())
+ return b
+}
+
+// SetTransport set the setting transport
+func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest {
+ b.setting.Transport = transport
+ return b
+}
+
+// SetProxy set the http proxy
+// example:
+//
+// func(req *http.Request) (*url.URL, error) {
+// u, _ := url.ParseRequestURI("http://127.0.0.1:8118")
+// return u, nil
+// }
+func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHTTPRequest {
+ b.setting.Proxy = proxy
+ return b
+}
+
+// SetCheckRedirect specifies the policy for handling redirects.
+//
+// If CheckRedirect is nil, the Client uses its default policy,
+// which is to stop after 10 consecutive requests.
+func (b *BeegoHTTPRequest) SetCheckRedirect(redirect func(req *http.Request, via []*http.Request) error) *BeegoHTTPRequest {
+ b.setting.CheckRedirect = redirect
+ return b
+}
+
+// Param adds query param in to request.
+// params build query string as ?key1=value1&key2=value2...
+func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest {
+ if param, ok := b.params[key]; ok {
+ b.params[key] = append(param, value)
+ } else {
+ b.params[key] = []string{value}
+ }
+ return b
+}
+
+// PostFile add a post file to the request
+func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest {
+ b.files[formname] = filename
+ return b
+}
+
+// Body adds request raw body.
+// it supports string and []byte.
+func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest {
+ switch t := data.(type) {
+ case string:
+ bf := bytes.NewBufferString(t)
+ b.req.Body = ioutil.NopCloser(bf)
+ b.req.ContentLength = int64(len(t))
+ case []byte:
+ bf := bytes.NewBuffer(t)
+ b.req.Body = ioutil.NopCloser(bf)
+ b.req.ContentLength = int64(len(t))
+ }
+ return b
+}
+
+// XMLBody adds request raw body encoding by XML.
+func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) {
+ if b.req.Body == nil && obj != nil {
+ byts, err := xml.Marshal(obj)
+ if err != nil {
+ return b, err
+ }
+ b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
+ b.req.ContentLength = int64(len(byts))
+ b.req.Header.Set("Content-Type", "application/xml")
+ }
+ return b, nil
+}
+
+// YAMLBody adds request raw body encoding by YAML.
+func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) {
+ if b.req.Body == nil && obj != nil {
+ byts, err := yaml.Marshal(obj)
+ if err != nil {
+ return b, err
+ }
+ b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
+ b.req.ContentLength = int64(len(byts))
+ b.req.Header.Set("Content-Type", "application/x+yaml")
+ }
+ return b, nil
+}
+
+// JSONBody adds request raw body encoding by JSON.
+func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) {
+ if b.req.Body == nil && obj != nil {
+ byts, err := json.Marshal(obj)
+ if err != nil {
+ return b, err
+ }
+ b.req.Body = ioutil.NopCloser(bytes.NewReader(byts))
+ b.req.ContentLength = int64(len(byts))
+ b.req.Header.Set("Content-Type", "application/json")
+ }
+ return b, nil
+}
+
+func (b *BeegoHTTPRequest) buildURL(paramBody string) {
+ // build GET url with query string
+ if b.req.Method == "GET" && len(paramBody) > 0 {
+ if strings.Contains(b.url, "?") {
+ b.url += "&" + paramBody
+ } else {
+ b.url = b.url + "?" + paramBody
+ }
+ return
+ }
+
+ // build POST/PUT/PATCH url and body
+ if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH" || b.req.Method == "DELETE") && b.req.Body == nil {
+ // with files
+ if len(b.files) > 0 {
+ pr, pw := io.Pipe()
+ bodyWriter := multipart.NewWriter(pw)
+ go func() {
+ for formname, filename := range b.files {
+ fileWriter, err := bodyWriter.CreateFormFile(formname, filename)
+ if err != nil {
+ log.Println("Httplib:", err)
+ }
+ fh, err := os.Open(filename)
+ if err != nil {
+ log.Println("Httplib:", err)
+ }
+ //iocopy
+ _, err = io.Copy(fileWriter, fh)
+ fh.Close()
+ if err != nil {
+ log.Println("Httplib:", err)
+ }
+ }
+ for k, v := range b.params {
+ for _, vv := range v {
+ bodyWriter.WriteField(k, vv)
+ }
+ }
+ bodyWriter.Close()
+ pw.Close()
+ }()
+ b.Header("Content-Type", bodyWriter.FormDataContentType())
+ b.req.Body = ioutil.NopCloser(pr)
+ b.Header("Transfer-Encoding", "chunked")
+ return
+ }
+
+ // with params
+ if len(paramBody) > 0 {
+ b.Header("Content-Type", "application/x-www-form-urlencoded")
+ b.Body(paramBody)
+ }
+ }
+}
+
+func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) {
+ if b.resp.StatusCode != 0 {
+ return b.resp, nil
+ }
+ resp, err := b.DoRequest()
+ if err != nil {
+ return nil, err
+ }
+ b.resp = resp
+ return resp, nil
+}
+
+// DoRequest will do the client.Do
+func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) {
+ var paramBody string
+ if len(b.params) > 0 {
+ var buf bytes.Buffer
+ for k, v := range b.params {
+ for _, vv := range v {
+ buf.WriteString(url.QueryEscape(k))
+ buf.WriteByte('=')
+ buf.WriteString(url.QueryEscape(vv))
+ buf.WriteByte('&')
+ }
+ }
+ paramBody = buf.String()
+ paramBody = paramBody[0 : len(paramBody)-1]
+ }
+
+ b.buildURL(paramBody)
+ urlParsed, err := url.Parse(b.url)
+ if err != nil {
+ return nil, err
+ }
+
+ b.req.URL = urlParsed
+
+ trans := b.setting.Transport
+
+ if trans == nil {
+ // create default transport
+ trans = &http.Transport{
+ TLSClientConfig: b.setting.TLSClientConfig,
+ Proxy: b.setting.Proxy,
+ Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout),
+ MaxIdleConnsPerHost: 100,
+ }
+ } else {
+ // if b.transport is *http.Transport then set the settings.
+ if t, ok := trans.(*http.Transport); ok {
+ if t.TLSClientConfig == nil {
+ t.TLSClientConfig = b.setting.TLSClientConfig
+ }
+ if t.Proxy == nil {
+ t.Proxy = b.setting.Proxy
+ }
+ if t.Dial == nil {
+ t.Dial = TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout)
+ }
+ }
+ }
+
+ var jar http.CookieJar
+ if b.setting.EnableCookie {
+ if defaultCookieJar == nil {
+ createDefaultCookie()
+ }
+ jar = defaultCookieJar
+ }
+
+ client := &http.Client{
+ Transport: trans,
+ Jar: jar,
+ }
+
+ if b.setting.UserAgent != "" && b.req.Header.Get("User-Agent") == "" {
+ b.req.Header.Set("User-Agent", b.setting.UserAgent)
+ }
+
+ if b.setting.CheckRedirect != nil {
+ client.CheckRedirect = b.setting.CheckRedirect
+ }
+
+ if b.setting.ShowDebug {
+ dump, err := httputil.DumpRequest(b.req, b.setting.DumpBody)
+ if err != nil {
+ log.Println(err.Error())
+ }
+ b.dump = dump
+ }
+ // retries default value is 0, it will run once.
+ // retries equal to -1, it will run forever until success
+ // retries is setted, it will retries fixed times.
+ // Sleeps for a 400ms inbetween calls to reduce spam
+ for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ {
+ resp, err = client.Do(b.req)
+ if err == nil {
+ break
+ }
+ time.Sleep(b.setting.RetryDelay)
+ }
+ return resp, err
+}
+
+// String returns the body string in response.
+// it calls Response inner.
+func (b *BeegoHTTPRequest) String() (string, error) {
+ data, err := b.Bytes()
+ if err != nil {
+ return "", err
+ }
+
+ return string(data), nil
+}
+
+// Bytes returns the body []byte in response.
+// it calls Response inner.
+func (b *BeegoHTTPRequest) Bytes() ([]byte, error) {
+ if b.body != nil {
+ return b.body, nil
+ }
+ resp, err := b.getResponse()
+ if err != nil {
+ return nil, err
+ }
+ if resp.Body == nil {
+ return nil, nil
+ }
+ defer resp.Body.Close()
+ if b.setting.Gzip && resp.Header.Get("Content-Encoding") == "gzip" {
+ reader, err := gzip.NewReader(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+ b.body, err = ioutil.ReadAll(reader)
+ return b.body, err
+ }
+ b.body, err = ioutil.ReadAll(resp.Body)
+ return b.body, err
+}
+
+// ToFile saves the body data in response to one file.
+// it calls Response inner.
+func (b *BeegoHTTPRequest) ToFile(filename string) error {
+ resp, err := b.getResponse()
+ if err != nil {
+ return err
+ }
+ if resp.Body == nil {
+ return nil
+ }
+ defer resp.Body.Close()
+ err = pathExistAndMkdir(filename)
+ if err != nil {
+ return err
+ }
+ f, err := os.Create(filename)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ _, err = io.Copy(f, resp.Body)
+ return err
+}
+
+//Check that the file directory exists, there is no automatically created
+func pathExistAndMkdir(filename string) (err error) {
+ filename = path.Dir(filename)
+ _, err = os.Stat(filename)
+ if err == nil {
+ return nil
+ }
+ if os.IsNotExist(err) {
+ err = os.MkdirAll(filename, os.ModePerm)
+ if err == nil {
+ return nil
+ }
+ }
+ return err
+}
+
+// ToJSON returns the map that marshals from the body bytes as json in response .
+// it calls Response inner.
+func (b *BeegoHTTPRequest) ToJSON(v interface{}) error {
+ data, err := b.Bytes()
+ if err != nil {
+ return err
+ }
+ return json.Unmarshal(data, v)
+}
+
+// ToXML returns the map that marshals from the body bytes as xml in response .
+// it calls Response inner.
+func (b *BeegoHTTPRequest) ToXML(v interface{}) error {
+ data, err := b.Bytes()
+ if err != nil {
+ return err
+ }
+ return xml.Unmarshal(data, v)
+}
+
+// ToYAML returns the map that marshals from the body bytes as yaml in response .
+// it calls Response inner.
+func (b *BeegoHTTPRequest) ToYAML(v interface{}) error {
+ data, err := b.Bytes()
+ if err != nil {
+ return err
+ }
+ return yaml.Unmarshal(data, v)
+}
+
+// Response executes request client gets response mannually.
+func (b *BeegoHTTPRequest) Response() (*http.Response, error) {
+ return b.getResponse()
+}
+
+// TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field.
+func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) {
+ return func(netw, addr string) (net.Conn, error) {
+ conn, err := net.DialTimeout(netw, addr, cTimeout)
+ if err != nil {
+ return nil, err
+ }
+ err = conn.SetDeadline(time.Now().Add(rwTimeout))
+ return conn, err
+ }
+}
diff --git a/pkg/httplib/httplib_test.go b/pkg/httplib/httplib_test.go
new file mode 100644
index 00000000..f6be8571
--- /dev/null
+++ b/pkg/httplib/httplib_test.go
@@ -0,0 +1,286 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package httplib
+
+import (
+ "errors"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "os"
+ "strings"
+ "testing"
+ "time"
+)
+
+func TestResponse(t *testing.T) {
+ req := Get("http://httpbin.org/get")
+ resp, err := req.Response()
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Log(resp)
+}
+
+func TestDoRequest(t *testing.T) {
+ req := Get("https://goolnk.com/33BD2j")
+ retryAmount := 1
+ req.Retries(1)
+ req.RetryDelay(1400 * time.Millisecond)
+ retryDelay := 1400 * time.Millisecond
+
+ req.setting.CheckRedirect = func(redirectReq *http.Request, redirectVia []*http.Request) error {
+ return errors.New("Redirect triggered")
+ }
+
+ startTime := time.Now().UnixNano() / int64(time.Millisecond)
+
+ _, err := req.Response()
+ if err == nil {
+ t.Fatal("Response should have yielded an error")
+ }
+
+ endTime := time.Now().UnixNano() / int64(time.Millisecond)
+ elapsedTime := endTime - startTime
+ delayedTime := int64(retryAmount) * retryDelay.Milliseconds()
+
+ if elapsedTime < delayedTime {
+ t.Errorf("Not enough retries. Took %dms. Delay was meant to take %dms", elapsedTime, delayedTime)
+ }
+
+}
+
+func TestGet(t *testing.T) {
+ req := Get("http://httpbin.org/get")
+ b, err := req.Bytes()
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Log(b)
+
+ s, err := req.String()
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Log(s)
+
+ if string(b) != s {
+ t.Fatal("request data not match")
+ }
+}
+
+func TestSimplePost(t *testing.T) {
+ v := "smallfish"
+ req := Post("http://httpbin.org/post")
+ req.Param("username", v)
+
+ str, err := req.String()
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Log(str)
+
+ n := strings.Index(str, v)
+ if n == -1 {
+ t.Fatal(v + " not found in post")
+ }
+}
+
+//func TestPostFile(t *testing.T) {
+// v := "smallfish"
+// req := Post("http://httpbin.org/post")
+// req.Debug(true)
+// req.Param("username", v)
+// req.PostFile("uploadfile", "httplib_test.go")
+
+// str, err := req.String()
+// if err != nil {
+// t.Fatal(err)
+// }
+// t.Log(str)
+
+// n := strings.Index(str, v)
+// if n == -1 {
+// t.Fatal(v + " not found in post")
+// }
+//}
+
+func TestSimplePut(t *testing.T) {
+ str, err := Put("http://httpbin.org/put").String()
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Log(str)
+}
+
+func TestSimpleDelete(t *testing.T) {
+ str, err := Delete("http://httpbin.org/delete").String()
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Log(str)
+}
+
+func TestSimpleDeleteParam(t *testing.T) {
+ str, err := Delete("http://httpbin.org/delete").Param("key", "val").String()
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Log(str)
+}
+
+func TestWithCookie(t *testing.T) {
+ v := "smallfish"
+ str, err := Get("http://httpbin.org/cookies/set?k1=" + v).SetEnableCookie(true).String()
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Log(str)
+
+ str, err = Get("http://httpbin.org/cookies").SetEnableCookie(true).String()
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Log(str)
+
+ n := strings.Index(str, v)
+ if n == -1 {
+ t.Fatal(v + " not found in cookie")
+ }
+}
+
+func TestWithBasicAuth(t *testing.T) {
+ str, err := Get("http://httpbin.org/basic-auth/user/passwd").SetBasicAuth("user", "passwd").String()
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Log(str)
+ n := strings.Index(str, "authenticated")
+ if n == -1 {
+ t.Fatal("authenticated not found in response")
+ }
+}
+
+func TestWithUserAgent(t *testing.T) {
+ v := "beego"
+ str, err := Get("http://httpbin.org/headers").SetUserAgent(v).String()
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Log(str)
+
+ n := strings.Index(str, v)
+ if n == -1 {
+ t.Fatal(v + " not found in user-agent")
+ }
+}
+
+func TestWithSetting(t *testing.T) {
+ v := "beego"
+ var setting BeegoHTTPSettings
+ setting.EnableCookie = true
+ setting.UserAgent = v
+ setting.Transport = &http.Transport{
+ DialContext: (&net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 30 * time.Second,
+ DualStack: true,
+ }).DialContext,
+ MaxIdleConns: 50,
+ IdleConnTimeout: 90 * time.Second,
+ ExpectContinueTimeout: 1 * time.Second,
+ }
+ setting.ReadWriteTimeout = 5 * time.Second
+ SetDefaultSetting(setting)
+
+ str, err := Get("http://httpbin.org/get").String()
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Log(str)
+
+ n := strings.Index(str, v)
+ if n == -1 {
+ t.Fatal(v + " not found in user-agent")
+ }
+}
+
+func TestToJson(t *testing.T) {
+ req := Get("http://httpbin.org/ip")
+ resp, err := req.Response()
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Log(resp)
+
+ // httpbin will return http remote addr
+ type IP struct {
+ Origin string `json:"origin"`
+ }
+ var ip IP
+ err = req.ToJSON(&ip)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Log(ip.Origin)
+ ips := strings.Split(ip.Origin, ",")
+ if len(ips) == 0 {
+ t.Fatal("response is not valid ip")
+ }
+ for i := range ips {
+ if net.ParseIP(strings.TrimSpace(ips[i])).To4() == nil {
+ t.Fatal("response is not valid ip")
+ }
+ }
+
+}
+
+func TestToFile(t *testing.T) {
+ f := "beego_testfile"
+ req := Get("http://httpbin.org/ip")
+ err := req.ToFile(f)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.Remove(f)
+ b, err := ioutil.ReadFile(f)
+ if n := strings.Index(string(b), "origin"); n == -1 {
+ t.Fatal(err)
+ }
+}
+
+func TestToFileDir(t *testing.T) {
+ f := "./files/beego_testfile"
+ req := Get("http://httpbin.org/ip")
+ err := req.ToFile(f)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll("./files")
+ b, err := ioutil.ReadFile(f)
+ if n := strings.Index(string(b), "origin"); n == -1 {
+ t.Fatal(err)
+ }
+}
+
+func TestHeader(t *testing.T) {
+ req := Get("http://httpbin.org/headers")
+ req.Header("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/31.0.1650.57 Safari/537.36")
+ str, err := req.String()
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Log(str)
+}
diff --git a/pkg/log.go b/pkg/log.go
new file mode 100644
index 00000000..cc4c0f81
--- /dev/null
+++ b/pkg/log.go
@@ -0,0 +1,127 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package beego
+
+import (
+ "strings"
+
+ "github.com/astaxie/beego/logs"
+)
+
+// Log levels to control the logging output.
+// Deprecated: use github.com/astaxie/beego/logs instead.
+const (
+ LevelEmergency = iota
+ LevelAlert
+ LevelCritical
+ LevelError
+ LevelWarning
+ LevelNotice
+ LevelInformational
+ LevelDebug
+)
+
+// BeeLogger references the used application logger.
+// Deprecated: use github.com/astaxie/beego/logs instead.
+var BeeLogger = logs.GetBeeLogger()
+
+// SetLevel sets the global log level used by the simple logger.
+// Deprecated: use github.com/astaxie/beego/logs instead.
+func SetLevel(l int) {
+ logs.SetLevel(l)
+}
+
+// SetLogFuncCall set the CallDepth, default is 3
+// Deprecated: use github.com/astaxie/beego/logs instead.
+func SetLogFuncCall(b bool) {
+ logs.SetLogFuncCall(b)
+}
+
+// SetLogger sets a new logger.
+// Deprecated: use github.com/astaxie/beego/logs instead.
+func SetLogger(adaptername string, config string) error {
+ return logs.SetLogger(adaptername, config)
+}
+
+// Emergency logs a message at emergency level.
+// Deprecated: use github.com/astaxie/beego/logs instead.
+func Emergency(v ...interface{}) {
+ logs.Emergency(generateFmtStr(len(v)), v...)
+}
+
+// Alert logs a message at alert level.
+// Deprecated: use github.com/astaxie/beego/logs instead.
+func Alert(v ...interface{}) {
+ logs.Alert(generateFmtStr(len(v)), v...)
+}
+
+// Critical logs a message at critical level.
+// Deprecated: use github.com/astaxie/beego/logs instead.
+func Critical(v ...interface{}) {
+ logs.Critical(generateFmtStr(len(v)), v...)
+}
+
+// Error logs a message at error level.
+// Deprecated: use github.com/astaxie/beego/logs instead.
+func Error(v ...interface{}) {
+ logs.Error(generateFmtStr(len(v)), v...)
+}
+
+// Warning logs a message at warning level.
+// Deprecated: use github.com/astaxie/beego/logs instead.
+func Warning(v ...interface{}) {
+ logs.Warning(generateFmtStr(len(v)), v...)
+}
+
+// Warn compatibility alias for Warning()
+// Deprecated: use github.com/astaxie/beego/logs instead.
+func Warn(v ...interface{}) {
+ logs.Warn(generateFmtStr(len(v)), v...)
+}
+
+// Notice logs a message at notice level.
+// Deprecated: use github.com/astaxie/beego/logs instead.
+func Notice(v ...interface{}) {
+ logs.Notice(generateFmtStr(len(v)), v...)
+}
+
+// Informational logs a message at info level.
+// Deprecated: use github.com/astaxie/beego/logs instead.
+func Informational(v ...interface{}) {
+ logs.Informational(generateFmtStr(len(v)), v...)
+}
+
+// Info compatibility alias for Warning()
+// Deprecated: use github.com/astaxie/beego/logs instead.
+func Info(v ...interface{}) {
+ logs.Info(generateFmtStr(len(v)), v...)
+}
+
+// Debug logs a message at debug level.
+// Deprecated: use github.com/astaxie/beego/logs instead.
+func Debug(v ...interface{}) {
+ logs.Debug(generateFmtStr(len(v)), v...)
+}
+
+// Trace logs a message at trace level.
+// compatibility alias for Warning()
+// Deprecated: use github.com/astaxie/beego/logs instead.
+func Trace(v ...interface{}) {
+ logs.Trace(generateFmtStr(len(v)), v...)
+}
+
+func generateFmtStr(n int) string {
+ return strings.Repeat("%v ", n)
+}
diff --git a/pkg/logs/README.md b/pkg/logs/README.md
new file mode 100644
index 00000000..c05bcc04
--- /dev/null
+++ b/pkg/logs/README.md
@@ -0,0 +1,72 @@
+## logs
+logs is a Go logs manager. It can use many logs adapters. The repo is inspired by `database/sql` .
+
+
+## How to install?
+
+ go get github.com/astaxie/beego/logs
+
+
+## What adapters are supported?
+
+As of now this logs support console, file,smtp and conn.
+
+
+## How to use it?
+
+First you must import it
+
+```golang
+import (
+ "github.com/astaxie/beego/logs"
+)
+```
+
+Then init a Log (example with console adapter)
+
+```golang
+log := logs.NewLogger(10000)
+log.SetLogger("console", "")
+```
+
+> the first params stand for how many channel
+
+Use it like this:
+
+```golang
+log.Trace("trace")
+log.Info("info")
+log.Warn("warning")
+log.Debug("debug")
+log.Critical("critical")
+```
+
+## File adapter
+
+Configure file adapter like this:
+
+```golang
+log := NewLogger(10000)
+log.SetLogger("file", `{"filename":"test.log"}`)
+```
+
+## Conn adapter
+
+Configure like this:
+
+```golang
+log := NewLogger(1000)
+log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`)
+log.Info("info")
+```
+
+## Smtp adapter
+
+Configure like this:
+
+```golang
+log := NewLogger(10000)
+log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`)
+log.Critical("sendmail critical")
+time.Sleep(time.Second * 30)
+```
diff --git a/pkg/logs/accesslog.go b/pkg/logs/accesslog.go
new file mode 100644
index 00000000..3ff9e20f
--- /dev/null
+++ b/pkg/logs/accesslog.go
@@ -0,0 +1,83 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logs
+
+import (
+ "bytes"
+ "strings"
+ "encoding/json"
+ "fmt"
+ "time"
+)
+
+const (
+ apacheFormatPattern = "%s - - [%s] \"%s %d %d\" %f %s %s"
+ apacheFormat = "APACHE_FORMAT"
+ jsonFormat = "JSON_FORMAT"
+)
+
+// AccessLogRecord struct for holding access log data.
+type AccessLogRecord struct {
+ RemoteAddr string `json:"remote_addr"`
+ RequestTime time.Time `json:"request_time"`
+ RequestMethod string `json:"request_method"`
+ Request string `json:"request"`
+ ServerProtocol string `json:"server_protocol"`
+ Host string `json:"host"`
+ Status int `json:"status"`
+ BodyBytesSent int64 `json:"body_bytes_sent"`
+ ElapsedTime time.Duration `json:"elapsed_time"`
+ HTTPReferrer string `json:"http_referrer"`
+ HTTPUserAgent string `json:"http_user_agent"`
+ RemoteUser string `json:"remote_user"`
+}
+
+func (r *AccessLogRecord) json() ([]byte, error) {
+ buffer := &bytes.Buffer{}
+ encoder := json.NewEncoder(buffer)
+ disableEscapeHTML(encoder)
+
+ err := encoder.Encode(r)
+ return buffer.Bytes(), err
+}
+
+func disableEscapeHTML(i interface{}) {
+ if e, ok := i.(interface {
+ SetEscapeHTML(bool)
+ }); ok {
+ e.SetEscapeHTML(false)
+ }
+}
+
+// AccessLog - Format and print access log.
+func AccessLog(r *AccessLogRecord, format string) {
+ var msg string
+ switch format {
+ case apacheFormat:
+ timeFormatted := r.RequestTime.Format("02/Jan/2006 03:04:05")
+ msg = fmt.Sprintf(apacheFormatPattern, r.RemoteAddr, timeFormatted, r.Request, r.Status, r.BodyBytesSent,
+ r.ElapsedTime.Seconds(), r.HTTPReferrer, r.HTTPUserAgent)
+ case jsonFormat:
+ fallthrough
+ default:
+ jsonData, err := r.json()
+ if err != nil {
+ msg = fmt.Sprintf(`{"Error": "%s"}`, err)
+ } else {
+ msg = string(jsonData)
+ }
+ }
+ beeLogger.writeMsg(levelLoggerImpl, strings.TrimSpace(msg))
+}
diff --git a/pkg/logs/alils/alils.go b/pkg/logs/alils/alils.go
new file mode 100644
index 00000000..867ff4cb
--- /dev/null
+++ b/pkg/logs/alils/alils.go
@@ -0,0 +1,186 @@
+package alils
+
+import (
+ "encoding/json"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/astaxie/beego/logs"
+ "github.com/gogo/protobuf/proto"
+)
+
+const (
+ // CacheSize set the flush size
+ CacheSize int = 64
+ // Delimiter define the topic delimiter
+ Delimiter string = "##"
+)
+
+// Config is the Config for Ali Log
+type Config struct {
+ Project string `json:"project"`
+ Endpoint string `json:"endpoint"`
+ KeyID string `json:"key_id"`
+ KeySecret string `json:"key_secret"`
+ LogStore string `json:"log_store"`
+ Topics []string `json:"topics"`
+ Source string `json:"source"`
+ Level int `json:"level"`
+ FlushWhen int `json:"flush_when"`
+}
+
+// aliLSWriter implements LoggerInterface.
+// it writes messages in keep-live tcp connection.
+type aliLSWriter struct {
+ store *LogStore
+ group []*LogGroup
+ withMap bool
+ groupMap map[string]*LogGroup
+ lock *sync.Mutex
+ Config
+}
+
+// NewAliLS create a new Logger
+func NewAliLS() logs.Logger {
+ alils := new(aliLSWriter)
+ alils.Level = logs.LevelTrace
+ return alils
+}
+
+// Init parse config and init struct
+func (c *aliLSWriter) Init(jsonConfig string) (err error) {
+
+ json.Unmarshal([]byte(jsonConfig), c)
+
+ if c.FlushWhen > CacheSize {
+ c.FlushWhen = CacheSize
+ }
+
+ prj := &LogProject{
+ Name: c.Project,
+ Endpoint: c.Endpoint,
+ AccessKeyID: c.KeyID,
+ AccessKeySecret: c.KeySecret,
+ }
+
+ c.store, err = prj.GetLogStore(c.LogStore)
+ if err != nil {
+ return err
+ }
+
+ // Create default Log Group
+ c.group = append(c.group, &LogGroup{
+ Topic: proto.String(""),
+ Source: proto.String(c.Source),
+ Logs: make([]*Log, 0, c.FlushWhen),
+ })
+
+ // Create other Log Group
+ c.groupMap = make(map[string]*LogGroup)
+ for _, topic := range c.Topics {
+
+ lg := &LogGroup{
+ Topic: proto.String(topic),
+ Source: proto.String(c.Source),
+ Logs: make([]*Log, 0, c.FlushWhen),
+ }
+
+ c.group = append(c.group, lg)
+ c.groupMap[topic] = lg
+ }
+
+ if len(c.group) == 1 {
+ c.withMap = false
+ } else {
+ c.withMap = true
+ }
+
+ c.lock = &sync.Mutex{}
+
+ return nil
+}
+
+// WriteMsg write message in connection.
+// if connection is down, try to re-connect.
+func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error) {
+
+ if level > c.Level {
+ return nil
+ }
+
+ var topic string
+ var content string
+ var lg *LogGroup
+ if c.withMap {
+
+ // Topic,LogGroup
+ strs := strings.SplitN(msg, Delimiter, 2)
+ if len(strs) == 2 {
+ pos := strings.LastIndex(strs[0], " ")
+ topic = strs[0][pos+1 : len(strs[0])]
+ content = strs[0][0:pos] + strs[1]
+ lg = c.groupMap[topic]
+ }
+
+ // send to empty Topic
+ if lg == nil {
+ content = msg
+ lg = c.group[0]
+ }
+ } else {
+ content = msg
+ lg = c.group[0]
+ }
+
+ c1 := &LogContent{
+ Key: proto.String("msg"),
+ Value: proto.String(content),
+ }
+
+ l := &Log{
+ Time: proto.Uint32(uint32(when.Unix())),
+ Contents: []*LogContent{
+ c1,
+ },
+ }
+
+ c.lock.Lock()
+ lg.Logs = append(lg.Logs, l)
+ c.lock.Unlock()
+
+ if len(lg.Logs) >= c.FlushWhen {
+ c.flush(lg)
+ }
+
+ return nil
+}
+
+// Flush implementing method. empty.
+func (c *aliLSWriter) Flush() {
+
+ // flush all group
+ for _, lg := range c.group {
+ c.flush(lg)
+ }
+}
+
+// Destroy destroy connection writer and close tcp listener.
+func (c *aliLSWriter) Destroy() {
+}
+
+func (c *aliLSWriter) flush(lg *LogGroup) {
+
+ c.lock.Lock()
+ defer c.lock.Unlock()
+ err := c.store.PutLogs(lg)
+ if err != nil {
+ return
+ }
+
+ lg.Logs = make([]*Log, 0, c.FlushWhen)
+}
+
+func init() {
+ logs.Register(logs.AdapterAliLS, NewAliLS)
+}
diff --git a/pkg/logs/alils/config.go b/pkg/logs/alils/config.go
new file mode 100755
index 00000000..e8c24448
--- /dev/null
+++ b/pkg/logs/alils/config.go
@@ -0,0 +1,13 @@
+package alils
+
+const (
+ version = "0.5.0" // SDK version
+ signatureMethod = "hmac-sha1" // Signature method
+
+ // OffsetNewest stands for the log head offset, i.e. the offset that will be
+ // assigned to the next message that will be produced to the shard.
+ OffsetNewest = "end"
+ // OffsetOldest stands for the oldest offset available on the logstore for a
+ // shard.
+ OffsetOldest = "begin"
+)
diff --git a/pkg/logs/alils/log.pb.go b/pkg/logs/alils/log.pb.go
new file mode 100755
index 00000000..601b0d78
--- /dev/null
+++ b/pkg/logs/alils/log.pb.go
@@ -0,0 +1,1038 @@
+package alils
+
+import (
+ "fmt"
+ "io"
+ "math"
+
+ "github.com/gogo/protobuf/proto"
+ github_com_gogo_protobuf_proto "github.com/gogo/protobuf/proto"
+)
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+var (
+ // ErrInvalidLengthLog invalid proto
+ ErrInvalidLengthLog = fmt.Errorf("proto: negative length found during unmarshaling")
+ // ErrIntOverflowLog overflow
+ ErrIntOverflowLog = fmt.Errorf("proto: integer overflow")
+)
+
+// Log define the proto Log
+type Log struct {
+ Time *uint32 `protobuf:"varint,1,req,name=Time" json:"Time,omitempty"`
+ Contents []*LogContent `protobuf:"bytes,2,rep,name=Contents" json:"Contents,omitempty"`
+ XXXUnrecognized []byte `json:"-"`
+}
+
+// Reset the Log
+func (m *Log) Reset() { *m = Log{} }
+
+// String return the Compact Log
+func (m *Log) String() string { return proto.CompactTextString(m) }
+
+// ProtoMessage not implemented
+func (*Log) ProtoMessage() {}
+
+// GetTime return the Log's Time
+func (m *Log) GetTime() uint32 {
+ if m != nil && m.Time != nil {
+ return *m.Time
+ }
+ return 0
+}
+
+// GetContents return the Log's Contents
+func (m *Log) GetContents() []*LogContent {
+ if m != nil {
+ return m.Contents
+ }
+ return nil
+}
+
+// LogContent define the Log content struct
+type LogContent struct {
+ Key *string `protobuf:"bytes,1,req,name=Key" json:"Key,omitempty"`
+ Value *string `protobuf:"bytes,2,req,name=Value" json:"Value,omitempty"`
+ XXXUnrecognized []byte `json:"-"`
+}
+
+// Reset LogContent
+func (m *LogContent) Reset() { *m = LogContent{} }
+
+// String return the compact text
+func (m *LogContent) String() string { return proto.CompactTextString(m) }
+
+// ProtoMessage not implemented
+func (*LogContent) ProtoMessage() {}
+
+// GetKey return the Key
+func (m *LogContent) GetKey() string {
+ if m != nil && m.Key != nil {
+ return *m.Key
+ }
+ return ""
+}
+
+// GetValue return the Value
+func (m *LogContent) GetValue() string {
+ if m != nil && m.Value != nil {
+ return *m.Value
+ }
+ return ""
+}
+
+// LogGroup define the logs struct
+type LogGroup struct {
+ Logs []*Log `protobuf:"bytes,1,rep,name=Logs" json:"Logs,omitempty"`
+ Reserved *string `protobuf:"bytes,2,opt,name=Reserved" json:"Reserved,omitempty"`
+ Topic *string `protobuf:"bytes,3,opt,name=Topic" json:"Topic,omitempty"`
+ Source *string `protobuf:"bytes,4,opt,name=Source" json:"Source,omitempty"`
+ XXXUnrecognized []byte `json:"-"`
+}
+
+// Reset LogGroup
+func (m *LogGroup) Reset() { *m = LogGroup{} }
+
+// String return the compact text
+func (m *LogGroup) String() string { return proto.CompactTextString(m) }
+
+// ProtoMessage not implemented
+func (*LogGroup) ProtoMessage() {}
+
+// GetLogs return the loggroup logs
+func (m *LogGroup) GetLogs() []*Log {
+ if m != nil {
+ return m.Logs
+ }
+ return nil
+}
+
+// GetReserved return Reserved
+func (m *LogGroup) GetReserved() string {
+ if m != nil && m.Reserved != nil {
+ return *m.Reserved
+ }
+ return ""
+}
+
+// GetTopic return Topic
+func (m *LogGroup) GetTopic() string {
+ if m != nil && m.Topic != nil {
+ return *m.Topic
+ }
+ return ""
+}
+
+// GetSource return Source
+func (m *LogGroup) GetSource() string {
+ if m != nil && m.Source != nil {
+ return *m.Source
+ }
+ return ""
+}
+
+// LogGroupList define the LogGroups
+type LogGroupList struct {
+ LogGroups []*LogGroup `protobuf:"bytes,1,rep,name=logGroups" json:"logGroups,omitempty"`
+ XXXUnrecognized []byte `json:"-"`
+}
+
+// Reset LogGroupList
+func (m *LogGroupList) Reset() { *m = LogGroupList{} }
+
+// String return compact text
+func (m *LogGroupList) String() string { return proto.CompactTextString(m) }
+
+// ProtoMessage not implemented
+func (*LogGroupList) ProtoMessage() {}
+
+// GetLogGroups return the LogGroups
+func (m *LogGroupList) GetLogGroups() []*LogGroup {
+ if m != nil {
+ return m.LogGroups
+ }
+ return nil
+}
+
+// Marshal the logs to byte slice
+func (m *Log) Marshal() (data []byte, err error) {
+ size := m.Size()
+ data = make([]byte, size)
+ n, err := m.MarshalTo(data)
+ if err != nil {
+ return nil, err
+ }
+ return data[:n], nil
+}
+
+// MarshalTo data
+func (m *Log) MarshalTo(data []byte) (int, error) {
+ var i int
+ _ = i
+ var l int
+ _ = l
+ if m.Time == nil {
+ return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Time")
+ }
+ data[i] = 0x8
+ i++
+ i = encodeVarintLog(data, i, uint64(*m.Time))
+ if len(m.Contents) > 0 {
+ for _, msg := range m.Contents {
+ data[i] = 0x12
+ i++
+ i = encodeVarintLog(data, i, uint64(msg.Size()))
+ n, err := msg.MarshalTo(data[i:])
+ if err != nil {
+ return 0, err
+ }
+ i += n
+ }
+ }
+ if m.XXXUnrecognized != nil {
+ i += copy(data[i:], m.XXXUnrecognized)
+ }
+ return i, nil
+}
+
+// Marshal LogContent
+func (m *LogContent) Marshal() (data []byte, err error) {
+ size := m.Size()
+ data = make([]byte, size)
+ n, err := m.MarshalTo(data)
+ if err != nil {
+ return nil, err
+ }
+ return data[:n], nil
+}
+
+// MarshalTo logcontent to data
+func (m *LogContent) MarshalTo(data []byte) (int, error) {
+ var i int
+ _ = i
+ var l int
+ _ = l
+ if m.Key == nil {
+ return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Key")
+ }
+ data[i] = 0xa
+ i++
+ i = encodeVarintLog(data, i, uint64(len(*m.Key)))
+ i += copy(data[i:], *m.Key)
+
+ if m.Value == nil {
+ return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Value")
+ }
+ data[i] = 0x12
+ i++
+ i = encodeVarintLog(data, i, uint64(len(*m.Value)))
+ i += copy(data[i:], *m.Value)
+ if m.XXXUnrecognized != nil {
+ i += copy(data[i:], m.XXXUnrecognized)
+ }
+ return i, nil
+}
+
+// Marshal LogGroup
+func (m *LogGroup) Marshal() (data []byte, err error) {
+ size := m.Size()
+ data = make([]byte, size)
+ n, err := m.MarshalTo(data)
+ if err != nil {
+ return nil, err
+ }
+ return data[:n], nil
+}
+
+// MarshalTo LogGroup to data
+func (m *LogGroup) MarshalTo(data []byte) (int, error) {
+ var i int
+ _ = i
+ var l int
+ _ = l
+ if len(m.Logs) > 0 {
+ for _, msg := range m.Logs {
+ data[i] = 0xa
+ i++
+ i = encodeVarintLog(data, i, uint64(msg.Size()))
+ n, err := msg.MarshalTo(data[i:])
+ if err != nil {
+ return 0, err
+ }
+ i += n
+ }
+ }
+ if m.Reserved != nil {
+ data[i] = 0x12
+ i++
+ i = encodeVarintLog(data, i, uint64(len(*m.Reserved)))
+ i += copy(data[i:], *m.Reserved)
+ }
+ if m.Topic != nil {
+ data[i] = 0x1a
+ i++
+ i = encodeVarintLog(data, i, uint64(len(*m.Topic)))
+ i += copy(data[i:], *m.Topic)
+ }
+ if m.Source != nil {
+ data[i] = 0x22
+ i++
+ i = encodeVarintLog(data, i, uint64(len(*m.Source)))
+ i += copy(data[i:], *m.Source)
+ }
+ if m.XXXUnrecognized != nil {
+ i += copy(data[i:], m.XXXUnrecognized)
+ }
+ return i, nil
+}
+
+// Marshal LogGroupList
+func (m *LogGroupList) Marshal() (data []byte, err error) {
+ size := m.Size()
+ data = make([]byte, size)
+ n, err := m.MarshalTo(data)
+ if err != nil {
+ return nil, err
+ }
+ return data[:n], nil
+}
+
+// MarshalTo LogGroupList to data
+func (m *LogGroupList) MarshalTo(data []byte) (int, error) {
+ var i int
+ _ = i
+ var l int
+ _ = l
+ if len(m.LogGroups) > 0 {
+ for _, msg := range m.LogGroups {
+ data[i] = 0xa
+ i++
+ i = encodeVarintLog(data, i, uint64(msg.Size()))
+ n, err := msg.MarshalTo(data[i:])
+ if err != nil {
+ return 0, err
+ }
+ i += n
+ }
+ }
+ if m.XXXUnrecognized != nil {
+ i += copy(data[i:], m.XXXUnrecognized)
+ }
+ return i, nil
+}
+
+func encodeFixed64Log(data []byte, offset int, v uint64) int {
+ data[offset] = uint8(v)
+ data[offset+1] = uint8(v >> 8)
+ data[offset+2] = uint8(v >> 16)
+ data[offset+3] = uint8(v >> 24)
+ data[offset+4] = uint8(v >> 32)
+ data[offset+5] = uint8(v >> 40)
+ data[offset+6] = uint8(v >> 48)
+ data[offset+7] = uint8(v >> 56)
+ return offset + 8
+}
+func encodeFixed32Log(data []byte, offset int, v uint32) int {
+ data[offset] = uint8(v)
+ data[offset+1] = uint8(v >> 8)
+ data[offset+2] = uint8(v >> 16)
+ data[offset+3] = uint8(v >> 24)
+ return offset + 4
+}
+func encodeVarintLog(data []byte, offset int, v uint64) int {
+ for v >= 1<<7 {
+ data[offset] = uint8(v&0x7f | 0x80)
+ v >>= 7
+ offset++
+ }
+ data[offset] = uint8(v)
+ return offset + 1
+}
+
+// Size return the log's size
+func (m *Log) Size() (n int) {
+ var l int
+ _ = l
+ if m.Time != nil {
+ n += 1 + sovLog(uint64(*m.Time))
+ }
+ if len(m.Contents) > 0 {
+ for _, e := range m.Contents {
+ l = e.Size()
+ n += 1 + l + sovLog(uint64(l))
+ }
+ }
+ if m.XXXUnrecognized != nil {
+ n += len(m.XXXUnrecognized)
+ }
+ return n
+}
+
+// Size return LogContent size based on Key and Value
+func (m *LogContent) Size() (n int) {
+ var l int
+ _ = l
+ if m.Key != nil {
+ l = len(*m.Key)
+ n += 1 + l + sovLog(uint64(l))
+ }
+ if m.Value != nil {
+ l = len(*m.Value)
+ n += 1 + l + sovLog(uint64(l))
+ }
+ if m.XXXUnrecognized != nil {
+ n += len(m.XXXUnrecognized)
+ }
+ return n
+}
+
+// Size return LogGroup size based on Logs
+func (m *LogGroup) Size() (n int) {
+ var l int
+ _ = l
+ if len(m.Logs) > 0 {
+ for _, e := range m.Logs {
+ l = e.Size()
+ n += 1 + l + sovLog(uint64(l))
+ }
+ }
+ if m.Reserved != nil {
+ l = len(*m.Reserved)
+ n += 1 + l + sovLog(uint64(l))
+ }
+ if m.Topic != nil {
+ l = len(*m.Topic)
+ n += 1 + l + sovLog(uint64(l))
+ }
+ if m.Source != nil {
+ l = len(*m.Source)
+ n += 1 + l + sovLog(uint64(l))
+ }
+ if m.XXXUnrecognized != nil {
+ n += len(m.XXXUnrecognized)
+ }
+ return n
+}
+
+// Size return LogGroupList size
+func (m *LogGroupList) Size() (n int) {
+ var l int
+ _ = l
+ if len(m.LogGroups) > 0 {
+ for _, e := range m.LogGroups {
+ l = e.Size()
+ n += 1 + l + sovLog(uint64(l))
+ }
+ }
+ if m.XXXUnrecognized != nil {
+ n += len(m.XXXUnrecognized)
+ }
+ return n
+}
+
+func sovLog(x uint64) (n int) {
+ for {
+ n++
+ x >>= 7
+ if x == 0 {
+ break
+ }
+ }
+ return n
+}
+func sozLog(x uint64) (n int) {
+ return sovLog((x << 1) ^ (x >> 63))
+}
+
+// Unmarshal data to log
+func (m *Log) Unmarshal(data []byte) error {
+ var hasFields [1]uint64
+ l := len(data)
+ iNdEx := 0
+ for iNdEx < l {
+ preIndex := iNdEx
+ var wire uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ wire |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ fieldNum := int32(wire >> 3)
+ wireType := int(wire & 0x7)
+ if wireType == 4 {
+ return fmt.Errorf("proto: Log: wiretype end group for non-group")
+ }
+ if fieldNum <= 0 {
+ return fmt.Errorf("proto: Log: illegal tag %d (wire type %d)", fieldNum, wire)
+ }
+ switch fieldNum {
+ case 1:
+ if wireType != 0 {
+ return fmt.Errorf("proto: wrong wireType = %d for field Time", wireType)
+ }
+ var v uint32
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ v |= (uint32(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ m.Time = &v
+ hasFields[0] |= uint64(0x00000001)
+ case 2:
+ if wireType != 2 {
+ return fmt.Errorf("proto: wrong wireType = %d for field Contents", wireType)
+ }
+ var msglen int
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ msglen |= (int(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ if msglen < 0 {
+ return ErrInvalidLengthLog
+ }
+ postIndex := iNdEx + msglen
+ if postIndex > l {
+ return io.ErrUnexpectedEOF
+ }
+ m.Contents = append(m.Contents, &LogContent{})
+ if err := m.Contents[len(m.Contents)-1].Unmarshal(data[iNdEx:postIndex]); err != nil {
+ return err
+ }
+ iNdEx = postIndex
+ default:
+ iNdEx = preIndex
+ skippy, err := skipLog(data[iNdEx:])
+ if err != nil {
+ return err
+ }
+ if skippy < 0 {
+ return ErrInvalidLengthLog
+ }
+ if (iNdEx + skippy) > l {
+ return io.ErrUnexpectedEOF
+ }
+ m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...)
+ iNdEx += skippy
+ }
+ }
+ if hasFields[0]&uint64(0x00000001) == 0 {
+ return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Time")
+ }
+
+ if iNdEx > l {
+ return io.ErrUnexpectedEOF
+ }
+ return nil
+}
+
+// Unmarshal data to LogContent
+func (m *LogContent) Unmarshal(data []byte) error {
+ var hasFields [1]uint64
+ l := len(data)
+ iNdEx := 0
+ for iNdEx < l {
+ preIndex := iNdEx
+ var wire uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ wire |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ fieldNum := int32(wire >> 3)
+ wireType := int(wire & 0x7)
+ if wireType == 4 {
+ return fmt.Errorf("proto: Content: wiretype end group for non-group")
+ }
+ if fieldNum <= 0 {
+ return fmt.Errorf("proto: Content: illegal tag %d (wire type %d)", fieldNum, wire)
+ }
+ switch fieldNum {
+ case 1:
+ if wireType != 2 {
+ return fmt.Errorf("proto: wrong wireType = %d for field Key", wireType)
+ }
+ var stringLen uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ stringLen |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ intStringLen := int(stringLen)
+ if intStringLen < 0 {
+ return ErrInvalidLengthLog
+ }
+ postIndex := iNdEx + intStringLen
+ if postIndex > l {
+ return io.ErrUnexpectedEOF
+ }
+ s := string(data[iNdEx:postIndex])
+ m.Key = &s
+ iNdEx = postIndex
+ hasFields[0] |= uint64(0x00000001)
+ case 2:
+ if wireType != 2 {
+ return fmt.Errorf("proto: wrong wireType = %d for field Value", wireType)
+ }
+ var stringLen uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ stringLen |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ intStringLen := int(stringLen)
+ if intStringLen < 0 {
+ return ErrInvalidLengthLog
+ }
+ postIndex := iNdEx + intStringLen
+ if postIndex > l {
+ return io.ErrUnexpectedEOF
+ }
+ s := string(data[iNdEx:postIndex])
+ m.Value = &s
+ iNdEx = postIndex
+ hasFields[0] |= uint64(0x00000002)
+ default:
+ iNdEx = preIndex
+ skippy, err := skipLog(data[iNdEx:])
+ if err != nil {
+ return err
+ }
+ if skippy < 0 {
+ return ErrInvalidLengthLog
+ }
+ if (iNdEx + skippy) > l {
+ return io.ErrUnexpectedEOF
+ }
+ m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...)
+ iNdEx += skippy
+ }
+ }
+ if hasFields[0]&uint64(0x00000001) == 0 {
+ return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Key")
+ }
+ if hasFields[0]&uint64(0x00000002) == 0 {
+ return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Value")
+ }
+
+ if iNdEx > l {
+ return io.ErrUnexpectedEOF
+ }
+ return nil
+}
+
+// Unmarshal data to LogGroup
+func (m *LogGroup) Unmarshal(data []byte) error {
+ l := len(data)
+ iNdEx := 0
+ for iNdEx < l {
+ preIndex := iNdEx
+ var wire uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ wire |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ fieldNum := int32(wire >> 3)
+ wireType := int(wire & 0x7)
+ if wireType == 4 {
+ return fmt.Errorf("proto: LogGroup: wiretype end group for non-group")
+ }
+ if fieldNum <= 0 {
+ return fmt.Errorf("proto: LogGroup: illegal tag %d (wire type %d)", fieldNum, wire)
+ }
+ switch fieldNum {
+ case 1:
+ if wireType != 2 {
+ return fmt.Errorf("proto: wrong wireType = %d for field Logs", wireType)
+ }
+ var msglen int
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ msglen |= (int(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ if msglen < 0 {
+ return ErrInvalidLengthLog
+ }
+ postIndex := iNdEx + msglen
+ if postIndex > l {
+ return io.ErrUnexpectedEOF
+ }
+ m.Logs = append(m.Logs, &Log{})
+ if err := m.Logs[len(m.Logs)-1].Unmarshal(data[iNdEx:postIndex]); err != nil {
+ return err
+ }
+ iNdEx = postIndex
+ case 2:
+ if wireType != 2 {
+ return fmt.Errorf("proto: wrong wireType = %d for field Reserved", wireType)
+ }
+ var stringLen uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ stringLen |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ intStringLen := int(stringLen)
+ if intStringLen < 0 {
+ return ErrInvalidLengthLog
+ }
+ postIndex := iNdEx + intStringLen
+ if postIndex > l {
+ return io.ErrUnexpectedEOF
+ }
+ s := string(data[iNdEx:postIndex])
+ m.Reserved = &s
+ iNdEx = postIndex
+ case 3:
+ if wireType != 2 {
+ return fmt.Errorf("proto: wrong wireType = %d for field Topic", wireType)
+ }
+ var stringLen uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ stringLen |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ intStringLen := int(stringLen)
+ if intStringLen < 0 {
+ return ErrInvalidLengthLog
+ }
+ postIndex := iNdEx + intStringLen
+ if postIndex > l {
+ return io.ErrUnexpectedEOF
+ }
+ s := string(data[iNdEx:postIndex])
+ m.Topic = &s
+ iNdEx = postIndex
+ case 4:
+ if wireType != 2 {
+ return fmt.Errorf("proto: wrong wireType = %d for field Source", wireType)
+ }
+ var stringLen uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ stringLen |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ intStringLen := int(stringLen)
+ if intStringLen < 0 {
+ return ErrInvalidLengthLog
+ }
+ postIndex := iNdEx + intStringLen
+ if postIndex > l {
+ return io.ErrUnexpectedEOF
+ }
+ s := string(data[iNdEx:postIndex])
+ m.Source = &s
+ iNdEx = postIndex
+ default:
+ iNdEx = preIndex
+ skippy, err := skipLog(data[iNdEx:])
+ if err != nil {
+ return err
+ }
+ if skippy < 0 {
+ return ErrInvalidLengthLog
+ }
+ if (iNdEx + skippy) > l {
+ return io.ErrUnexpectedEOF
+ }
+ m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...)
+ iNdEx += skippy
+ }
+ }
+
+ if iNdEx > l {
+ return io.ErrUnexpectedEOF
+ }
+ return nil
+}
+
+// Unmarshal data to LogGroupList
+func (m *LogGroupList) Unmarshal(data []byte) error {
+ l := len(data)
+ iNdEx := 0
+ for iNdEx < l {
+ preIndex := iNdEx
+ var wire uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ wire |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ fieldNum := int32(wire >> 3)
+ wireType := int(wire & 0x7)
+ if wireType == 4 {
+ return fmt.Errorf("proto: LogGroupList: wiretype end group for non-group")
+ }
+ if fieldNum <= 0 {
+ return fmt.Errorf("proto: LogGroupList: illegal tag %d (wire type %d)", fieldNum, wire)
+ }
+ switch fieldNum {
+ case 1:
+ if wireType != 2 {
+ return fmt.Errorf("proto: wrong wireType = %d for field LogGroups", wireType)
+ }
+ var msglen int
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ msglen |= (int(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ if msglen < 0 {
+ return ErrInvalidLengthLog
+ }
+ postIndex := iNdEx + msglen
+ if postIndex > l {
+ return io.ErrUnexpectedEOF
+ }
+ m.LogGroups = append(m.LogGroups, &LogGroup{})
+ if err := m.LogGroups[len(m.LogGroups)-1].Unmarshal(data[iNdEx:postIndex]); err != nil {
+ return err
+ }
+ iNdEx = postIndex
+ default:
+ iNdEx = preIndex
+ skippy, err := skipLog(data[iNdEx:])
+ if err != nil {
+ return err
+ }
+ if skippy < 0 {
+ return ErrInvalidLengthLog
+ }
+ if (iNdEx + skippy) > l {
+ return io.ErrUnexpectedEOF
+ }
+ m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...)
+ iNdEx += skippy
+ }
+ }
+
+ if iNdEx > l {
+ return io.ErrUnexpectedEOF
+ }
+ return nil
+}
+
+func skipLog(data []byte) (n int, err error) {
+ l := len(data)
+ iNdEx := 0
+ for iNdEx < l {
+ var wire uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return 0, ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return 0, io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ wire |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ wireType := int(wire & 0x7)
+ switch wireType {
+ case 0:
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return 0, ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return 0, io.ErrUnexpectedEOF
+ }
+ iNdEx++
+ if data[iNdEx-1] < 0x80 {
+ break
+ }
+ }
+ return iNdEx, nil
+ case 1:
+ iNdEx += 8
+ return iNdEx, nil
+ case 2:
+ var length int
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return 0, ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return 0, io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ length |= (int(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ iNdEx += length
+ if length < 0 {
+ return 0, ErrInvalidLengthLog
+ }
+ return iNdEx, nil
+ case 3:
+ for {
+ var innerWire uint64
+ var start = iNdEx
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return 0, ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return 0, io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ innerWire |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ innerWireType := int(innerWire & 0x7)
+ if innerWireType == 4 {
+ break
+ }
+ next, err := skipLog(data[start:])
+ if err != nil {
+ return 0, err
+ }
+ iNdEx = start + next
+ }
+ return iNdEx, nil
+ case 4:
+ return iNdEx, nil
+ case 5:
+ iNdEx += 4
+ return iNdEx, nil
+ default:
+ return 0, fmt.Errorf("proto: illegal wireType %d", wireType)
+ }
+ }
+ panic("unreachable")
+}
diff --git a/pkg/logs/alils/log_config.go b/pkg/logs/alils/log_config.go
new file mode 100755
index 00000000..e8564efb
--- /dev/null
+++ b/pkg/logs/alils/log_config.go
@@ -0,0 +1,42 @@
+package alils
+
+// InputDetail define log detail
+type InputDetail struct {
+ LogType string `json:"logType"`
+ LogPath string `json:"logPath"`
+ FilePattern string `json:"filePattern"`
+ LocalStorage bool `json:"localStorage"`
+ TimeFormat string `json:"timeFormat"`
+ LogBeginRegex string `json:"logBeginRegex"`
+ Regex string `json:"regex"`
+ Keys []string `json:"key"`
+ FilterKeys []string `json:"filterKey"`
+ FilterRegex []string `json:"filterRegex"`
+ TopicFormat string `json:"topicFormat"`
+}
+
+// OutputDetail define the output detail
+type OutputDetail struct {
+ Endpoint string `json:"endpoint"`
+ LogStoreName string `json:"logstoreName"`
+}
+
+// LogConfig define Log Config
+type LogConfig struct {
+ Name string `json:"configName"`
+ InputType string `json:"inputType"`
+ InputDetail InputDetail `json:"inputDetail"`
+ OutputType string `json:"outputType"`
+ OutputDetail OutputDetail `json:"outputDetail"`
+
+ CreateTime uint32
+ LastModifyTime uint32
+
+ project *LogProject
+}
+
+// GetAppliedMachineGroup returns applied machine group of this config.
+func (c *LogConfig) GetAppliedMachineGroup(confName string) (groupNames []string, err error) {
+ groupNames, err = c.project.GetAppliedMachineGroups(c.Name)
+ return
+}
diff --git a/pkg/logs/alils/log_project.go b/pkg/logs/alils/log_project.go
new file mode 100755
index 00000000..59db8cbf
--- /dev/null
+++ b/pkg/logs/alils/log_project.go
@@ -0,0 +1,819 @@
+/*
+Package alils implements the SDK(v0.5.0) of Simple Log Service(abbr. SLS).
+
+For more description about SLS, please read this article:
+http://gitlab.alibaba-inc.com/sls/doc.
+*/
+package alils
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/http/httputil"
+)
+
+// Error message in SLS HTTP response.
+type errorMessage struct {
+ Code string `json:"errorCode"`
+ Message string `json:"errorMessage"`
+}
+
+// LogProject Define the Ali Project detail
+type LogProject struct {
+ Name string // Project name
+ Endpoint string // IP or hostname of SLS endpoint
+ AccessKeyID string
+ AccessKeySecret string
+}
+
+// NewLogProject creates a new SLS project.
+func NewLogProject(name, endpoint, AccessKeyID, accessKeySecret string) (p *LogProject, err error) {
+ p = &LogProject{
+ Name: name,
+ Endpoint: endpoint,
+ AccessKeyID: AccessKeyID,
+ AccessKeySecret: accessKeySecret,
+ }
+ return p, nil
+}
+
+// ListLogStore returns all logstore names of project p.
+func (p *LogProject) ListLogStore() (storeNames []string, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ uri := fmt.Sprintf("/logstores")
+ r, err := request(p, "GET", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to list logstore")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ type Body struct {
+ Count int
+ LogStores []string
+ }
+ body := &Body{}
+
+ err = json.Unmarshal(buf, body)
+ if err != nil {
+ return
+ }
+
+ storeNames = body.LogStores
+
+ return
+}
+
+// GetLogStore returns logstore according by logstore name.
+func (p *LogProject) GetLogStore(name string) (s *LogStore, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ r, err := request(p, "GET", "/logstores/"+name, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to get logstore")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ s = &LogStore{}
+ err = json.Unmarshal(buf, s)
+ if err != nil {
+ return
+ }
+ s.project = p
+ return
+}
+
+// CreateLogStore creates a new logstore in SLS,
+// where name is logstore name,
+// and ttl is time-to-live(in day) of logs,
+// and shardCnt is the number of shards.
+func (p *LogProject) CreateLogStore(name string, ttl, shardCnt int) (err error) {
+
+ type Body struct {
+ Name string `json:"logstoreName"`
+ TTL int `json:"ttl"`
+ ShardCount int `json:"shardCount"`
+ }
+
+ store := &Body{
+ Name: name,
+ TTL: ttl,
+ ShardCount: shardCnt,
+ }
+
+ body, err := json.Marshal(store)
+ if err != nil {
+ return
+ }
+
+ h := map[string]string{
+ "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
+ "Content-Type": "application/json",
+ "Accept-Encoding": "deflate", // TODO: support lz4
+ }
+
+ r, err := request(p, "POST", "/logstores", h, body)
+ if err != nil {
+ return
+ }
+
+ body, err = ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(body, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to create logstore")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ return
+}
+
+// DeleteLogStore deletes a logstore according by logstore name.
+func (p *LogProject) DeleteLogStore(name string) (err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ r, err := request(p, "DELETE", "/logstores/"+name, h, nil)
+ if err != nil {
+ return
+ }
+
+ body, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(body, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to delete logstore")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+ return
+}
+
+// UpdateLogStore updates a logstore according by logstore name,
+// obviously we can't modify the logstore name itself.
+func (p *LogProject) UpdateLogStore(name string, ttl, shardCnt int) (err error) {
+
+ type Body struct {
+ Name string `json:"logstoreName"`
+ TTL int `json:"ttl"`
+ ShardCount int `json:"shardCount"`
+ }
+
+ store := &Body{
+ Name: name,
+ TTL: ttl,
+ ShardCount: shardCnt,
+ }
+
+ body, err := json.Marshal(store)
+ if err != nil {
+ return
+ }
+
+ h := map[string]string{
+ "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
+ "Content-Type": "application/json",
+ "Accept-Encoding": "deflate", // TODO: support lz4
+ }
+
+ r, err := request(p, "PUT", "/logstores", h, body)
+ if err != nil {
+ return
+ }
+
+ body, err = ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(body, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to update logstore")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ return
+}
+
+// ListMachineGroup returns machine group name list and the total number of machine groups.
+// The offset starts from 0 and the size is the max number of machine groups could be returned.
+func (p *LogProject) ListMachineGroup(offset, size int) (m []string, total int, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ if size <= 0 {
+ size = 500
+ }
+
+ uri := fmt.Sprintf("/machinegroups?offset=%v&size=%v", offset, size)
+ r, err := request(p, "GET", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to list machine group")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ type Body struct {
+ MachineGroups []string
+ Count int
+ Total int
+ }
+ body := &Body{}
+
+ err = json.Unmarshal(buf, body)
+ if err != nil {
+ return
+ }
+
+ m = body.MachineGroups
+ total = body.Total
+
+ return
+}
+
+// GetMachineGroup retruns machine group according by machine group name.
+func (p *LogProject) GetMachineGroup(name string) (m *MachineGroup, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ r, err := request(p, "GET", "/machinegroups/"+name, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to get machine group:%v", name)
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ m = &MachineGroup{}
+ err = json.Unmarshal(buf, m)
+ if err != nil {
+ return
+ }
+ m.project = p
+ return
+}
+
+// CreateMachineGroup creates a new machine group in SLS.
+func (p *LogProject) CreateMachineGroup(m *MachineGroup) (err error) {
+
+ body, err := json.Marshal(m)
+ if err != nil {
+ return
+ }
+
+ h := map[string]string{
+ "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
+ "Content-Type": "application/json",
+ "Accept-Encoding": "deflate", // TODO: support lz4
+ }
+
+ r, err := request(p, "POST", "/machinegroups", h, body)
+ if err != nil {
+ return
+ }
+
+ body, err = ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(body, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to create machine group")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ return
+}
+
+// UpdateMachineGroup updates a machine group.
+func (p *LogProject) UpdateMachineGroup(m *MachineGroup) (err error) {
+
+ body, err := json.Marshal(m)
+ if err != nil {
+ return
+ }
+
+ h := map[string]string{
+ "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
+ "Content-Type": "application/json",
+ "Accept-Encoding": "deflate", // TODO: support lz4
+ }
+
+ r, err := request(p, "PUT", "/machinegroups/"+m.Name, h, body)
+ if err != nil {
+ return
+ }
+
+ body, err = ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(body, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to update machine group")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ return
+}
+
+// DeleteMachineGroup deletes machine group according machine group name.
+func (p *LogProject) DeleteMachineGroup(name string) (err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ r, err := request(p, "DELETE", "/machinegroups/"+name, h, nil)
+ if err != nil {
+ return
+ }
+
+ body, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(body, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to delete machine group")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+ return
+}
+
+// ListConfig returns config names list and the total number of configs.
+// The offset starts from 0 and the size is the max number of configs could be returned.
+func (p *LogProject) ListConfig(offset, size int) (cfgNames []string, total int, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ if size <= 0 {
+ size = 100
+ }
+
+ uri := fmt.Sprintf("/configs?offset=%v&size=%v", offset, size)
+ r, err := request(p, "GET", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to delete machine group")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ type Body struct {
+ Total int
+ Configs []string
+ }
+ body := &Body{}
+
+ err = json.Unmarshal(buf, body)
+ if err != nil {
+ return
+ }
+
+ cfgNames = body.Configs
+ total = body.Total
+ return
+}
+
+// GetConfig returns config according by config name.
+func (p *LogProject) GetConfig(name string) (c *LogConfig, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ r, err := request(p, "GET", "/configs/"+name, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to delete config")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ c = &LogConfig{}
+ err = json.Unmarshal(buf, c)
+ if err != nil {
+ return
+ }
+ c.project = p
+ return
+}
+
+// UpdateConfig updates a config.
+func (p *LogProject) UpdateConfig(c *LogConfig) (err error) {
+
+ body, err := json.Marshal(c)
+ if err != nil {
+ return
+ }
+
+ h := map[string]string{
+ "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
+ "Content-Type": "application/json",
+ "Accept-Encoding": "deflate", // TODO: support lz4
+ }
+
+ r, err := request(p, "PUT", "/configs/"+c.Name, h, body)
+ if err != nil {
+ return
+ }
+
+ body, err = ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(body, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to update config")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ return
+}
+
+// CreateConfig creates a new config in SLS.
+func (p *LogProject) CreateConfig(c *LogConfig) (err error) {
+
+ body, err := json.Marshal(c)
+ if err != nil {
+ return
+ }
+
+ h := map[string]string{
+ "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
+ "Content-Type": "application/json",
+ "Accept-Encoding": "deflate", // TODO: support lz4
+ }
+
+ r, err := request(p, "POST", "/configs", h, body)
+ if err != nil {
+ return
+ }
+
+ body, err = ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(body, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to update config")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ return
+}
+
+// DeleteConfig deletes a config according by config name.
+func (p *LogProject) DeleteConfig(name string) (err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ r, err := request(p, "DELETE", "/configs/"+name, h, nil)
+ if err != nil {
+ return
+ }
+
+ body, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(body, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to delete config")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+ return
+}
+
+// GetAppliedMachineGroups returns applied machine group names list according config name.
+func (p *LogProject) GetAppliedMachineGroups(confName string) (groupNames []string, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ uri := fmt.Sprintf("/configs/%v/machinegroups", confName)
+ r, err := request(p, "GET", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to get applied machine groups")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ type Body struct {
+ Count int
+ Machinegroups []string
+ }
+
+ body := &Body{}
+ err = json.Unmarshal(buf, body)
+ if err != nil {
+ return
+ }
+
+ groupNames = body.Machinegroups
+ return
+}
+
+// GetAppliedConfigs returns applied config names list according machine group name groupName.
+func (p *LogProject) GetAppliedConfigs(groupName string) (confNames []string, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ uri := fmt.Sprintf("/machinegroups/%v/configs", groupName)
+ r, err := request(p, "GET", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to applied configs")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ type Cfg struct {
+ Count int `json:"count"`
+ Configs []string `json:"configs"`
+ }
+
+ body := &Cfg{}
+ err = json.Unmarshal(buf, body)
+ if err != nil {
+ return
+ }
+
+ confNames = body.Configs
+ return
+}
+
+// ApplyConfigToMachineGroup applies config to machine group.
+func (p *LogProject) ApplyConfigToMachineGroup(confName, groupName string) (err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName)
+ r, err := request(p, "PUT", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to apply config to machine group")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+ return
+}
+
+// RemoveConfigFromMachineGroup removes config from machine group.
+func (p *LogProject) RemoveConfigFromMachineGroup(confName, groupName string) (err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName)
+ r, err := request(p, "DELETE", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to remove config from machine group")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+ return
+}
diff --git a/pkg/logs/alils/log_store.go b/pkg/logs/alils/log_store.go
new file mode 100755
index 00000000..fa502736
--- /dev/null
+++ b/pkg/logs/alils/log_store.go
@@ -0,0 +1,271 @@
+package alils
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/http/httputil"
+ "strconv"
+
+ lz4 "github.com/cloudflare/golz4"
+ "github.com/gogo/protobuf/proto"
+)
+
+// LogStore Store the logs
+type LogStore struct {
+ Name string `json:"logstoreName"`
+ TTL int
+ ShardCount int
+
+ CreateTime uint32
+ LastModifyTime uint32
+
+ project *LogProject
+}
+
+// Shard define the Log Shard
+type Shard struct {
+ ShardID int `json:"shardID"`
+}
+
+// ListShards returns shard id list of this logstore.
+func (s *LogStore) ListShards() (shardIDs []int, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ uri := fmt.Sprintf("/logstores/%v/shards", s.Name)
+ r, err := request(s.project, "GET", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to list logstore")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Println(dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ var shards []*Shard
+ err = json.Unmarshal(buf, &shards)
+ if err != nil {
+ return
+ }
+
+ for _, v := range shards {
+ shardIDs = append(shardIDs, v.ShardID)
+ }
+ return
+}
+
+// PutLogs put logs into logstore.
+// The callers should transform user logs into LogGroup.
+func (s *LogStore) PutLogs(lg *LogGroup) (err error) {
+ body, err := proto.Marshal(lg)
+ if err != nil {
+ return
+ }
+
+ // Compresse body with lz4
+ out := make([]byte, lz4.CompressBound(body))
+ n, err := lz4.Compress(body, out)
+ if err != nil {
+ return
+ }
+
+ h := map[string]string{
+ "x-sls-compresstype": "lz4",
+ "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
+ "Content-Type": "application/x-protobuf",
+ }
+
+ uri := fmt.Sprintf("/logstores/%v", s.Name)
+ r, err := request(s.project, "POST", uri, h, out[:n])
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to put logs")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Println(dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+ return
+}
+
+// GetCursor gets log cursor of one shard specified by shardID.
+// The from can be in three form: a) unix timestamp in seccond, b) "begin", c) "end".
+// For more detail please read: http://gitlab.alibaba-inc.com/sls/doc/blob/master/api/shard.md#logstore
+func (s *LogStore) GetCursor(shardID int, from string) (cursor string, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ uri := fmt.Sprintf("/logstores/%v/shards/%v?type=cursor&from=%v",
+ s.Name, shardID, from)
+
+ r, err := request(s.project, "GET", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to get cursor")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Println(dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ type Body struct {
+ Cursor string
+ }
+ body := &Body{}
+
+ err = json.Unmarshal(buf, body)
+ if err != nil {
+ return
+ }
+ cursor = body.Cursor
+ return
+}
+
+// GetLogsBytes gets logs binary data from shard specified by shardID according cursor.
+// The logGroupMaxCount is the max number of logGroup could be returned.
+// The nextCursor is the next curosr can be used to read logs at next time.
+func (s *LogStore) GetLogsBytes(shardID int, cursor string,
+ logGroupMaxCount int) (out []byte, nextCursor string, err error) {
+
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ "Accept": "application/x-protobuf",
+ "Accept-Encoding": "lz4",
+ }
+
+ uri := fmt.Sprintf("/logstores/%v/shards/%v?type=logs&cursor=%v&count=%v",
+ s.Name, shardID, cursor, logGroupMaxCount)
+
+ r, err := request(s.project, "GET", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to get cursor")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Println(dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ v, ok := r.Header["X-Sls-Compresstype"]
+ if !ok || len(v) == 0 {
+ err = fmt.Errorf("can't find 'x-sls-compresstype' header")
+ return
+ }
+ if v[0] != "lz4" {
+ err = fmt.Errorf("unexpected compress type:%v", v[0])
+ return
+ }
+
+ v, ok = r.Header["X-Sls-Cursor"]
+ if !ok || len(v) == 0 {
+ err = fmt.Errorf("can't find 'x-sls-cursor' header")
+ return
+ }
+ nextCursor = v[0]
+
+ v, ok = r.Header["X-Sls-Bodyrawsize"]
+ if !ok || len(v) == 0 {
+ err = fmt.Errorf("can't find 'x-sls-bodyrawsize' header")
+ return
+ }
+ bodyRawSize, err := strconv.Atoi(v[0])
+ if err != nil {
+ return
+ }
+
+ out = make([]byte, bodyRawSize)
+ err = lz4.Uncompress(buf, out)
+ if err != nil {
+ return
+ }
+
+ return
+}
+
+// LogsBytesDecode decodes logs binary data retruned by GetLogsBytes API
+func LogsBytesDecode(data []byte) (gl *LogGroupList, err error) {
+
+ gl = &LogGroupList{}
+ err = proto.Unmarshal(data, gl)
+ if err != nil {
+ return
+ }
+
+ return
+}
+
+// GetLogs gets logs from shard specified by shardID according cursor.
+// The logGroupMaxCount is the max number of logGroup could be returned.
+// The nextCursor is the next curosr can be used to read logs at next time.
+func (s *LogStore) GetLogs(shardID int, cursor string,
+ logGroupMaxCount int) (gl *LogGroupList, nextCursor string, err error) {
+
+ out, nextCursor, err := s.GetLogsBytes(shardID, cursor, logGroupMaxCount)
+ if err != nil {
+ return
+ }
+
+ gl, err = LogsBytesDecode(out)
+ if err != nil {
+ return
+ }
+
+ return
+}
diff --git a/pkg/logs/alils/machine_group.go b/pkg/logs/alils/machine_group.go
new file mode 100755
index 00000000..b6c69a14
--- /dev/null
+++ b/pkg/logs/alils/machine_group.go
@@ -0,0 +1,91 @@
+package alils
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/http/httputil"
+)
+
+// MachineGroupAttribute define the Attribute
+type MachineGroupAttribute struct {
+ ExternalName string `json:"externalName"`
+ TopicName string `json:"groupTopic"`
+}
+
+// MachineGroup define the machine Group
+type MachineGroup struct {
+ Name string `json:"groupName"`
+ Type string `json:"groupType"`
+ MachineIDType string `json:"machineIdentifyType"`
+ MachineIDList []string `json:"machineList"`
+
+ Attribute MachineGroupAttribute `json:"groupAttribute"`
+
+ CreateTime uint32
+ LastModifyTime uint32
+
+ project *LogProject
+}
+
+// Machine define the Machine
+type Machine struct {
+ IP string
+ UniqueID string `json:"machine-uniqueid"`
+ UserdefinedID string `json:"userdefined-id"`
+}
+
+// MachineList define the Machine List
+type MachineList struct {
+ Total int
+ Machines []*Machine
+}
+
+// ListMachines returns machine list of this machine group.
+func (m *MachineGroup) ListMachines() (ms []*Machine, total int, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ uri := fmt.Sprintf("/machinegroups/%v/machines", m.Name)
+ r, err := request(m.project, "GET", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to remove config from machine group")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Println(dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ body := &MachineList{}
+ err = json.Unmarshal(buf, body)
+ if err != nil {
+ return
+ }
+
+ ms = body.Machines
+ total = body.Total
+
+ return
+}
+
+// GetAppliedConfigs returns applied configs of this machine group.
+func (m *MachineGroup) GetAppliedConfigs() (confNames []string, err error) {
+ confNames, err = m.project.GetAppliedConfigs(m.Name)
+ return
+}
diff --git a/pkg/logs/alils/request.go b/pkg/logs/alils/request.go
new file mode 100755
index 00000000..50d9c43c
--- /dev/null
+++ b/pkg/logs/alils/request.go
@@ -0,0 +1,62 @@
+package alils
+
+import (
+ "bytes"
+ "crypto/md5"
+ "fmt"
+ "net/http"
+)
+
+// request sends a request to SLS.
+func request(project *LogProject, method, uri string, headers map[string]string,
+ body []byte) (resp *http.Response, err error) {
+
+ // The caller should provide 'x-sls-bodyrawsize' header
+ if _, ok := headers["x-sls-bodyrawsize"]; !ok {
+ err = fmt.Errorf("Can't find 'x-sls-bodyrawsize' header")
+ return
+ }
+
+ // SLS public request headers
+ headers["Host"] = project.Name + "." + project.Endpoint
+ headers["Date"] = nowRFC1123()
+ headers["x-sls-apiversion"] = version
+ headers["x-sls-signaturemethod"] = signatureMethod
+ if body != nil {
+ bodyMD5 := fmt.Sprintf("%X", md5.Sum(body))
+ headers["Content-MD5"] = bodyMD5
+
+ if _, ok := headers["Content-Type"]; !ok {
+ err = fmt.Errorf("Can't find 'Content-Type' header")
+ return
+ }
+ }
+
+ // Calc Authorization
+ // Authorization = "SLS :"
+ digest, err := signature(project, method, uri, headers)
+ if err != nil {
+ return
+ }
+ auth := fmt.Sprintf("SLS %v:%v", project.AccessKeyID, digest)
+ headers["Authorization"] = auth
+
+ // Initialize http request
+ reader := bytes.NewReader(body)
+ urlStr := fmt.Sprintf("http://%v.%v%v", project.Name, project.Endpoint, uri)
+ req, err := http.NewRequest(method, urlStr, reader)
+ if err != nil {
+ return
+ }
+ for k, v := range headers {
+ req.Header.Add(k, v)
+ }
+
+ // Get ready to do request
+ resp, err = http.DefaultClient.Do(req)
+ if err != nil {
+ return
+ }
+
+ return
+}
diff --git a/pkg/logs/alils/signature.go b/pkg/logs/alils/signature.go
new file mode 100755
index 00000000..2d611307
--- /dev/null
+++ b/pkg/logs/alils/signature.go
@@ -0,0 +1,111 @@
+package alils
+
+import (
+ "crypto/hmac"
+ "crypto/sha1"
+ "encoding/base64"
+ "fmt"
+ "net/url"
+ "sort"
+ "strings"
+ "time"
+)
+
+// GMT location
+var gmtLoc = time.FixedZone("GMT", 0)
+
+// NowRFC1123 returns now time in RFC1123 format with GMT timezone,
+// eg. "Mon, 02 Jan 2006 15:04:05 GMT".
+func nowRFC1123() string {
+ return time.Now().In(gmtLoc).Format(time.RFC1123)
+}
+
+// signature calculates a request's signature digest.
+func signature(project *LogProject, method, uri string,
+ headers map[string]string) (digest string, err error) {
+ var contentMD5, contentType, date, canoHeaders, canoResource string
+ var slsHeaderKeys sort.StringSlice
+
+ // SignString = VERB + "\n"
+ // + CONTENT-MD5 + "\n"
+ // + CONTENT-TYPE + "\n"
+ // + DATE + "\n"
+ // + CanonicalizedSLSHeaders + "\n"
+ // + CanonicalizedResource
+
+ if val, ok := headers["Content-MD5"]; ok {
+ contentMD5 = val
+ }
+
+ if val, ok := headers["Content-Type"]; ok {
+ contentType = val
+ }
+
+ date, ok := headers["Date"]
+ if !ok {
+ err = fmt.Errorf("Can't find 'Date' header")
+ return
+ }
+
+ // Calc CanonicalizedSLSHeaders
+ slsHeaders := make(map[string]string, len(headers))
+ for k, v := range headers {
+ l := strings.TrimSpace(strings.ToLower(k))
+ if strings.HasPrefix(l, "x-sls-") {
+ slsHeaders[l] = strings.TrimSpace(v)
+ slsHeaderKeys = append(slsHeaderKeys, l)
+ }
+ }
+
+ sort.Sort(slsHeaderKeys)
+ for i, k := range slsHeaderKeys {
+ canoHeaders += k + ":" + slsHeaders[k]
+ if i+1 < len(slsHeaderKeys) {
+ canoHeaders += "\n"
+ }
+ }
+
+ // Calc CanonicalizedResource
+ u, err := url.Parse(uri)
+ if err != nil {
+ return
+ }
+
+ canoResource += url.QueryEscape(u.Path)
+ if u.RawQuery != "" {
+ var keys sort.StringSlice
+
+ vals := u.Query()
+ for k := range vals {
+ keys = append(keys, k)
+ }
+
+ sort.Sort(keys)
+ canoResource += "?"
+ for i, k := range keys {
+ if i > 0 {
+ canoResource += "&"
+ }
+
+ for _, v := range vals[k] {
+ canoResource += k + "=" + v
+ }
+ }
+ }
+
+ signStr := method + "\n" +
+ contentMD5 + "\n" +
+ contentType + "\n" +
+ date + "\n" +
+ canoHeaders + "\n" +
+ canoResource
+
+ // Signature = base64(hmac-sha1(UTF8-Encoding-Of(SignString),AccessKeySecret))
+ mac := hmac.New(sha1.New, []byte(project.AccessKeySecret))
+ _, err = mac.Write([]byte(signStr))
+ if err != nil {
+ return
+ }
+ digest = base64.StdEncoding.EncodeToString(mac.Sum(nil))
+ return
+}
diff --git a/pkg/logs/conn.go b/pkg/logs/conn.go
new file mode 100644
index 00000000..74c458ab
--- /dev/null
+++ b/pkg/logs/conn.go
@@ -0,0 +1,119 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logs
+
+import (
+ "encoding/json"
+ "io"
+ "net"
+ "time"
+)
+
+// connWriter implements LoggerInterface.
+// it writes messages in keep-live tcp connection.
+type connWriter struct {
+ lg *logWriter
+ innerWriter io.WriteCloser
+ ReconnectOnMsg bool `json:"reconnectOnMsg"`
+ Reconnect bool `json:"reconnect"`
+ Net string `json:"net"`
+ Addr string `json:"addr"`
+ Level int `json:"level"`
+}
+
+// NewConn create new ConnWrite returning as LoggerInterface.
+func NewConn() Logger {
+ conn := new(connWriter)
+ conn.Level = LevelTrace
+ return conn
+}
+
+// Init init connection writer with json config.
+// json config only need key "level".
+func (c *connWriter) Init(jsonConfig string) error {
+ return json.Unmarshal([]byte(jsonConfig), c)
+}
+
+// WriteMsg write message in connection.
+// if connection is down, try to re-connect.
+func (c *connWriter) WriteMsg(when time.Time, msg string, level int) error {
+ if level > c.Level {
+ return nil
+ }
+ if c.needToConnectOnMsg() {
+ err := c.connect()
+ if err != nil {
+ return err
+ }
+ }
+
+ if c.ReconnectOnMsg {
+ defer c.innerWriter.Close()
+ }
+
+ _, err := c.lg.writeln(when, msg)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+// Flush implementing method. empty.
+func (c *connWriter) Flush() {
+
+}
+
+// Destroy destroy connection writer and close tcp listener.
+func (c *connWriter) Destroy() {
+ if c.innerWriter != nil {
+ c.innerWriter.Close()
+ }
+}
+
+func (c *connWriter) connect() error {
+ if c.innerWriter != nil {
+ c.innerWriter.Close()
+ c.innerWriter = nil
+ }
+
+ conn, err := net.Dial(c.Net, c.Addr)
+ if err != nil {
+ return err
+ }
+
+ if tcpConn, ok := conn.(*net.TCPConn); ok {
+ tcpConn.SetKeepAlive(true)
+ }
+
+ c.innerWriter = conn
+ c.lg = newLogWriter(conn)
+ return nil
+}
+
+func (c *connWriter) needToConnectOnMsg() bool {
+ if c.Reconnect {
+ return true
+ }
+
+ if c.innerWriter == nil {
+ return true
+ }
+
+ return c.ReconnectOnMsg
+}
+
+func init() {
+ Register(AdapterConn, NewConn)
+}
diff --git a/pkg/logs/conn_test.go b/pkg/logs/conn_test.go
new file mode 100644
index 00000000..bb377d41
--- /dev/null
+++ b/pkg/logs/conn_test.go
@@ -0,0 +1,79 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logs
+
+import (
+ "net"
+ "os"
+ "testing"
+)
+
+// ConnTCPListener takes a TCP listener and accepts n TCP connections
+// Returns connections using connChan
+func connTCPListener(t *testing.T, n int, ln net.Listener, connChan chan<- net.Conn) {
+
+ // Listen and accept n incoming connections
+ for i := 0; i < n; i++ {
+ conn, err := ln.Accept()
+ if err != nil {
+ t.Log("Error accepting connection: ", err.Error())
+ os.Exit(1)
+ }
+
+ // Send accepted connection to channel
+ connChan <- conn
+ }
+ ln.Close()
+ close(connChan)
+}
+
+func TestConn(t *testing.T) {
+ log := NewLogger(1000)
+ log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`)
+ log.Informational("informational")
+}
+
+func TestReconnect(t *testing.T) {
+ // Setup connection listener
+ newConns := make(chan net.Conn)
+ connNum := 2
+ ln, err := net.Listen("tcp", ":6002")
+ if err != nil {
+ t.Log("Error listening:", err.Error())
+ os.Exit(1)
+ }
+ go connTCPListener(t, connNum, ln, newConns)
+
+ // Setup logger
+ log := NewLogger(1000)
+ log.SetPrefix("test")
+ log.SetLogger(AdapterConn, `{"net":"tcp","reconnect":true,"level":6,"addr":":6002"}`)
+ log.Informational("informational 1")
+
+ // Refuse first connection
+ first := <-newConns
+ first.Close()
+
+ // Send another log after conn closed
+ log.Informational("informational 2")
+
+ // Check if there was a second connection attempt
+ select {
+ case second := <-newConns:
+ second.Close()
+ default:
+ t.Error("Did not reconnect")
+ }
+}
diff --git a/pkg/logs/console.go b/pkg/logs/console.go
new file mode 100644
index 00000000..3dcaee1d
--- /dev/null
+++ b/pkg/logs/console.go
@@ -0,0 +1,99 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logs
+
+import (
+ "encoding/json"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/shiena/ansicolor"
+)
+
+// brush is a color join function
+type brush func(string) string
+
+// newBrush return a fix color Brush
+func newBrush(color string) brush {
+ pre := "\033["
+ reset := "\033[0m"
+ return func(text string) string {
+ return pre + color + "m" + text + reset
+ }
+}
+
+var colors = []brush{
+ newBrush("1;37"), // Emergency white
+ newBrush("1;36"), // Alert cyan
+ newBrush("1;35"), // Critical magenta
+ newBrush("1;31"), // Error red
+ newBrush("1;33"), // Warning yellow
+ newBrush("1;32"), // Notice green
+ newBrush("1;34"), // Informational blue
+ newBrush("1;44"), // Debug Background blue
+}
+
+// consoleWriter implements LoggerInterface and writes messages to terminal.
+type consoleWriter struct {
+ lg *logWriter
+ Level int `json:"level"`
+ Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color
+}
+
+// NewConsole create ConsoleWriter returning as LoggerInterface.
+func NewConsole() Logger {
+ cw := &consoleWriter{
+ lg: newLogWriter(ansicolor.NewAnsiColorWriter(os.Stdout)),
+ Level: LevelDebug,
+ Colorful: true,
+ }
+ return cw
+}
+
+// Init init console logger.
+// jsonConfig like '{"level":LevelTrace}'.
+func (c *consoleWriter) Init(jsonConfig string) error {
+ if len(jsonConfig) == 0 {
+ return nil
+ }
+ return json.Unmarshal([]byte(jsonConfig), c)
+}
+
+// WriteMsg write message in console.
+func (c *consoleWriter) WriteMsg(when time.Time, msg string, level int) error {
+ if level > c.Level {
+ return nil
+ }
+ if c.Colorful {
+ msg = strings.Replace(msg, levelPrefix[level], colors[level](levelPrefix[level]), 1)
+ }
+ c.lg.writeln(when, msg)
+ return nil
+}
+
+// Destroy implementing method. empty.
+func (c *consoleWriter) Destroy() {
+
+}
+
+// Flush implementing method. empty.
+func (c *consoleWriter) Flush() {
+
+}
+
+func init() {
+ Register(AdapterConsole, NewConsole)
+}
diff --git a/pkg/logs/console_test.go b/pkg/logs/console_test.go
new file mode 100644
index 00000000..4bc45f57
--- /dev/null
+++ b/pkg/logs/console_test.go
@@ -0,0 +1,64 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logs
+
+import (
+ "testing"
+ "time"
+)
+
+// Try each log level in decreasing order of priority.
+func testConsoleCalls(bl *BeeLogger) {
+ bl.Emergency("emergency")
+ bl.Alert("alert")
+ bl.Critical("critical")
+ bl.Error("error")
+ bl.Warning("warning")
+ bl.Notice("notice")
+ bl.Informational("informational")
+ bl.Debug("debug")
+}
+
+// Test console logging by visually comparing the lines being output with and
+// without a log level specification.
+func TestConsole(t *testing.T) {
+ log1 := NewLogger(10000)
+ log1.EnableFuncCallDepth(true)
+ log1.SetLogger("console", "")
+ testConsoleCalls(log1)
+
+ log2 := NewLogger(100)
+ log2.SetLogger("console", `{"level":3}`)
+ testConsoleCalls(log2)
+}
+
+// Test console without color
+func TestConsoleNoColor(t *testing.T) {
+ log := NewLogger(100)
+ log.SetLogger("console", `{"color":false}`)
+ testConsoleCalls(log)
+}
+
+// Test console async
+func TestConsoleAsync(t *testing.T) {
+ log := NewLogger(100)
+ log.SetLogger("console")
+ log.Async()
+ //log.Close()
+ testConsoleCalls(log)
+ for len(log.msgChan) != 0 {
+ time.Sleep(1 * time.Millisecond)
+ }
+}
diff --git a/pkg/logs/es/es.go b/pkg/logs/es/es.go
new file mode 100644
index 00000000..2b7b1710
--- /dev/null
+++ b/pkg/logs/es/es.go
@@ -0,0 +1,102 @@
+package es
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/elastic/go-elasticsearch/v6"
+ "github.com/elastic/go-elasticsearch/v6/esapi"
+
+ "github.com/astaxie/beego/logs"
+)
+
+// NewES return a LoggerInterface
+func NewES() logs.Logger {
+ cw := &esLogger{
+ Level: logs.LevelDebug,
+ }
+ return cw
+}
+
+// esLogger will log msg into ES
+// before you using this implementation,
+// please import this package
+// usually means that you can import this package in your main package
+// for example, anonymous:
+// import _ "github.com/astaxie/beego/logs/es"
+type esLogger struct {
+ *elasticsearch.Client
+ DSN string `json:"dsn"`
+ Level int `json:"level"`
+}
+
+// {"dsn":"http://localhost:9200/","level":1}
+func (el *esLogger) Init(jsonconfig string) error {
+ err := json.Unmarshal([]byte(jsonconfig), el)
+ if err != nil {
+ return err
+ }
+ if el.DSN == "" {
+ return errors.New("empty dsn")
+ } else if u, err := url.Parse(el.DSN); err != nil {
+ return err
+ } else if u.Path == "" {
+ return errors.New("missing prefix")
+ } else {
+ conn, err := elasticsearch.NewClient(elasticsearch.Config{
+ Addresses: []string{el.DSN},
+ })
+ if err != nil {
+ return err
+ }
+ el.Client = conn
+ }
+ return nil
+}
+
+// WriteMsg will write the msg and level into es
+func (el *esLogger) WriteMsg(when time.Time, msg string, level int) error {
+ if level > el.Level {
+ return nil
+ }
+
+ idx := LogDocument{
+ Timestamp: when.Format(time.RFC3339),
+ Msg: msg,
+ }
+
+ body, err := json.Marshal(idx)
+ if err != nil {
+ return err
+ }
+ req := esapi.IndexRequest{
+ Index: fmt.Sprintf("%04d.%02d.%02d", when.Year(), when.Month(), when.Day()),
+ DocumentType: "logs",
+ Body: strings.NewReader(string(body)),
+ }
+ _, err = req.Do(context.Background(), el.Client)
+ return err
+}
+
+// Destroy is a empty method
+func (el *esLogger) Destroy() {
+}
+
+// Flush is a empty method
+func (el *esLogger) Flush() {
+
+}
+
+type LogDocument struct {
+ Timestamp string `json:"timestamp"`
+ Msg string `json:"msg"`
+}
+
+func init() {
+ logs.Register(logs.AdapterEs, NewES)
+}
diff --git a/pkg/logs/file.go b/pkg/logs/file.go
new file mode 100644
index 00000000..222db989
--- /dev/null
+++ b/pkg/logs/file.go
@@ -0,0 +1,409 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logs
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "path"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+)
+
+// fileLogWriter implements LoggerInterface.
+// It writes messages by lines limit, file size limit, or time frequency.
+type fileLogWriter struct {
+ sync.RWMutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize
+ // The opened file
+ Filename string `json:"filename"`
+ fileWriter *os.File
+
+ // Rotate at line
+ MaxLines int `json:"maxlines"`
+ maxLinesCurLines int
+
+ MaxFiles int `json:"maxfiles"`
+ MaxFilesCurFiles int
+
+ // Rotate at size
+ MaxSize int `json:"maxsize"`
+ maxSizeCurSize int
+
+ // Rotate daily
+ Daily bool `json:"daily"`
+ MaxDays int64 `json:"maxdays"`
+ dailyOpenDate int
+ dailyOpenTime time.Time
+
+ // Rotate hourly
+ Hourly bool `json:"hourly"`
+ MaxHours int64 `json:"maxhours"`
+ hourlyOpenDate int
+ hourlyOpenTime time.Time
+
+ Rotate bool `json:"rotate"`
+
+ Level int `json:"level"`
+
+ Perm string `json:"perm"`
+
+ RotatePerm string `json:"rotateperm"`
+
+ fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix
+}
+
+// newFileWriter create a FileLogWriter returning as LoggerInterface.
+func newFileWriter() Logger {
+ w := &fileLogWriter{
+ Daily: true,
+ MaxDays: 7,
+ Hourly: false,
+ MaxHours: 168,
+ Rotate: true,
+ RotatePerm: "0440",
+ Level: LevelTrace,
+ Perm: "0660",
+ MaxLines: 10000000,
+ MaxFiles: 999,
+ MaxSize: 1 << 28,
+ }
+ return w
+}
+
+// Init file logger with json config.
+// jsonConfig like:
+// {
+// "filename":"logs/beego.log",
+// "maxLines":10000,
+// "maxsize":1024,
+// "daily":true,
+// "maxDays":15,
+// "rotate":true,
+// "perm":"0600"
+// }
+func (w *fileLogWriter) Init(jsonConfig string) error {
+ err := json.Unmarshal([]byte(jsonConfig), w)
+ if err != nil {
+ return err
+ }
+ if len(w.Filename) == 0 {
+ return errors.New("jsonconfig must have filename")
+ }
+ w.suffix = filepath.Ext(w.Filename)
+ w.fileNameOnly = strings.TrimSuffix(w.Filename, w.suffix)
+ if w.suffix == "" {
+ w.suffix = ".log"
+ }
+ err = w.startLogger()
+ return err
+}
+
+// start file logger. create log file and set to locker-inside file writer.
+func (w *fileLogWriter) startLogger() error {
+ file, err := w.createLogFile()
+ if err != nil {
+ return err
+ }
+ if w.fileWriter != nil {
+ w.fileWriter.Close()
+ }
+ w.fileWriter = file
+ return w.initFd()
+}
+
+func (w *fileLogWriter) needRotateDaily(size int, day int) bool {
+ return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) ||
+ (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) ||
+ (w.Daily && day != w.dailyOpenDate)
+}
+
+func (w *fileLogWriter) needRotateHourly(size int, hour int) bool {
+ return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) ||
+ (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) ||
+ (w.Hourly && hour != w.hourlyOpenDate)
+
+}
+
+// WriteMsg write logger message into file.
+func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error {
+ if level > w.Level {
+ return nil
+ }
+ hd, d, h := formatTimeHeader(when)
+ msg = string(hd) + msg + "\n"
+ if w.Rotate {
+ w.RLock()
+ if w.needRotateHourly(len(msg), h) {
+ w.RUnlock()
+ w.Lock()
+ if w.needRotateHourly(len(msg), h) {
+ if err := w.doRotate(when); err != nil {
+ fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
+ }
+ }
+ w.Unlock()
+ } else if w.needRotateDaily(len(msg), d) {
+ w.RUnlock()
+ w.Lock()
+ if w.needRotateDaily(len(msg), d) {
+ if err := w.doRotate(when); err != nil {
+ fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
+ }
+ }
+ w.Unlock()
+ } else {
+ w.RUnlock()
+ }
+ }
+
+ w.Lock()
+ _, err := w.fileWriter.Write([]byte(msg))
+ if err == nil {
+ w.maxLinesCurLines++
+ w.maxSizeCurSize += len(msg)
+ }
+ w.Unlock()
+ return err
+}
+
+func (w *fileLogWriter) createLogFile() (*os.File, error) {
+ // Open the log file
+ perm, err := strconv.ParseInt(w.Perm, 8, 64)
+ if err != nil {
+ return nil, err
+ }
+
+ filepath := path.Dir(w.Filename)
+ os.MkdirAll(filepath, os.FileMode(perm))
+
+ fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, os.FileMode(perm))
+ if err == nil {
+ // Make sure file perm is user set perm cause of `os.OpenFile` will obey umask
+ os.Chmod(w.Filename, os.FileMode(perm))
+ }
+ return fd, err
+}
+
+func (w *fileLogWriter) initFd() error {
+ fd := w.fileWriter
+ fInfo, err := fd.Stat()
+ if err != nil {
+ return fmt.Errorf("get stat err: %s", err)
+ }
+ w.maxSizeCurSize = int(fInfo.Size())
+ w.dailyOpenTime = time.Now()
+ w.dailyOpenDate = w.dailyOpenTime.Day()
+ w.hourlyOpenTime = time.Now()
+ w.hourlyOpenDate = w.hourlyOpenTime.Hour()
+ w.maxLinesCurLines = 0
+ if w.Hourly {
+ go w.hourlyRotate(w.hourlyOpenTime)
+ } else if w.Daily {
+ go w.dailyRotate(w.dailyOpenTime)
+ }
+ if fInfo.Size() > 0 && w.MaxLines > 0 {
+ count, err := w.lines()
+ if err != nil {
+ return err
+ }
+ w.maxLinesCurLines = count
+ }
+ return nil
+}
+
+func (w *fileLogWriter) dailyRotate(openTime time.Time) {
+ y, m, d := openTime.Add(24 * time.Hour).Date()
+ nextDay := time.Date(y, m, d, 0, 0, 0, 0, openTime.Location())
+ tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100))
+ <-tm.C
+ w.Lock()
+ if w.needRotateDaily(0, time.Now().Day()) {
+ if err := w.doRotate(time.Now()); err != nil {
+ fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
+ }
+ }
+ w.Unlock()
+}
+
+func (w *fileLogWriter) hourlyRotate(openTime time.Time) {
+ y, m, d := openTime.Add(1 * time.Hour).Date()
+ h, _, _ := openTime.Add(1 * time.Hour).Clock()
+ nextHour := time.Date(y, m, d, h, 0, 0, 0, openTime.Location())
+ tm := time.NewTimer(time.Duration(nextHour.UnixNano() - openTime.UnixNano() + 100))
+ <-tm.C
+ w.Lock()
+ if w.needRotateHourly(0, time.Now().Hour()) {
+ if err := w.doRotate(time.Now()); err != nil {
+ fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
+ }
+ }
+ w.Unlock()
+}
+
+func (w *fileLogWriter) lines() (int, error) {
+ fd, err := os.Open(w.Filename)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.Close()
+
+ buf := make([]byte, 32768) // 32k
+ count := 0
+ lineSep := []byte{'\n'}
+
+ for {
+ c, err := fd.Read(buf)
+ if err != nil && err != io.EOF {
+ return count, err
+ }
+
+ count += bytes.Count(buf[:c], lineSep)
+
+ if err == io.EOF {
+ break
+ }
+ }
+
+ return count, nil
+}
+
+// DoRotate means it need to write file in new file.
+// new file name like xx.2013-01-01.log (daily) or xx.001.log (by line or size)
+func (w *fileLogWriter) doRotate(logTime time.Time) error {
+ // file exists
+ // Find the next available number
+ num := w.MaxFilesCurFiles + 1
+ fName := ""
+ format := ""
+ var openTime time.Time
+ rotatePerm, err := strconv.ParseInt(w.RotatePerm, 8, 64)
+ if err != nil {
+ return err
+ }
+
+ _, err = os.Lstat(w.Filename)
+ if err != nil {
+ //even if the file is not exist or other ,we should RESTART the logger
+ goto RESTART_LOGGER
+ }
+
+ if w.Hourly {
+ format = "2006010215"
+ openTime = w.hourlyOpenTime
+ } else if w.Daily {
+ format = "2006-01-02"
+ openTime = w.dailyOpenTime
+ }
+
+ // only when one of them be setted, then the file would be splited
+ if w.MaxLines > 0 || w.MaxSize > 0 {
+ for ; err == nil && num <= w.MaxFiles; num++ {
+ fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", logTime.Format(format), num, w.suffix)
+ _, err = os.Lstat(fName)
+ }
+ } else {
+ fName = w.fileNameOnly + fmt.Sprintf(".%s.%03d%s", openTime.Format(format), num, w.suffix)
+ _, err = os.Lstat(fName)
+ w.MaxFilesCurFiles = num
+ }
+
+ // return error if the last file checked still existed
+ if err == nil {
+ return fmt.Errorf("Rotate: Cannot find free log number to rename %s", w.Filename)
+ }
+
+ // close fileWriter before rename
+ w.fileWriter.Close()
+
+ // Rename the file to its new found name
+ // even if occurs error,we MUST guarantee to restart new logger
+ err = os.Rename(w.Filename, fName)
+ if err != nil {
+ goto RESTART_LOGGER
+ }
+
+ err = os.Chmod(fName, os.FileMode(rotatePerm))
+
+RESTART_LOGGER:
+
+ startLoggerErr := w.startLogger()
+ go w.deleteOldLog()
+
+ if startLoggerErr != nil {
+ return fmt.Errorf("Rotate StartLogger: %s", startLoggerErr)
+ }
+ if err != nil {
+ return fmt.Errorf("Rotate: %s", err)
+ }
+ return nil
+}
+
+func (w *fileLogWriter) deleteOldLog() {
+ dir := filepath.Dir(w.Filename)
+ absolutePath, err := filepath.EvalSymlinks(w.Filename)
+ if err == nil {
+ dir = filepath.Dir(absolutePath)
+ }
+ filepath.Walk(dir, func(path string, info os.FileInfo, err error) (returnErr error) {
+ defer func() {
+ if r := recover(); r != nil {
+ fmt.Fprintf(os.Stderr, "Unable to delete old log '%s', error: %v\n", path, r)
+ }
+ }()
+
+ if info == nil {
+ return
+ }
+ if w.Hourly {
+ if !info.IsDir() && info.ModTime().Add(1 * time.Hour * time.Duration(w.MaxHours)).Before(time.Now()) {
+ if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) &&
+ strings.HasSuffix(filepath.Base(path), w.suffix) {
+ os.Remove(path)
+ }
+ }
+ } else if w.Daily {
+ if !info.IsDir() && info.ModTime().Add(24 * time.Hour * time.Duration(w.MaxDays)).Before(time.Now()) {
+ if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) &&
+ strings.HasSuffix(filepath.Base(path), w.suffix) {
+ os.Remove(path)
+ }
+ }
+ }
+ return
+ })
+}
+
+// Destroy close the file description, close file writer.
+func (w *fileLogWriter) Destroy() {
+ w.fileWriter.Close()
+}
+
+// Flush flush file logger.
+// there are no buffering messages in file logger in memory.
+// flush file means sync file from disk.
+func (w *fileLogWriter) Flush() {
+ w.fileWriter.Sync()
+}
+
+func init() {
+ Register(AdapterFile, newFileWriter)
+}
diff --git a/pkg/logs/file_test.go b/pkg/logs/file_test.go
new file mode 100644
index 00000000..e7c2ca9a
--- /dev/null
+++ b/pkg/logs/file_test.go
@@ -0,0 +1,420 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logs
+
+import (
+ "bufio"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "strconv"
+ "testing"
+ "time"
+)
+
+func TestFilePerm(t *testing.T) {
+ log := NewLogger(10000)
+ // use 0666 as test perm cause the default umask is 022
+ log.SetLogger("file", `{"filename":"test.log", "perm": "0666"}`)
+ log.Debug("debug")
+ log.Informational("info")
+ log.Notice("notice")
+ log.Warning("warning")
+ log.Error("error")
+ log.Alert("alert")
+ log.Critical("critical")
+ log.Emergency("emergency")
+ file, err := os.Stat("test.log")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if file.Mode() != 0666 {
+ t.Fatal("unexpected log file permission")
+ }
+ os.Remove("test.log")
+}
+
+func TestFile1(t *testing.T) {
+ log := NewLogger(10000)
+ log.SetLogger("file", `{"filename":"test.log"}`)
+ log.Debug("debug")
+ log.Informational("info")
+ log.Notice("notice")
+ log.Warning("warning")
+ log.Error("error")
+ log.Alert("alert")
+ log.Critical("critical")
+ log.Emergency("emergency")
+ f, err := os.Open("test.log")
+ if err != nil {
+ t.Fatal(err)
+ }
+ b := bufio.NewReader(f)
+ lineNum := 0
+ for {
+ line, _, err := b.ReadLine()
+ if err != nil {
+ break
+ }
+ if len(line) > 0 {
+ lineNum++
+ }
+ }
+ var expected = LevelDebug + 1
+ if lineNum != expected {
+ t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines")
+ }
+ os.Remove("test.log")
+}
+
+func TestFile2(t *testing.T) {
+ log := NewLogger(10000)
+ log.SetLogger("file", fmt.Sprintf(`{"filename":"test2.log","level":%d}`, LevelError))
+ log.Debug("debug")
+ log.Info("info")
+ log.Notice("notice")
+ log.Warning("warning")
+ log.Error("error")
+ log.Alert("alert")
+ log.Critical("critical")
+ log.Emergency("emergency")
+ f, err := os.Open("test2.log")
+ if err != nil {
+ t.Fatal(err)
+ }
+ b := bufio.NewReader(f)
+ lineNum := 0
+ for {
+ line, _, err := b.ReadLine()
+ if err != nil {
+ break
+ }
+ if len(line) > 0 {
+ lineNum++
+ }
+ }
+ var expected = LevelError + 1
+ if lineNum != expected {
+ t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines")
+ }
+ os.Remove("test2.log")
+}
+
+func TestFileDailyRotate_01(t *testing.T) {
+ log := NewLogger(10000)
+ log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`)
+ log.Debug("debug")
+ log.Info("info")
+ log.Notice("notice")
+ log.Warning("warning")
+ log.Error("error")
+ log.Alert("alert")
+ log.Critical("critical")
+ log.Emergency("emergency")
+ rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log"
+ b, err := exists(rotateName)
+ if !b || err != nil {
+ os.Remove("test3.log")
+ t.Fatal("rotate not generated")
+ }
+ os.Remove(rotateName)
+ os.Remove("test3.log")
+}
+
+func TestFileDailyRotate_02(t *testing.T) {
+ fn1 := "rotate_day.log"
+ fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log"
+ testFileRotate(t, fn1, fn2, true, false)
+}
+
+func TestFileDailyRotate_03(t *testing.T) {
+ fn1 := "rotate_day.log"
+ fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log"
+ os.Create(fn)
+ fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log"
+ testFileRotate(t, fn1, fn2, true, false)
+ os.Remove(fn)
+}
+
+func TestFileDailyRotate_04(t *testing.T) {
+ fn1 := "rotate_day.log"
+ fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log"
+ testFileDailyRotate(t, fn1, fn2)
+}
+
+func TestFileDailyRotate_05(t *testing.T) {
+ fn1 := "rotate_day.log"
+ fn := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".log"
+ os.Create(fn)
+ fn2 := "rotate_day." + time.Now().Add(-24*time.Hour).Format("2006-01-02") + ".001.log"
+ testFileDailyRotate(t, fn1, fn2)
+ os.Remove(fn)
+}
+func TestFileDailyRotate_06(t *testing.T) { //test file mode
+ log := NewLogger(10000)
+ log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`)
+ log.Debug("debug")
+ log.Info("info")
+ log.Notice("notice")
+ log.Warning("warning")
+ log.Error("error")
+ log.Alert("alert")
+ log.Critical("critical")
+ log.Emergency("emergency")
+ rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log"
+ s, _ := os.Lstat(rotateName)
+ if s.Mode() != 0440 {
+ os.Remove(rotateName)
+ os.Remove("test3.log")
+ t.Fatal("rotate file mode error")
+ }
+ os.Remove(rotateName)
+ os.Remove("test3.log")
+}
+
+func TestFileHourlyRotate_01(t *testing.T) {
+ log := NewLogger(10000)
+ log.SetLogger("file", `{"filename":"test3.log","hourly":true,"maxlines":4}`)
+ log.Debug("debug")
+ log.Info("info")
+ log.Notice("notice")
+ log.Warning("warning")
+ log.Error("error")
+ log.Alert("alert")
+ log.Critical("critical")
+ log.Emergency("emergency")
+ rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006010215"), 1) + ".log"
+ b, err := exists(rotateName)
+ if !b || err != nil {
+ os.Remove("test3.log")
+ t.Fatal("rotate not generated")
+ }
+ os.Remove(rotateName)
+ os.Remove("test3.log")
+}
+
+func TestFileHourlyRotate_02(t *testing.T) {
+ fn1 := "rotate_hour.log"
+ fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log"
+ testFileRotate(t, fn1, fn2, false, true)
+}
+
+func TestFileHourlyRotate_03(t *testing.T) {
+ fn1 := "rotate_hour.log"
+ fn := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".log"
+ os.Create(fn)
+ fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log"
+ testFileRotate(t, fn1, fn2, false, true)
+ os.Remove(fn)
+}
+
+func TestFileHourlyRotate_04(t *testing.T) {
+ fn1 := "rotate_hour.log"
+ fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log"
+ testFileHourlyRotate(t, fn1, fn2)
+}
+
+func TestFileHourlyRotate_05(t *testing.T) {
+ fn1 := "rotate_hour.log"
+ fn := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".log"
+ os.Create(fn)
+ fn2 := "rotate_hour." + time.Now().Add(-1*time.Hour).Format("2006010215") + ".001.log"
+ testFileHourlyRotate(t, fn1, fn2)
+ os.Remove(fn)
+}
+
+func TestFileHourlyRotate_06(t *testing.T) { //test file mode
+ log := NewLogger(10000)
+ log.SetLogger("file", `{"filename":"test3.log", "hourly":true, "maxlines":4}`)
+ log.Debug("debug")
+ log.Info("info")
+ log.Notice("notice")
+ log.Warning("warning")
+ log.Error("error")
+ log.Alert("alert")
+ log.Critical("critical")
+ log.Emergency("emergency")
+ rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006010215"), 1) + ".log"
+ s, _ := os.Lstat(rotateName)
+ if s.Mode() != 0440 {
+ os.Remove(rotateName)
+ os.Remove("test3.log")
+ t.Fatal("rotate file mode error")
+ }
+ os.Remove(rotateName)
+ os.Remove("test3.log")
+}
+
+func testFileRotate(t *testing.T, fn1, fn2 string, daily, hourly bool) {
+ fw := &fileLogWriter{
+ Daily: daily,
+ MaxDays: 7,
+ Hourly: hourly,
+ MaxHours: 168,
+ Rotate: true,
+ Level: LevelTrace,
+ Perm: "0660",
+ RotatePerm: "0440",
+ }
+
+ if daily {
+ fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1))
+ fw.dailyOpenTime = time.Now().Add(-24 * time.Hour)
+ fw.dailyOpenDate = fw.dailyOpenTime.Day()
+ }
+
+ if hourly {
+ fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1))
+ fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour)
+ fw.hourlyOpenDate = fw.hourlyOpenTime.Day()
+ }
+
+ fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug)
+
+ for _, file := range []string{fn1, fn2} {
+ _, err := os.Stat(file)
+ if err != nil {
+ t.Log(err)
+ t.FailNow()
+ }
+ os.Remove(file)
+ }
+ fw.Destroy()
+}
+
+func testFileDailyRotate(t *testing.T, fn1, fn2 string) {
+ fw := &fileLogWriter{
+ Daily: true,
+ MaxDays: 7,
+ Rotate: true,
+ Level: LevelTrace,
+ Perm: "0660",
+ RotatePerm: "0440",
+ }
+ fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1))
+ fw.dailyOpenTime = time.Now().Add(-24 * time.Hour)
+ fw.dailyOpenDate = fw.dailyOpenTime.Day()
+ today, _ := time.ParseInLocation("2006-01-02", time.Now().Format("2006-01-02"), fw.dailyOpenTime.Location())
+ today = today.Add(-1 * time.Second)
+ fw.dailyRotate(today)
+ for _, file := range []string{fn1, fn2} {
+ _, err := os.Stat(file)
+ if err != nil {
+ t.FailNow()
+ }
+ content, err := ioutil.ReadFile(file)
+ if err != nil {
+ t.FailNow()
+ }
+ if len(content) > 0 {
+ t.FailNow()
+ }
+ os.Remove(file)
+ }
+ fw.Destroy()
+}
+
+func testFileHourlyRotate(t *testing.T, fn1, fn2 string) {
+ fw := &fileLogWriter{
+ Hourly: true,
+ MaxHours: 168,
+ Rotate: true,
+ Level: LevelTrace,
+ Perm: "0660",
+ RotatePerm: "0440",
+ }
+ fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1))
+ fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour)
+ fw.hourlyOpenDate = fw.hourlyOpenTime.Hour()
+ hour, _ := time.ParseInLocation("2006010215", time.Now().Format("2006010215"), fw.hourlyOpenTime.Location())
+ hour = hour.Add(-1 * time.Second)
+ fw.hourlyRotate(hour)
+ for _, file := range []string{fn1, fn2} {
+ _, err := os.Stat(file)
+ if err != nil {
+ t.FailNow()
+ }
+ content, err := ioutil.ReadFile(file)
+ if err != nil {
+ t.FailNow()
+ }
+ if len(content) > 0 {
+ t.FailNow()
+ }
+ os.Remove(file)
+ }
+ fw.Destroy()
+}
+func exists(path string) (bool, error) {
+ _, err := os.Stat(path)
+ if err == nil {
+ return true, nil
+ }
+ if os.IsNotExist(err) {
+ return false, nil
+ }
+ return false, err
+}
+
+func BenchmarkFile(b *testing.B) {
+ log := NewLogger(100000)
+ log.SetLogger("file", `{"filename":"test4.log"}`)
+ for i := 0; i < b.N; i++ {
+ log.Debug("debug")
+ }
+ os.Remove("test4.log")
+}
+
+func BenchmarkFileAsynchronous(b *testing.B) {
+ log := NewLogger(100000)
+ log.SetLogger("file", `{"filename":"test4.log"}`)
+ log.Async()
+ for i := 0; i < b.N; i++ {
+ log.Debug("debug")
+ }
+ os.Remove("test4.log")
+}
+
+func BenchmarkFileCallDepth(b *testing.B) {
+ log := NewLogger(100000)
+ log.SetLogger("file", `{"filename":"test4.log"}`)
+ log.EnableFuncCallDepth(true)
+ log.SetLogFuncCallDepth(2)
+ for i := 0; i < b.N; i++ {
+ log.Debug("debug")
+ }
+ os.Remove("test4.log")
+}
+
+func BenchmarkFileAsynchronousCallDepth(b *testing.B) {
+ log := NewLogger(100000)
+ log.SetLogger("file", `{"filename":"test4.log"}`)
+ log.EnableFuncCallDepth(true)
+ log.SetLogFuncCallDepth(2)
+ log.Async()
+ for i := 0; i < b.N; i++ {
+ log.Debug("debug")
+ }
+ os.Remove("test4.log")
+}
+
+func BenchmarkFileOnGoroutine(b *testing.B) {
+ log := NewLogger(100000)
+ log.SetLogger("file", `{"filename":"test4.log"}`)
+ for i := 0; i < b.N; i++ {
+ go log.Debug("debug")
+ }
+ os.Remove("test4.log")
+}
diff --git a/pkg/logs/jianliao.go b/pkg/logs/jianliao.go
new file mode 100644
index 00000000..88ba0f9a
--- /dev/null
+++ b/pkg/logs/jianliao.go
@@ -0,0 +1,72 @@
+package logs
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/url"
+ "time"
+)
+
+// JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook
+type JLWriter struct {
+ AuthorName string `json:"authorname"`
+ Title string `json:"title"`
+ WebhookURL string `json:"webhookurl"`
+ RedirectURL string `json:"redirecturl,omitempty"`
+ ImageURL string `json:"imageurl,omitempty"`
+ Level int `json:"level"`
+}
+
+// newJLWriter create jiaoliao writer.
+func newJLWriter() Logger {
+ return &JLWriter{Level: LevelTrace}
+}
+
+// Init JLWriter with json config string
+func (s *JLWriter) Init(jsonconfig string) error {
+ return json.Unmarshal([]byte(jsonconfig), s)
+}
+
+// WriteMsg write message in smtp writer.
+// it will send an email with subject and only this message.
+func (s *JLWriter) WriteMsg(when time.Time, msg string, level int) error {
+ if level > s.Level {
+ return nil
+ }
+
+ text := fmt.Sprintf("%s %s", when.Format("2006-01-02 15:04:05"), msg)
+
+ form := url.Values{}
+ form.Add("authorName", s.AuthorName)
+ form.Add("title", s.Title)
+ form.Add("text", text)
+ if s.RedirectURL != "" {
+ form.Add("redirectUrl", s.RedirectURL)
+ }
+ if s.ImageURL != "" {
+ form.Add("imageUrl", s.ImageURL)
+ }
+
+ resp, err := http.PostForm(s.WebhookURL, form)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode)
+ }
+ return nil
+}
+
+// Flush implementing method. empty.
+func (s *JLWriter) Flush() {
+}
+
+// Destroy implementing method. empty.
+func (s *JLWriter) Destroy() {
+}
+
+func init() {
+ Register(AdapterJianLiao, newJLWriter)
+}
diff --git a/pkg/logs/log.go b/pkg/logs/log.go
new file mode 100644
index 00000000..39c006d2
--- /dev/null
+++ b/pkg/logs/log.go
@@ -0,0 +1,669 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package logs provide a general log interface
+// Usage:
+//
+// import "github.com/astaxie/beego/logs"
+//
+// log := NewLogger(10000)
+// log.SetLogger("console", "")
+//
+// > the first params stand for how many channel
+//
+// Use it like this:
+//
+// log.Trace("trace")
+// log.Info("info")
+// log.Warn("warning")
+// log.Debug("debug")
+// log.Critical("critical")
+//
+// more docs http://beego.me/docs/module/logs.md
+package logs
+
+import (
+ "fmt"
+ "log"
+ "os"
+ "path"
+ "runtime"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+)
+
+// RFC5424 log message levels.
+const (
+ LevelEmergency = iota
+ LevelAlert
+ LevelCritical
+ LevelError
+ LevelWarning
+ LevelNotice
+ LevelInformational
+ LevelDebug
+)
+
+// levelLogLogger is defined to implement log.Logger
+// the real log level will be LevelEmergency
+const levelLoggerImpl = -1
+
+// Name for adapter with beego official support
+const (
+ AdapterConsole = "console"
+ AdapterFile = "file"
+ AdapterMultiFile = "multifile"
+ AdapterMail = "smtp"
+ AdapterConn = "conn"
+ AdapterEs = "es"
+ AdapterJianLiao = "jianliao"
+ AdapterSlack = "slack"
+ AdapterAliLS = "alils"
+)
+
+// Legacy log level constants to ensure backwards compatibility.
+const (
+ LevelInfo = LevelInformational
+ LevelTrace = LevelDebug
+ LevelWarn = LevelWarning
+)
+
+type newLoggerFunc func() Logger
+
+// Logger defines the behavior of a log provider.
+type Logger interface {
+ Init(config string) error
+ WriteMsg(when time.Time, msg string, level int) error
+ Destroy()
+ Flush()
+}
+
+var adapters = make(map[string]newLoggerFunc)
+var levelPrefix = [LevelDebug + 1]string{"[M]", "[A]", "[C]", "[E]", "[W]", "[N]", "[I]", "[D]"}
+
+// Register makes a log provide available by the provided name.
+// If Register is called twice with the same name or if driver is nil,
+// it panics.
+func Register(name string, log newLoggerFunc) {
+ if log == nil {
+ panic("logs: Register provide is nil")
+ }
+ if _, dup := adapters[name]; dup {
+ panic("logs: Register called twice for provider " + name)
+ }
+ adapters[name] = log
+}
+
+// BeeLogger is default logger in beego application.
+// it can contain several providers and log message into all providers.
+type BeeLogger struct {
+ lock sync.Mutex
+ level int
+ init bool
+ enableFuncCallDepth bool
+ loggerFuncCallDepth int
+ asynchronous bool
+ prefix string
+ msgChanLen int64
+ msgChan chan *logMsg
+ signalChan chan string
+ wg sync.WaitGroup
+ outputs []*nameLogger
+}
+
+const defaultAsyncMsgLen = 1e3
+
+type nameLogger struct {
+ Logger
+ name string
+}
+
+type logMsg struct {
+ level int
+ msg string
+ when time.Time
+}
+
+var logMsgPool *sync.Pool
+
+// NewLogger returns a new BeeLogger.
+// channelLen means the number of messages in chan(used where asynchronous is true).
+// if the buffering chan is full, logger adapters write to file or other way.
+func NewLogger(channelLens ...int64) *BeeLogger {
+ bl := new(BeeLogger)
+ bl.level = LevelDebug
+ bl.loggerFuncCallDepth = 2
+ bl.msgChanLen = append(channelLens, 0)[0]
+ if bl.msgChanLen <= 0 {
+ bl.msgChanLen = defaultAsyncMsgLen
+ }
+ bl.signalChan = make(chan string, 1)
+ bl.setLogger(AdapterConsole)
+ return bl
+}
+
+// Async set the log to asynchronous and start the goroutine
+func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger {
+ bl.lock.Lock()
+ defer bl.lock.Unlock()
+ if bl.asynchronous {
+ return bl
+ }
+ bl.asynchronous = true
+ if len(msgLen) > 0 && msgLen[0] > 0 {
+ bl.msgChanLen = msgLen[0]
+ }
+ bl.msgChan = make(chan *logMsg, bl.msgChanLen)
+ logMsgPool = &sync.Pool{
+ New: func() interface{} {
+ return &logMsg{}
+ },
+ }
+ bl.wg.Add(1)
+ go bl.startLogger()
+ return bl
+}
+
+// SetLogger provides a given logger adapter into BeeLogger with config string.
+// config need to be correct JSON as string: {"interval":360}.
+func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error {
+ config := append(configs, "{}")[0]
+ for _, l := range bl.outputs {
+ if l.name == adapterName {
+ return fmt.Errorf("logs: duplicate adaptername %q (you have set this logger before)", adapterName)
+ }
+ }
+
+ logAdapter, ok := adapters[adapterName]
+ if !ok {
+ return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName)
+ }
+
+ lg := logAdapter()
+ err := lg.Init(config)
+ if err != nil {
+ fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error())
+ return err
+ }
+ bl.outputs = append(bl.outputs, &nameLogger{name: adapterName, Logger: lg})
+ return nil
+}
+
+// SetLogger provides a given logger adapter into BeeLogger with config string.
+// config need to be correct JSON as string: {"interval":360}.
+func (bl *BeeLogger) SetLogger(adapterName string, configs ...string) error {
+ bl.lock.Lock()
+ defer bl.lock.Unlock()
+ if !bl.init {
+ bl.outputs = []*nameLogger{}
+ bl.init = true
+ }
+ return bl.setLogger(adapterName, configs...)
+}
+
+// DelLogger remove a logger adapter in BeeLogger.
+func (bl *BeeLogger) DelLogger(adapterName string) error {
+ bl.lock.Lock()
+ defer bl.lock.Unlock()
+ outputs := []*nameLogger{}
+ for _, lg := range bl.outputs {
+ if lg.name == adapterName {
+ lg.Destroy()
+ } else {
+ outputs = append(outputs, lg)
+ }
+ }
+ if len(outputs) == len(bl.outputs) {
+ return fmt.Errorf("logs: unknown adaptername %q (forgotten Register?)", adapterName)
+ }
+ bl.outputs = outputs
+ return nil
+}
+
+func (bl *BeeLogger) writeToLoggers(when time.Time, msg string, level int) {
+ for _, l := range bl.outputs {
+ err := l.WriteMsg(when, msg, level)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "unable to WriteMsg to adapter:%v,error:%v\n", l.name, err)
+ }
+ }
+}
+
+func (bl *BeeLogger) Write(p []byte) (n int, err error) {
+ if len(p) == 0 {
+ return 0, nil
+ }
+ // writeMsg will always add a '\n' character
+ if p[len(p)-1] == '\n' {
+ p = p[0 : len(p)-1]
+ }
+ // set levelLoggerImpl to ensure all log message will be write out
+ err = bl.writeMsg(levelLoggerImpl, string(p))
+ if err == nil {
+ return len(p), err
+ }
+ return 0, err
+}
+
+func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error {
+ if !bl.init {
+ bl.lock.Lock()
+ bl.setLogger(AdapterConsole)
+ bl.lock.Unlock()
+ }
+
+ if len(v) > 0 {
+ msg = fmt.Sprintf(msg, v...)
+ }
+
+ msg = bl.prefix + " " + msg
+
+ when := time.Now()
+ if bl.enableFuncCallDepth {
+ _, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth)
+ if !ok {
+ file = "???"
+ line = 0
+ }
+ _, filename := path.Split(file)
+ msg = "[" + filename + ":" + strconv.Itoa(line) + "] " + msg
+ }
+
+ //set level info in front of filename info
+ if logLevel == levelLoggerImpl {
+ // set to emergency to ensure all log will be print out correctly
+ logLevel = LevelEmergency
+ } else {
+ msg = levelPrefix[logLevel] + " " + msg
+ }
+
+ if bl.asynchronous {
+ lm := logMsgPool.Get().(*logMsg)
+ lm.level = logLevel
+ lm.msg = msg
+ lm.when = when
+ if bl.outputs != nil {
+ bl.msgChan <- lm
+ } else {
+ logMsgPool.Put(lm)
+ }
+ } else {
+ bl.writeToLoggers(when, msg, logLevel)
+ }
+ return nil
+}
+
+// SetLevel Set log message level.
+// If message level (such as LevelDebug) is higher than logger level (such as LevelWarning),
+// log providers will not even be sent the message.
+func (bl *BeeLogger) SetLevel(l int) {
+ bl.level = l
+}
+
+// GetLevel Get Current log message level.
+func (bl *BeeLogger) GetLevel() int {
+ return bl.level
+}
+
+// SetLogFuncCallDepth set log funcCallDepth
+func (bl *BeeLogger) SetLogFuncCallDepth(d int) {
+ bl.loggerFuncCallDepth = d
+}
+
+// GetLogFuncCallDepth return log funcCallDepth for wrapper
+func (bl *BeeLogger) GetLogFuncCallDepth() int {
+ return bl.loggerFuncCallDepth
+}
+
+// EnableFuncCallDepth enable log funcCallDepth
+func (bl *BeeLogger) EnableFuncCallDepth(b bool) {
+ bl.enableFuncCallDepth = b
+}
+
+// set prefix
+func (bl *BeeLogger) SetPrefix(s string) {
+ bl.prefix = s
+}
+
+// start logger chan reading.
+// when chan is not empty, write logs.
+func (bl *BeeLogger) startLogger() {
+ gameOver := false
+ for {
+ select {
+ case bm := <-bl.msgChan:
+ bl.writeToLoggers(bm.when, bm.msg, bm.level)
+ logMsgPool.Put(bm)
+ case sg := <-bl.signalChan:
+ // Now should only send "flush" or "close" to bl.signalChan
+ bl.flush()
+ if sg == "close" {
+ for _, l := range bl.outputs {
+ l.Destroy()
+ }
+ bl.outputs = nil
+ gameOver = true
+ }
+ bl.wg.Done()
+ }
+ if gameOver {
+ break
+ }
+ }
+}
+
+// Emergency Log EMERGENCY level message.
+func (bl *BeeLogger) Emergency(format string, v ...interface{}) {
+ if LevelEmergency > bl.level {
+ return
+ }
+ bl.writeMsg(LevelEmergency, format, v...)
+}
+
+// Alert Log ALERT level message.
+func (bl *BeeLogger) Alert(format string, v ...interface{}) {
+ if LevelAlert > bl.level {
+ return
+ }
+ bl.writeMsg(LevelAlert, format, v...)
+}
+
+// Critical Log CRITICAL level message.
+func (bl *BeeLogger) Critical(format string, v ...interface{}) {
+ if LevelCritical > bl.level {
+ return
+ }
+ bl.writeMsg(LevelCritical, format, v...)
+}
+
+// Error Log ERROR level message.
+func (bl *BeeLogger) Error(format string, v ...interface{}) {
+ if LevelError > bl.level {
+ return
+ }
+ bl.writeMsg(LevelError, format, v...)
+}
+
+// Warning Log WARNING level message.
+func (bl *BeeLogger) Warning(format string, v ...interface{}) {
+ if LevelWarn > bl.level {
+ return
+ }
+ bl.writeMsg(LevelWarn, format, v...)
+}
+
+// Notice Log NOTICE level message.
+func (bl *BeeLogger) Notice(format string, v ...interface{}) {
+ if LevelNotice > bl.level {
+ return
+ }
+ bl.writeMsg(LevelNotice, format, v...)
+}
+
+// Informational Log INFORMATIONAL level message.
+func (bl *BeeLogger) Informational(format string, v ...interface{}) {
+ if LevelInfo > bl.level {
+ return
+ }
+ bl.writeMsg(LevelInfo, format, v...)
+}
+
+// Debug Log DEBUG level message.
+func (bl *BeeLogger) Debug(format string, v ...interface{}) {
+ if LevelDebug > bl.level {
+ return
+ }
+ bl.writeMsg(LevelDebug, format, v...)
+}
+
+// Warn Log WARN level message.
+// compatibility alias for Warning()
+func (bl *BeeLogger) Warn(format string, v ...interface{}) {
+ if LevelWarn > bl.level {
+ return
+ }
+ bl.writeMsg(LevelWarn, format, v...)
+}
+
+// Info Log INFO level message.
+// compatibility alias for Informational()
+func (bl *BeeLogger) Info(format string, v ...interface{}) {
+ if LevelInfo > bl.level {
+ return
+ }
+ bl.writeMsg(LevelInfo, format, v...)
+}
+
+// Trace Log TRACE level message.
+// compatibility alias for Debug()
+func (bl *BeeLogger) Trace(format string, v ...interface{}) {
+ if LevelDebug > bl.level {
+ return
+ }
+ bl.writeMsg(LevelDebug, format, v...)
+}
+
+// Flush flush all chan data.
+func (bl *BeeLogger) Flush() {
+ if bl.asynchronous {
+ bl.signalChan <- "flush"
+ bl.wg.Wait()
+ bl.wg.Add(1)
+ return
+ }
+ bl.flush()
+}
+
+// Close close logger, flush all chan data and destroy all adapters in BeeLogger.
+func (bl *BeeLogger) Close() {
+ if bl.asynchronous {
+ bl.signalChan <- "close"
+ bl.wg.Wait()
+ close(bl.msgChan)
+ } else {
+ bl.flush()
+ for _, l := range bl.outputs {
+ l.Destroy()
+ }
+ bl.outputs = nil
+ }
+ close(bl.signalChan)
+}
+
+// Reset close all outputs, and set bl.outputs to nil
+func (bl *BeeLogger) Reset() {
+ bl.Flush()
+ for _, l := range bl.outputs {
+ l.Destroy()
+ }
+ bl.outputs = nil
+}
+
+func (bl *BeeLogger) flush() {
+ if bl.asynchronous {
+ for {
+ if len(bl.msgChan) > 0 {
+ bm := <-bl.msgChan
+ bl.writeToLoggers(bm.when, bm.msg, bm.level)
+ logMsgPool.Put(bm)
+ continue
+ }
+ break
+ }
+ }
+ for _, l := range bl.outputs {
+ l.Flush()
+ }
+}
+
+// beeLogger references the used application logger.
+var beeLogger = NewLogger()
+
+// GetBeeLogger returns the default BeeLogger
+func GetBeeLogger() *BeeLogger {
+ return beeLogger
+}
+
+var beeLoggerMap = struct {
+ sync.RWMutex
+ logs map[string]*log.Logger
+}{
+ logs: map[string]*log.Logger{},
+}
+
+// GetLogger returns the default BeeLogger
+func GetLogger(prefixes ...string) *log.Logger {
+ prefix := append(prefixes, "")[0]
+ if prefix != "" {
+ prefix = fmt.Sprintf(`[%s] `, strings.ToUpper(prefix))
+ }
+ beeLoggerMap.RLock()
+ l, ok := beeLoggerMap.logs[prefix]
+ if ok {
+ beeLoggerMap.RUnlock()
+ return l
+ }
+ beeLoggerMap.RUnlock()
+ beeLoggerMap.Lock()
+ defer beeLoggerMap.Unlock()
+ l, ok = beeLoggerMap.logs[prefix]
+ if !ok {
+ l = log.New(beeLogger, prefix, 0)
+ beeLoggerMap.logs[prefix] = l
+ }
+ return l
+}
+
+// Reset will remove all the adapter
+func Reset() {
+ beeLogger.Reset()
+}
+
+// Async set the beelogger with Async mode and hold msglen messages
+func Async(msgLen ...int64) *BeeLogger {
+ return beeLogger.Async(msgLen...)
+}
+
+// SetLevel sets the global log level used by the simple logger.
+func SetLevel(l int) {
+ beeLogger.SetLevel(l)
+}
+
+// SetPrefix sets the prefix
+func SetPrefix(s string) {
+ beeLogger.SetPrefix(s)
+}
+
+// EnableFuncCallDepth enable log funcCallDepth
+func EnableFuncCallDepth(b bool) {
+ beeLogger.enableFuncCallDepth = b
+}
+
+// SetLogFuncCall set the CallDepth, default is 4
+func SetLogFuncCall(b bool) {
+ beeLogger.EnableFuncCallDepth(b)
+ beeLogger.SetLogFuncCallDepth(4)
+}
+
+// SetLogFuncCallDepth set log funcCallDepth
+func SetLogFuncCallDepth(d int) {
+ beeLogger.loggerFuncCallDepth = d
+}
+
+// SetLogger sets a new logger.
+func SetLogger(adapter string, config ...string) error {
+ return beeLogger.SetLogger(adapter, config...)
+}
+
+// Emergency logs a message at emergency level.
+func Emergency(f interface{}, v ...interface{}) {
+ beeLogger.Emergency(formatLog(f, v...))
+}
+
+// Alert logs a message at alert level.
+func Alert(f interface{}, v ...interface{}) {
+ beeLogger.Alert(formatLog(f, v...))
+}
+
+// Critical logs a message at critical level.
+func Critical(f interface{}, v ...interface{}) {
+ beeLogger.Critical(formatLog(f, v...))
+}
+
+// Error logs a message at error level.
+func Error(f interface{}, v ...interface{}) {
+ beeLogger.Error(formatLog(f, v...))
+}
+
+// Warning logs a message at warning level.
+func Warning(f interface{}, v ...interface{}) {
+ beeLogger.Warn(formatLog(f, v...))
+}
+
+// Warn compatibility alias for Warning()
+func Warn(f interface{}, v ...interface{}) {
+ beeLogger.Warn(formatLog(f, v...))
+}
+
+// Notice logs a message at notice level.
+func Notice(f interface{}, v ...interface{}) {
+ beeLogger.Notice(formatLog(f, v...))
+}
+
+// Informational logs a message at info level.
+func Informational(f interface{}, v ...interface{}) {
+ beeLogger.Info(formatLog(f, v...))
+}
+
+// Info compatibility alias for Warning()
+func Info(f interface{}, v ...interface{}) {
+ beeLogger.Info(formatLog(f, v...))
+}
+
+// Debug logs a message at debug level.
+func Debug(f interface{}, v ...interface{}) {
+ beeLogger.Debug(formatLog(f, v...))
+}
+
+// Trace logs a message at trace level.
+// compatibility alias for Warning()
+func Trace(f interface{}, v ...interface{}) {
+ beeLogger.Trace(formatLog(f, v...))
+}
+
+func formatLog(f interface{}, v ...interface{}) string {
+ var msg string
+ switch f.(type) {
+ case string:
+ msg = f.(string)
+ if len(v) == 0 {
+ return msg
+ }
+ if strings.Contains(msg, "%") && !strings.Contains(msg, "%%") {
+ //format string
+ } else {
+ //do not contain format char
+ msg += strings.Repeat(" %v", len(v))
+ }
+ default:
+ msg = fmt.Sprint(f)
+ if len(v) == 0 {
+ return msg
+ }
+ msg += strings.Repeat(" %v", len(v))
+ }
+ return fmt.Sprintf(msg, v...)
+}
diff --git a/pkg/logs/logger.go b/pkg/logs/logger.go
new file mode 100644
index 00000000..a28bff6f
--- /dev/null
+++ b/pkg/logs/logger.go
@@ -0,0 +1,176 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logs
+
+import (
+ "io"
+ "runtime"
+ "sync"
+ "time"
+)
+
+type logWriter struct {
+ sync.Mutex
+ writer io.Writer
+}
+
+func newLogWriter(wr io.Writer) *logWriter {
+ return &logWriter{writer: wr}
+}
+
+func (lg *logWriter) writeln(when time.Time, msg string) (int, error) {
+ lg.Lock()
+ h, _, _ := formatTimeHeader(when)
+ n, err := lg.writer.Write(append(append(h, msg...), '\n'))
+ lg.Unlock()
+ return n, err
+}
+
+const (
+ y1 = `0123456789`
+ y2 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789`
+ y3 = `0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999`
+ y4 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789`
+ mo1 = `000000000111`
+ mo2 = `123456789012`
+ d1 = `0000000001111111111222222222233`
+ d2 = `1234567890123456789012345678901`
+ h1 = `000000000011111111112222`
+ h2 = `012345678901234567890123`
+ mi1 = `000000000011111111112222222222333333333344444444445555555555`
+ mi2 = `012345678901234567890123456789012345678901234567890123456789`
+ s1 = `000000000011111111112222222222333333333344444444445555555555`
+ s2 = `012345678901234567890123456789012345678901234567890123456789`
+ ns1 = `0123456789`
+)
+
+func formatTimeHeader(when time.Time) ([]byte, int, int) {
+ y, mo, d := when.Date()
+ h, mi, s := when.Clock()
+ ns := when.Nanosecond() / 1000000
+ //len("2006/01/02 15:04:05.123 ")==24
+ var buf [24]byte
+
+ buf[0] = y1[y/1000%10]
+ buf[1] = y2[y/100]
+ buf[2] = y3[y-y/100*100]
+ buf[3] = y4[y-y/100*100]
+ buf[4] = '/'
+ buf[5] = mo1[mo-1]
+ buf[6] = mo2[mo-1]
+ buf[7] = '/'
+ buf[8] = d1[d-1]
+ buf[9] = d2[d-1]
+ buf[10] = ' '
+ buf[11] = h1[h]
+ buf[12] = h2[h]
+ buf[13] = ':'
+ buf[14] = mi1[mi]
+ buf[15] = mi2[mi]
+ buf[16] = ':'
+ buf[17] = s1[s]
+ buf[18] = s2[s]
+ buf[19] = '.'
+ buf[20] = ns1[ns/100]
+ buf[21] = ns1[ns%100/10]
+ buf[22] = ns1[ns%10]
+
+ buf[23] = ' '
+
+ return buf[0:], d, h
+}
+
+var (
+ green = string([]byte{27, 91, 57, 55, 59, 52, 50, 109})
+ white = string([]byte{27, 91, 57, 48, 59, 52, 55, 109})
+ yellow = string([]byte{27, 91, 57, 55, 59, 52, 51, 109})
+ red = string([]byte{27, 91, 57, 55, 59, 52, 49, 109})
+ blue = string([]byte{27, 91, 57, 55, 59, 52, 52, 109})
+ magenta = string([]byte{27, 91, 57, 55, 59, 52, 53, 109})
+ cyan = string([]byte{27, 91, 57, 55, 59, 52, 54, 109})
+
+ w32Green = string([]byte{27, 91, 52, 50, 109})
+ w32White = string([]byte{27, 91, 52, 55, 109})
+ w32Yellow = string([]byte{27, 91, 52, 51, 109})
+ w32Red = string([]byte{27, 91, 52, 49, 109})
+ w32Blue = string([]byte{27, 91, 52, 52, 109})
+ w32Magenta = string([]byte{27, 91, 52, 53, 109})
+ w32Cyan = string([]byte{27, 91, 52, 54, 109})
+
+ reset = string([]byte{27, 91, 48, 109})
+)
+
+var once sync.Once
+var colorMap map[string]string
+
+func initColor() {
+ if runtime.GOOS == "windows" {
+ green = w32Green
+ white = w32White
+ yellow = w32Yellow
+ red = w32Red
+ blue = w32Blue
+ magenta = w32Magenta
+ cyan = w32Cyan
+ }
+ colorMap = map[string]string{
+ //by color
+ "green": green,
+ "white": white,
+ "yellow": yellow,
+ "red": red,
+ //by method
+ "GET": blue,
+ "POST": cyan,
+ "PUT": yellow,
+ "DELETE": red,
+ "PATCH": green,
+ "HEAD": magenta,
+ "OPTIONS": white,
+ }
+}
+
+// ColorByStatus return color by http code
+// 2xx return Green
+// 3xx return White
+// 4xx return Yellow
+// 5xx return Red
+func ColorByStatus(code int) string {
+ once.Do(initColor)
+ switch {
+ case code >= 200 && code < 300:
+ return colorMap["green"]
+ case code >= 300 && code < 400:
+ return colorMap["white"]
+ case code >= 400 && code < 500:
+ return colorMap["yellow"]
+ default:
+ return colorMap["red"]
+ }
+}
+
+// ColorByMethod return color by http code
+func ColorByMethod(method string) string {
+ once.Do(initColor)
+ if c := colorMap[method]; c != "" {
+ return c
+ }
+ return reset
+}
+
+// ResetColor return reset color
+func ResetColor() string {
+ return reset
+}
diff --git a/pkg/logs/logger_test.go b/pkg/logs/logger_test.go
new file mode 100644
index 00000000..15be500d
--- /dev/null
+++ b/pkg/logs/logger_test.go
@@ -0,0 +1,57 @@
+// Copyright 2016 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logs
+
+import (
+ "testing"
+ "time"
+)
+
+func TestFormatHeader_0(t *testing.T) {
+ tm := time.Now()
+ if tm.Year() >= 2100 {
+ t.FailNow()
+ }
+ dur := time.Second
+ for {
+ if tm.Year() >= 2100 {
+ break
+ }
+ h, _, _ := formatTimeHeader(tm)
+ if tm.Format("2006/01/02 15:04:05.000 ") != string(h) {
+ t.Log(tm)
+ t.FailNow()
+ }
+ tm = tm.Add(dur)
+ dur *= 2
+ }
+}
+
+func TestFormatHeader_1(t *testing.T) {
+ tm := time.Now()
+ year := tm.Year()
+ dur := time.Second
+ for {
+ if tm.Year() >= year+1 {
+ break
+ }
+ h, _, _ := formatTimeHeader(tm)
+ if tm.Format("2006/01/02 15:04:05.000 ") != string(h) {
+ t.Log(tm)
+ t.FailNow()
+ }
+ tm = tm.Add(dur)
+ }
+}
diff --git a/pkg/logs/multifile.go b/pkg/logs/multifile.go
new file mode 100644
index 00000000..90168274
--- /dev/null
+++ b/pkg/logs/multifile.go
@@ -0,0 +1,119 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logs
+
+import (
+ "encoding/json"
+ "time"
+)
+
+// A filesLogWriter manages several fileLogWriter
+// filesLogWriter will write logs to the file in json configuration and write the same level log to correspond file
+// means if the file name in configuration is project.log filesLogWriter will create project.error.log/project.debug.log
+// and write the error-level logs to project.error.log and write the debug-level logs to project.debug.log
+// the rotate attribute also acts like fileLogWriter
+type multiFileLogWriter struct {
+ writers [LevelDebug + 1 + 1]*fileLogWriter // the last one for fullLogWriter
+ fullLogWriter *fileLogWriter
+ Separate []string `json:"separate"`
+}
+
+var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"}
+
+// Init file logger with json config.
+// jsonConfig like:
+// {
+// "filename":"logs/beego.log",
+// "maxLines":0,
+// "maxsize":0,
+// "daily":true,
+// "maxDays":15,
+// "rotate":true,
+// "perm":0600,
+// "separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"],
+// }
+
+func (f *multiFileLogWriter) Init(config string) error {
+ writer := newFileWriter().(*fileLogWriter)
+ err := writer.Init(config)
+ if err != nil {
+ return err
+ }
+ f.fullLogWriter = writer
+ f.writers[LevelDebug+1] = writer
+
+ //unmarshal "separate" field to f.Separate
+ json.Unmarshal([]byte(config), f)
+
+ jsonMap := map[string]interface{}{}
+ json.Unmarshal([]byte(config), &jsonMap)
+
+ for i := LevelEmergency; i < LevelDebug+1; i++ {
+ for _, v := range f.Separate {
+ if v == levelNames[i] {
+ jsonMap["filename"] = f.fullLogWriter.fileNameOnly + "." + levelNames[i] + f.fullLogWriter.suffix
+ jsonMap["level"] = i
+ bs, _ := json.Marshal(jsonMap)
+ writer = newFileWriter().(*fileLogWriter)
+ err := writer.Init(string(bs))
+ if err != nil {
+ return err
+ }
+ f.writers[i] = writer
+ }
+ }
+ }
+
+ return nil
+}
+
+func (f *multiFileLogWriter) Destroy() {
+ for i := 0; i < len(f.writers); i++ {
+ if f.writers[i] != nil {
+ f.writers[i].Destroy()
+ }
+ }
+}
+
+func (f *multiFileLogWriter) WriteMsg(when time.Time, msg string, level int) error {
+ if f.fullLogWriter != nil {
+ f.fullLogWriter.WriteMsg(when, msg, level)
+ }
+ for i := 0; i < len(f.writers)-1; i++ {
+ if f.writers[i] != nil {
+ if level == f.writers[i].Level {
+ f.writers[i].WriteMsg(when, msg, level)
+ }
+ }
+ }
+ return nil
+}
+
+func (f *multiFileLogWriter) Flush() {
+ for i := 0; i < len(f.writers); i++ {
+ if f.writers[i] != nil {
+ f.writers[i].Flush()
+ }
+ }
+}
+
+// newFilesWriter create a FileLogWriter returning as LoggerInterface.
+func newFilesWriter() Logger {
+ return &multiFileLogWriter{}
+}
+
+func init() {
+ Register(AdapterMultiFile, newFilesWriter)
+}
diff --git a/pkg/logs/multifile_test.go b/pkg/logs/multifile_test.go
new file mode 100644
index 00000000..57b96094
--- /dev/null
+++ b/pkg/logs/multifile_test.go
@@ -0,0 +1,78 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logs
+
+import (
+ "bufio"
+ "os"
+ "strconv"
+ "strings"
+ "testing"
+)
+
+func TestFiles_1(t *testing.T) {
+ log := NewLogger(10000)
+ log.SetLogger("multifile", `{"filename":"test.log","separate":["emergency", "alert", "critical", "error", "warning", "notice", "info", "debug"]}`)
+ log.Debug("debug")
+ log.Informational("info")
+ log.Notice("notice")
+ log.Warning("warning")
+ log.Error("error")
+ log.Alert("alert")
+ log.Critical("critical")
+ log.Emergency("emergency")
+ fns := []string{""}
+ fns = append(fns, levelNames[0:]...)
+ name := "test"
+ suffix := ".log"
+ for _, fn := range fns {
+
+ file := name + suffix
+ if fn != "" {
+ file = name + "." + fn + suffix
+ }
+ f, err := os.Open(file)
+ if err != nil {
+ t.Fatal(err)
+ }
+ b := bufio.NewReader(f)
+ lineNum := 0
+ lastLine := ""
+ for {
+ line, _, err := b.ReadLine()
+ if err != nil {
+ break
+ }
+ if len(line) > 0 {
+ lastLine = string(line)
+ lineNum++
+ }
+ }
+ var expected = 1
+ if fn == "" {
+ expected = LevelDebug + 1
+ }
+ if lineNum != expected {
+ t.Fatal(file, "has", lineNum, "lines not "+strconv.Itoa(expected)+" lines")
+ }
+ if lineNum == 1 {
+ if !strings.Contains(lastLine, fn) {
+ t.Fatal(file + " " + lastLine + " not contains the log msg " + fn)
+ }
+ }
+ os.Remove(file)
+ }
+
+}
diff --git a/pkg/logs/slack.go b/pkg/logs/slack.go
new file mode 100644
index 00000000..1cd2e5ae
--- /dev/null
+++ b/pkg/logs/slack.go
@@ -0,0 +1,60 @@
+package logs
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/url"
+ "time"
+)
+
+// SLACKWriter implements beego LoggerInterface and is used to send jiaoliao webhook
+type SLACKWriter struct {
+ WebhookURL string `json:"webhookurl"`
+ Level int `json:"level"`
+}
+
+// newSLACKWriter create jiaoliao writer.
+func newSLACKWriter() Logger {
+ return &SLACKWriter{Level: LevelTrace}
+}
+
+// Init SLACKWriter with json config string
+func (s *SLACKWriter) Init(jsonconfig string) error {
+ return json.Unmarshal([]byte(jsonconfig), s)
+}
+
+// WriteMsg write message in smtp writer.
+// it will send an email with subject and only this message.
+func (s *SLACKWriter) WriteMsg(when time.Time, msg string, level int) error {
+ if level > s.Level {
+ return nil
+ }
+
+ text := fmt.Sprintf("{\"text\": \"%s %s\"}", when.Format("2006-01-02 15:04:05"), msg)
+
+ form := url.Values{}
+ form.Add("payload", text)
+
+ resp, err := http.PostForm(s.WebhookURL, form)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode)
+ }
+ return nil
+}
+
+// Flush implementing method. empty.
+func (s *SLACKWriter) Flush() {
+}
+
+// Destroy implementing method. empty.
+func (s *SLACKWriter) Destroy() {
+}
+
+func init() {
+ Register(AdapterSlack, newSLACKWriter)
+}
diff --git a/pkg/logs/smtp.go b/pkg/logs/smtp.go
new file mode 100644
index 00000000..6208d7b8
--- /dev/null
+++ b/pkg/logs/smtp.go
@@ -0,0 +1,149 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logs
+
+import (
+ "crypto/tls"
+ "encoding/json"
+ "fmt"
+ "net"
+ "net/smtp"
+ "strings"
+ "time"
+)
+
+// SMTPWriter implements LoggerInterface and is used to send emails via given SMTP-server.
+type SMTPWriter struct {
+ Username string `json:"username"`
+ Password string `json:"password"`
+ Host string `json:"host"`
+ Subject string `json:"subject"`
+ FromAddress string `json:"fromAddress"`
+ RecipientAddresses []string `json:"sendTos"`
+ Level int `json:"level"`
+}
+
+// NewSMTPWriter create smtp writer.
+func newSMTPWriter() Logger {
+ return &SMTPWriter{Level: LevelTrace}
+}
+
+// Init smtp writer with json config.
+// config like:
+// {
+// "username":"example@gmail.com",
+// "password:"password",
+// "host":"smtp.gmail.com:465",
+// "subject":"email title",
+// "fromAddress":"from@example.com",
+// "sendTos":["email1","email2"],
+// "level":LevelError
+// }
+func (s *SMTPWriter) Init(jsonconfig string) error {
+ return json.Unmarshal([]byte(jsonconfig), s)
+}
+
+func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth {
+ if len(strings.Trim(s.Username, " ")) == 0 && len(strings.Trim(s.Password, " ")) == 0 {
+ return nil
+ }
+ return smtp.PlainAuth(
+ "",
+ s.Username,
+ s.Password,
+ host,
+ )
+}
+
+func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error {
+ client, err := smtp.Dial(hostAddressWithPort)
+ if err != nil {
+ return err
+ }
+
+ host, _, _ := net.SplitHostPort(hostAddressWithPort)
+ tlsConn := &tls.Config{
+ InsecureSkipVerify: true,
+ ServerName: host,
+ }
+ if err = client.StartTLS(tlsConn); err != nil {
+ return err
+ }
+
+ if auth != nil {
+ if err = client.Auth(auth); err != nil {
+ return err
+ }
+ }
+
+ if err = client.Mail(fromAddress); err != nil {
+ return err
+ }
+
+ for _, rec := range recipients {
+ if err = client.Rcpt(rec); err != nil {
+ return err
+ }
+ }
+
+ w, err := client.Data()
+ if err != nil {
+ return err
+ }
+ _, err = w.Write(msgContent)
+ if err != nil {
+ return err
+ }
+
+ err = w.Close()
+ if err != nil {
+ return err
+ }
+
+ return client.Quit()
+}
+
+// WriteMsg write message in smtp writer.
+// it will send an email with subject and only this message.
+func (s *SMTPWriter) WriteMsg(when time.Time, msg string, level int) error {
+ if level > s.Level {
+ return nil
+ }
+
+ hp := strings.Split(s.Host, ":")
+
+ // Set up authentication information.
+ auth := s.getSMTPAuth(hp[0])
+
+ // Connect to the server, authenticate, set the sender and recipient,
+ // and send the email all in one step.
+ contentType := "Content-Type: text/plain" + "; charset=UTF-8"
+ mailmsg := []byte("To: " + strings.Join(s.RecipientAddresses, ";") + "\r\nFrom: " + s.FromAddress + "<" + s.FromAddress +
+ ">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", when.Format("2006-01-02 15:04:05")) + msg)
+
+ return s.sendMail(s.Host, auth, s.FromAddress, s.RecipientAddresses, mailmsg)
+}
+
+// Flush implementing method. empty.
+func (s *SMTPWriter) Flush() {
+}
+
+// Destroy implementing method. empty.
+func (s *SMTPWriter) Destroy() {
+}
+
+func init() {
+ Register(AdapterMail, newSMTPWriter)
+}
diff --git a/pkg/logs/smtp_test.go b/pkg/logs/smtp_test.go
new file mode 100644
index 00000000..28e762d2
--- /dev/null
+++ b/pkg/logs/smtp_test.go
@@ -0,0 +1,27 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package logs
+
+import (
+ "testing"
+ "time"
+)
+
+func TestSmtp(t *testing.T) {
+ log := NewLogger(10000)
+ log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx","host":"smtp.gmail.com:587","sendTos":["xiemengjun@gmail.com"]}`)
+ log.Critical("sendmail critical")
+ time.Sleep(time.Second * 30)
+}
diff --git a/pkg/metric/prometheus.go b/pkg/metric/prometheus.go
new file mode 100644
index 00000000..7722240b
--- /dev/null
+++ b/pkg/metric/prometheus.go
@@ -0,0 +1,99 @@
+// Copyright 2020 astaxie
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package metric
+
+import (
+ "net/http"
+ "reflect"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/prometheus/client_golang/prometheus"
+
+ "github.com/astaxie/beego"
+ "github.com/astaxie/beego/logs"
+)
+
+func PrometheusMiddleWare(next http.Handler) http.Handler {
+ summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{
+ Name: "beego",
+ Subsystem: "http_request",
+ ConstLabels: map[string]string{
+ "server": beego.BConfig.ServerName,
+ "env": beego.BConfig.RunMode,
+ "appname": beego.BConfig.AppName,
+ },
+ Help: "The statics info for http request",
+ }, []string{"pattern", "method", "status", "duration"})
+
+ prometheus.MustRegister(summaryVec)
+
+ registerBuildInfo()
+
+ return http.HandlerFunc(func(writer http.ResponseWriter, q *http.Request) {
+ start := time.Now()
+ next.ServeHTTP(writer, q)
+ end := time.Now()
+ go report(end.Sub(start), writer, q, summaryVec)
+ })
+}
+
+func registerBuildInfo() {
+ buildInfo := prometheus.NewGaugeVec(prometheus.GaugeOpts{
+ Name: "beego",
+ Subsystem: "build_info",
+ Help: "The building information",
+ ConstLabels: map[string]string{
+ "appname": beego.BConfig.AppName,
+ "build_version": beego.BuildVersion,
+ "build_revision": beego.BuildGitRevision,
+ "build_status": beego.BuildStatus,
+ "build_tag": beego.BuildTag,
+ "build_time": strings.Replace(beego.BuildTime, "--", " ", 1),
+ "go_version": beego.GoVersion,
+ "git_branch": beego.GitBranch,
+ "start_time": time.Now().Format("2006-01-02 15:04:05"),
+ },
+ }, []string{})
+
+ prometheus.MustRegister(buildInfo)
+ buildInfo.WithLabelValues().Set(1)
+}
+
+func report(dur time.Duration, writer http.ResponseWriter, q *http.Request, vec *prometheus.SummaryVec) {
+ ctrl := beego.BeeApp.Handlers
+ ctx := ctrl.GetContext()
+ ctx.Reset(writer, q)
+ defer ctrl.GiveBackContext(ctx)
+
+ // We cannot read the status code from q.Response.StatusCode
+ // since the http server does not set q.Response. So q.Response is nil
+ // Thus, we use reflection to read the status from writer whose concrete type is http.response
+ responseVal := reflect.ValueOf(writer).Elem()
+ field := responseVal.FieldByName("status")
+ status := -1
+ if field.IsValid() && field.Kind() == reflect.Int {
+ status = int(field.Int())
+ }
+ ptn := "UNKNOWN"
+ if rt, found := ctrl.FindRouter(ctx); found {
+ ptn = rt.GetPattern()
+ } else {
+ logs.Warn("we can not find the router info for this request, so request will be recorded as UNKNOWN: " + q.URL.String())
+ }
+ ms := dur / time.Millisecond
+ vec.WithLabelValues(ptn, q.Method, strconv.Itoa(status), strconv.Itoa(int(ms))).Observe(float64(ms))
+}
diff --git a/pkg/metric/prometheus_test.go b/pkg/metric/prometheus_test.go
new file mode 100644
index 00000000..d82a6dec
--- /dev/null
+++ b/pkg/metric/prometheus_test.go
@@ -0,0 +1,42 @@
+// Copyright 2020 astaxie
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package metric
+
+import (
+ "net/http"
+ "net/url"
+ "testing"
+ "time"
+
+ "github.com/prometheus/client_golang/prometheus"
+
+ "github.com/astaxie/beego/context"
+)
+
+func TestPrometheusMiddleWare(t *testing.T) {
+ middleware := PrometheusMiddleWare(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
+ writer := &context.Response{}
+ request := &http.Request{
+ URL: &url.URL{
+ Host: "localhost",
+ RawPath: "/a/b/c",
+ },
+ Method: "POST",
+ }
+ vec := prometheus.NewSummaryVec(prometheus.SummaryOpts{}, []string{"pattern", "method", "status", "duration"})
+
+ report(time.Second, writer, request, vec)
+ middleware.ServeHTTP(writer, request)
+}
diff --git a/pkg/migration/ddl.go b/pkg/migration/ddl.go
new file mode 100644
index 00000000..cd2c1c49
--- /dev/null
+++ b/pkg/migration/ddl.go
@@ -0,0 +1,395 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package migration
+
+import (
+ "fmt"
+
+ "github.com/astaxie/beego/logs"
+)
+
+// Index struct defines the structure of Index Columns
+type Index struct {
+ Name string
+}
+
+// Unique struct defines a single unique key combination
+type Unique struct {
+ Definition string
+ Columns []*Column
+}
+
+//Column struct defines a single column of a table
+type Column struct {
+ Name string
+ Inc string
+ Null string
+ Default string
+ Unsign string
+ DataType string
+ remove bool
+ Modify bool
+}
+
+// Foreign struct defines a single foreign relationship
+type Foreign struct {
+ ForeignTable string
+ ForeignColumn string
+ OnDelete string
+ OnUpdate string
+ Column
+}
+
+// RenameColumn struct allows renaming of columns
+type RenameColumn struct {
+ OldName string
+ OldNull string
+ OldDefault string
+ OldUnsign string
+ OldDataType string
+ NewName string
+ Column
+}
+
+// CreateTable creates the table on system
+func (m *Migration) CreateTable(tablename, engine, charset string, p ...func()) {
+ m.TableName = tablename
+ m.Engine = engine
+ m.Charset = charset
+ m.ModifyType = "create"
+}
+
+// AlterTable set the ModifyType to alter
+func (m *Migration) AlterTable(tablename string) {
+ m.TableName = tablename
+ m.ModifyType = "alter"
+}
+
+// NewCol creates a new standard column and attaches it to m struct
+func (m *Migration) NewCol(name string) *Column {
+ col := &Column{Name: name}
+ m.AddColumns(col)
+ return col
+}
+
+//PriCol creates a new primary column and attaches it to m struct
+func (m *Migration) PriCol(name string) *Column {
+ col := &Column{Name: name}
+ m.AddColumns(col)
+ m.AddPrimary(col)
+ return col
+}
+
+//UniCol creates / appends columns to specified unique key and attaches it to m struct
+func (m *Migration) UniCol(uni, name string) *Column {
+ col := &Column{Name: name}
+ m.AddColumns(col)
+
+ uniqueOriginal := &Unique{}
+
+ for _, unique := range m.Uniques {
+ if unique.Definition == uni {
+ unique.AddColumnsToUnique(col)
+ uniqueOriginal = unique
+ }
+ }
+ if uniqueOriginal.Definition == "" {
+ unique := &Unique{Definition: uni}
+ unique.AddColumnsToUnique(col)
+ m.AddUnique(unique)
+ }
+
+ return col
+}
+
+//ForeignCol creates a new foreign column and returns the instance of column
+func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreign *Foreign) {
+
+ foreign = &Foreign{ForeignColumn: foreigncol, ForeignTable: foreigntable}
+ foreign.Name = colname
+ m.AddForeign(foreign)
+ return foreign
+}
+
+//SetOnDelete sets the on delete of foreign
+func (foreign *Foreign) SetOnDelete(del string) *Foreign {
+ foreign.OnDelete = "ON DELETE" + del
+ return foreign
+}
+
+//SetOnUpdate sets the on update of foreign
+func (foreign *Foreign) SetOnUpdate(update string) *Foreign {
+ foreign.OnUpdate = "ON UPDATE" + update
+ return foreign
+}
+
+//Remove marks the columns to be removed.
+//it allows reverse m to create the column.
+func (c *Column) Remove() {
+ c.remove = true
+}
+
+//SetAuto enables auto_increment of column (can be used once)
+func (c *Column) SetAuto(inc bool) *Column {
+ if inc {
+ c.Inc = "auto_increment"
+ }
+ return c
+}
+
+//SetNullable sets the column to be null
+func (c *Column) SetNullable(null bool) *Column {
+ if null {
+ c.Null = ""
+
+ } else {
+ c.Null = "NOT NULL"
+ }
+ return c
+}
+
+//SetDefault sets the default value, prepend with "DEFAULT "
+func (c *Column) SetDefault(def string) *Column {
+ c.Default = "DEFAULT " + def
+ return c
+}
+
+//SetUnsigned sets the column to be unsigned int
+func (c *Column) SetUnsigned(unsign bool) *Column {
+ if unsign {
+ c.Unsign = "UNSIGNED"
+ }
+ return c
+}
+
+//SetDataType sets the dataType of the column
+func (c *Column) SetDataType(dataType string) *Column {
+ c.DataType = dataType
+ return c
+}
+
+//SetOldNullable allows reverting to previous nullable on reverse ms
+func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn {
+ if null {
+ c.OldNull = ""
+
+ } else {
+ c.OldNull = "NOT NULL"
+ }
+ return c
+}
+
+//SetOldDefault allows reverting to previous default on reverse ms
+func (c *RenameColumn) SetOldDefault(def string) *RenameColumn {
+ c.OldDefault = def
+ return c
+}
+
+//SetOldUnsigned allows reverting to previous unsgined on reverse ms
+func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn {
+ if unsign {
+ c.OldUnsign = "UNSIGNED"
+ }
+ return c
+}
+
+//SetOldDataType allows reverting to previous datatype on reverse ms
+func (c *RenameColumn) SetOldDataType(dataType string) *RenameColumn {
+ c.OldDataType = dataType
+ return c
+}
+
+//SetPrimary adds the columns to the primary key (can only be used any number of times in only one m)
+func (c *Column) SetPrimary(m *Migration) *Column {
+ m.Primary = append(m.Primary, c)
+ return c
+}
+
+//AddColumnsToUnique adds the columns to Unique Struct
+func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique {
+
+ unique.Columns = append(unique.Columns, columns...)
+
+ return unique
+}
+
+//AddColumns adds columns to m struct
+func (m *Migration) AddColumns(columns ...*Column) *Migration {
+
+ m.Columns = append(m.Columns, columns...)
+
+ return m
+}
+
+//AddPrimary adds the column to primary in m struct
+func (m *Migration) AddPrimary(primary *Column) *Migration {
+ m.Primary = append(m.Primary, primary)
+ return m
+}
+
+//AddUnique adds the column to unique in m struct
+func (m *Migration) AddUnique(unique *Unique) *Migration {
+ m.Uniques = append(m.Uniques, unique)
+ return m
+}
+
+//AddForeign adds the column to foreign in m struct
+func (m *Migration) AddForeign(foreign *Foreign) *Migration {
+ m.Foreigns = append(m.Foreigns, foreign)
+ return m
+}
+
+//AddIndex adds the column to index in m struct
+func (m *Migration) AddIndex(index *Index) *Migration {
+ m.Indexes = append(m.Indexes, index)
+ return m
+}
+
+//RenameColumn allows renaming of columns
+func (m *Migration) RenameColumn(from, to string) *RenameColumn {
+ rename := &RenameColumn{OldName: from, NewName: to}
+ m.Renames = append(m.Renames, rename)
+ return rename
+}
+
+//GetSQL returns the generated sql depending on ModifyType
+func (m *Migration) GetSQL() (sql string) {
+ sql = ""
+ switch m.ModifyType {
+ case "create":
+ {
+ sql += fmt.Sprintf("CREATE TABLE `%s` (", m.TableName)
+ for index, column := range m.Columns {
+ sql += fmt.Sprintf("\n `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
+ if len(m.Columns) > index+1 {
+ sql += ","
+ }
+ }
+
+ if len(m.Primary) > 0 {
+ sql += fmt.Sprintf(",\n PRIMARY KEY( ")
+ }
+ for index, column := range m.Primary {
+ sql += fmt.Sprintf(" `%s`", column.Name)
+ if len(m.Primary) > index+1 {
+ sql += ","
+ }
+
+ }
+ if len(m.Primary) > 0 {
+ sql += fmt.Sprintf(")")
+ }
+
+ for _, unique := range m.Uniques {
+ sql += fmt.Sprintf(",\n UNIQUE KEY `%s`( ", unique.Definition)
+ for index, column := range unique.Columns {
+ sql += fmt.Sprintf(" `%s`", column.Name)
+ if len(unique.Columns) > index+1 {
+ sql += ","
+ }
+ }
+ sql += fmt.Sprintf(")")
+ }
+ for _, foreign := range m.Foreigns {
+ sql += fmt.Sprintf(",\n `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default)
+ sql += fmt.Sprintf(",\n KEY `%s_%s_foreign`(`%s`),", m.TableName, foreign.Column.Name, foreign.Column.Name)
+ sql += fmt.Sprintf("\n CONSTRAINT `%s_%s_foreign` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) %s %s", m.TableName, foreign.Column.Name, foreign.Column.Name, foreign.ForeignTable, foreign.ForeignColumn, foreign.OnDelete, foreign.OnUpdate)
+
+ }
+ sql += fmt.Sprintf(")ENGINE=%s DEFAULT CHARSET=%s;", m.Engine, m.Charset)
+ break
+ }
+ case "alter":
+ {
+ sql += fmt.Sprintf("ALTER TABLE `%s` ", m.TableName)
+ for index, column := range m.Columns {
+ if !column.remove {
+ logs.Info("col")
+ sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
+ } else {
+ sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name)
+ }
+
+ if len(m.Columns) > index+1 {
+ sql += ","
+ }
+ }
+ for index, column := range m.Renames {
+ sql += fmt.Sprintf("CHANGE COLUMN `%s` `%s` %s %s %s %s %s", column.OldName, column.NewName, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
+ if len(m.Renames) > index+1 {
+ sql += ","
+ }
+ }
+
+ for index, foreign := range m.Foreigns {
+ sql += fmt.Sprintf("ADD `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default)
+ sql += fmt.Sprintf(",\n ADD KEY `%s_%s_foreign`(`%s`)", m.TableName, foreign.Column.Name, foreign.Column.Name)
+ sql += fmt.Sprintf(",\n ADD CONSTRAINT `%s_%s_foreign` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) %s %s", m.TableName, foreign.Column.Name, foreign.Column.Name, foreign.ForeignTable, foreign.ForeignColumn, foreign.OnDelete, foreign.OnUpdate)
+ if len(m.Foreigns) > index+1 {
+ sql += ","
+ }
+ }
+ sql += ";"
+
+ break
+ }
+ case "reverse":
+ {
+
+ sql += fmt.Sprintf("ALTER TABLE `%s`", m.TableName)
+ for index, column := range m.Columns {
+ if column.remove {
+ sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
+ } else {
+ sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name)
+ }
+ if len(m.Columns) > index+1 {
+ sql += ","
+ }
+ }
+
+ if len(m.Primary) > 0 {
+ sql += fmt.Sprintf("\n DROP PRIMARY KEY,")
+ }
+
+ for index, unique := range m.Uniques {
+ sql += fmt.Sprintf("\n DROP KEY `%s`", unique.Definition)
+ if len(m.Uniques) > index+1 {
+ sql += ","
+ }
+
+ }
+ for index, column := range m.Renames {
+ sql += fmt.Sprintf("\n CHANGE COLUMN `%s` `%s` %s %s %s %s", column.NewName, column.OldName, column.OldDataType, column.OldUnsign, column.OldNull, column.OldDefault)
+ if len(m.Renames) > index+1 {
+ sql += ","
+ }
+ }
+
+ for _, foreign := range m.Foreigns {
+ sql += fmt.Sprintf("\n DROP KEY `%s_%s_foreign`", m.TableName, foreign.Column.Name)
+ sql += fmt.Sprintf(",\n DROP FOREIGN KEY `%s_%s_foreign`", m.TableName, foreign.Column.Name)
+ sql += fmt.Sprintf(",\n DROP COLUMN `%s`", foreign.Name)
+ }
+ sql += ";"
+ }
+ case "delete":
+ {
+ sql += fmt.Sprintf("DROP TABLE IF EXISTS `%s`;", m.TableName)
+ }
+ }
+
+ return
+}
diff --git a/pkg/migration/doc.go b/pkg/migration/doc.go
new file mode 100644
index 00000000..0c6564d4
--- /dev/null
+++ b/pkg/migration/doc.go
@@ -0,0 +1,32 @@
+// Package migration enables you to generate migrations back and forth. It generates both migrations.
+//
+// //Creates a table
+// m.CreateTable("tablename","InnoDB","utf8");
+//
+// //Alter a table
+// m.AlterTable("tablename")
+//
+// Standard Column Methods
+// * SetDataType
+// * SetNullable
+// * SetDefault
+// * SetUnsigned (use only on integer types unless produces error)
+//
+// //Sets a primary column, multiple calls allowed, standard column methods available
+// m.PriCol("id").SetAuto(true).SetNullable(false).SetDataType("INT(10)").SetUnsigned(true)
+//
+// //UniCol Can be used multiple times, allows standard Column methods. Use same "index" string to add to same index
+// m.UniCol("index","column")
+//
+// //Standard Column Initialisation, can call .Remove() after NewCol("") on alter to remove
+// m.NewCol("name").SetDataType("VARCHAR(255) COLLATE utf8_unicode_ci").SetNullable(false)
+// m.NewCol("value").SetDataType("DOUBLE(8,2)").SetNullable(false)
+//
+// //Rename Columns , only use with Alter table, doesn't works with Create, prefix standard column methods with "Old" to
+// //create a true reversible migration eg: SetOldDataType("DOUBLE(12,3)")
+// m.RenameColumn("from","to")...
+//
+// //Foreign Columns, single columns are only supported, SetOnDelete & SetOnUpdate are available, call appropriately.
+// //Supports standard column methods, automatic reverse.
+// m.ForeignCol("local_col","foreign_col","foreign_table")
+package migration
diff --git a/pkg/migration/migration.go b/pkg/migration/migration.go
new file mode 100644
index 00000000..5ddfd972
--- /dev/null
+++ b/pkg/migration/migration.go
@@ -0,0 +1,330 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package migration is used for migration
+//
+// The table structure is as follow:
+//
+// CREATE TABLE `migrations` (
+// `id_migration` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'surrogate key',
+// `name` varchar(255) DEFAULT NULL COMMENT 'migration name, unique',
+// `created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'date migrated or rolled back',
+// `statements` longtext COMMENT 'SQL statements for this migration',
+// `rollback_statements` longtext,
+// `status` enum('update','rollback') DEFAULT NULL COMMENT 'update indicates it is a normal migration while rollback means this migration is rolled back',
+// PRIMARY KEY (`id_migration`)
+// ) ENGINE=InnoDB DEFAULT CHARSET=utf8;
+package migration
+
+import (
+ "errors"
+ "sort"
+ "strings"
+ "time"
+
+ "github.com/astaxie/beego/logs"
+ "github.com/astaxie/beego/orm"
+)
+
+// const the data format for the bee generate migration datatype
+const (
+ DateFormat = "20060102_150405"
+ DBDateFormat = "2006-01-02 15:04:05"
+)
+
+// Migrationer is an interface for all Migration struct
+type Migrationer interface {
+ Up()
+ Down()
+ Reset()
+ Exec(name, status string) error
+ GetCreated() int64
+}
+
+//Migration defines the migrations by either SQL or DDL
+type Migration struct {
+ sqls []string
+ Created string
+ TableName string
+ Engine string
+ Charset string
+ ModifyType string
+ Columns []*Column
+ Indexes []*Index
+ Primary []*Column
+ Uniques []*Unique
+ Foreigns []*Foreign
+ Renames []*RenameColumn
+ RemoveColumns []*Column
+ RemoveIndexes []*Index
+ RemoveUniques []*Unique
+ RemoveForeigns []*Foreign
+}
+
+var (
+ migrationMap map[string]Migrationer
+)
+
+func init() {
+ migrationMap = make(map[string]Migrationer)
+}
+
+// Up implement in the Inheritance struct for upgrade
+func (m *Migration) Up() {
+
+ switch m.ModifyType {
+ case "reverse":
+ m.ModifyType = "alter"
+ case "delete":
+ m.ModifyType = "create"
+ }
+ m.sqls = append(m.sqls, m.GetSQL())
+}
+
+// Down implement in the Inheritance struct for down
+func (m *Migration) Down() {
+
+ switch m.ModifyType {
+ case "alter":
+ m.ModifyType = "reverse"
+ case "create":
+ m.ModifyType = "delete"
+ }
+ m.sqls = append(m.sqls, m.GetSQL())
+}
+
+//Migrate adds the SQL to the execution list
+func (m *Migration) Migrate(migrationType string) {
+ m.ModifyType = migrationType
+ m.sqls = append(m.sqls, m.GetSQL())
+}
+
+// SQL add sql want to execute
+func (m *Migration) SQL(sql string) {
+ m.sqls = append(m.sqls, sql)
+}
+
+// Reset the sqls
+func (m *Migration) Reset() {
+ m.sqls = make([]string, 0)
+}
+
+// Exec execute the sql already add in the sql
+func (m *Migration) Exec(name, status string) error {
+ o := orm.NewOrm()
+ for _, s := range m.sqls {
+ logs.Info("exec sql:", s)
+ r := o.Raw(s)
+ _, err := r.Exec()
+ if err != nil {
+ return err
+ }
+ }
+ return m.addOrUpdateRecord(name, status)
+}
+
+func (m *Migration) addOrUpdateRecord(name, status string) error {
+ o := orm.NewOrm()
+ if status == "down" {
+ status = "rollback"
+ p, err := o.Raw("update migrations set status = ?, rollback_statements = ?, created_at = ? where name = ?").Prepare()
+ if err != nil {
+ return nil
+ }
+ _, err = p.Exec(status, strings.Join(m.sqls, "; "), time.Now().Format(DBDateFormat), name)
+ return err
+ }
+ status = "update"
+ p, err := o.Raw("insert into migrations(name, created_at, statements, status) values(?,?,?,?)").Prepare()
+ if err != nil {
+ return err
+ }
+ _, err = p.Exec(name, time.Now().Format(DBDateFormat), strings.Join(m.sqls, "; "), status)
+ return err
+}
+
+// GetCreated get the unixtime from the Created
+func (m *Migration) GetCreated() int64 {
+ t, err := time.Parse(DateFormat, m.Created)
+ if err != nil {
+ return 0
+ }
+ return t.Unix()
+}
+
+// Register register the Migration in the map
+func Register(name string, m Migrationer) error {
+ if _, ok := migrationMap[name]; ok {
+ return errors.New("already exist name:" + name)
+ }
+ migrationMap[name] = m
+ return nil
+}
+
+// Upgrade upgrade the migration from lasttime
+func Upgrade(lasttime int64) error {
+ sm := sortMap(migrationMap)
+ i := 0
+ migs, _ := getAllMigrations()
+ for _, v := range sm {
+ if _, ok := migs[v.name]; !ok {
+ logs.Info("start upgrade", v.name)
+ v.m.Reset()
+ v.m.Up()
+ err := v.m.Exec(v.name, "up")
+ if err != nil {
+ logs.Error("execute error:", err)
+ time.Sleep(2 * time.Second)
+ return err
+ }
+ logs.Info("end upgrade:", v.name)
+ i++
+ }
+ }
+ logs.Info("total success upgrade:", i, " migration")
+ time.Sleep(2 * time.Second)
+ return nil
+}
+
+// Rollback rollback the migration by the name
+func Rollback(name string) error {
+ if v, ok := migrationMap[name]; ok {
+ logs.Info("start rollback")
+ v.Reset()
+ v.Down()
+ err := v.Exec(name, "down")
+ if err != nil {
+ logs.Error("execute error:", err)
+ time.Sleep(2 * time.Second)
+ return err
+ }
+ logs.Info("end rollback")
+ time.Sleep(2 * time.Second)
+ return nil
+ }
+ logs.Error("not exist the migrationMap name:" + name)
+ time.Sleep(2 * time.Second)
+ return errors.New("not exist the migrationMap name:" + name)
+}
+
+// Reset reset all migration
+// run all migration's down function
+func Reset() error {
+ sm := sortMap(migrationMap)
+ i := 0
+ for j := len(sm) - 1; j >= 0; j-- {
+ v := sm[j]
+ if isRollBack(v.name) {
+ logs.Info("skip the", v.name)
+ time.Sleep(1 * time.Second)
+ continue
+ }
+ logs.Info("start reset:", v.name)
+ v.m.Reset()
+ v.m.Down()
+ err := v.m.Exec(v.name, "down")
+ if err != nil {
+ logs.Error("execute error:", err)
+ time.Sleep(2 * time.Second)
+ return err
+ }
+ i++
+ logs.Info("end reset:", v.name)
+ }
+ logs.Info("total success reset:", i, " migration")
+ time.Sleep(2 * time.Second)
+ return nil
+}
+
+// Refresh first Reset, then Upgrade
+func Refresh() error {
+ err := Reset()
+ if err != nil {
+ logs.Error("execute error:", err)
+ time.Sleep(2 * time.Second)
+ return err
+ }
+ err = Upgrade(0)
+ return err
+}
+
+type dataSlice []data
+
+type data struct {
+ created int64
+ name string
+ m Migrationer
+}
+
+// Len is part of sort.Interface.
+func (d dataSlice) Len() int {
+ return len(d)
+}
+
+// Swap is part of sort.Interface.
+func (d dataSlice) Swap(i, j int) {
+ d[i], d[j] = d[j], d[i]
+}
+
+// Less is part of sort.Interface. We use count as the value to sort by
+func (d dataSlice) Less(i, j int) bool {
+ return d[i].created < d[j].created
+}
+
+func sortMap(m map[string]Migrationer) dataSlice {
+ s := make(dataSlice, 0, len(m))
+ for k, v := range m {
+ d := data{}
+ d.created = v.GetCreated()
+ d.name = k
+ d.m = v
+ s = append(s, d)
+ }
+ sort.Sort(s)
+ return s
+}
+
+func isRollBack(name string) bool {
+ o := orm.NewOrm()
+ var maps []orm.Params
+ num, err := o.Raw("select * from migrations where `name` = ? order by id_migration desc", name).Values(&maps)
+ if err != nil {
+ logs.Info("get name has error", err)
+ return false
+ }
+ if num <= 0 {
+ return false
+ }
+ if maps[0]["status"] == "rollback" {
+ return true
+ }
+ return false
+}
+func getAllMigrations() (map[string]string, error) {
+ o := orm.NewOrm()
+ var maps []orm.Params
+ migs := make(map[string]string)
+ num, err := o.Raw("select * from migrations order by id_migration desc").Values(&maps)
+ if err != nil {
+ logs.Info("get name has error", err)
+ return migs, err
+ }
+ if num > 0 {
+ for _, v := range maps {
+ name := v["name"].(string)
+ migs[name] = v["status"].(string)
+ }
+ }
+ return migs, nil
+}
diff --git a/pkg/mime.go b/pkg/mime.go
new file mode 100644
index 00000000..ca2878ab
--- /dev/null
+++ b/pkg/mime.go
@@ -0,0 +1,556 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package beego
+
+var mimemaps = map[string]string{
+ ".3dm": "x-world/x-3dmf",
+ ".3dmf": "x-world/x-3dmf",
+ ".7z": "application/x-7z-compressed",
+ ".a": "application/octet-stream",
+ ".aab": "application/x-authorware-bin",
+ ".aam": "application/x-authorware-map",
+ ".aas": "application/x-authorware-seg",
+ ".abc": "text/vndabc",
+ ".ace": "application/x-ace-compressed",
+ ".acgi": "text/html",
+ ".afl": "video/animaflex",
+ ".ai": "application/postscript",
+ ".aif": "audio/aiff",
+ ".aifc": "audio/aiff",
+ ".aiff": "audio/aiff",
+ ".aim": "application/x-aim",
+ ".aip": "text/x-audiosoft-intra",
+ ".alz": "application/x-alz-compressed",
+ ".ani": "application/x-navi-animation",
+ ".aos": "application/x-nokia-9000-communicator-add-on-software",
+ ".aps": "application/mime",
+ ".apk": "application/vnd.android.package-archive",
+ ".arc": "application/x-arc-compressed",
+ ".arj": "application/arj",
+ ".art": "image/x-jg",
+ ".asf": "video/x-ms-asf",
+ ".asm": "text/x-asm",
+ ".asp": "text/asp",
+ ".asx": "application/x-mplayer2",
+ ".au": "audio/basic",
+ ".avi": "video/x-msvideo",
+ ".avs": "video/avs-video",
+ ".bcpio": "application/x-bcpio",
+ ".bin": "application/mac-binary",
+ ".bmp": "image/bmp",
+ ".boo": "application/book",
+ ".book": "application/book",
+ ".boz": "application/x-bzip2",
+ ".bsh": "application/x-bsh",
+ ".bz2": "application/x-bzip2",
+ ".bz": "application/x-bzip",
+ ".c++": "text/plain",
+ ".c": "text/x-c",
+ ".cab": "application/vnd.ms-cab-compressed",
+ ".cat": "application/vndms-pkiseccat",
+ ".cc": "text/x-c",
+ ".ccad": "application/clariscad",
+ ".cco": "application/x-cocoa",
+ ".cdf": "application/cdf",
+ ".cer": "application/pkix-cert",
+ ".cha": "application/x-chat",
+ ".chat": "application/x-chat",
+ ".chrt": "application/vnd.kde.kchart",
+ ".class": "application/java",
+ ".com": "text/plain",
+ ".conf": "text/plain",
+ ".cpio": "application/x-cpio",
+ ".cpp": "text/x-c",
+ ".cpt": "application/mac-compactpro",
+ ".crl": "application/pkcs-crl",
+ ".crt": "application/pkix-cert",
+ ".crx": "application/x-chrome-extension",
+ ".csh": "text/x-scriptcsh",
+ ".css": "text/css",
+ ".csv": "text/csv",
+ ".cxx": "text/plain",
+ ".dar": "application/x-dar",
+ ".dcr": "application/x-director",
+ ".deb": "application/x-debian-package",
+ ".deepv": "application/x-deepv",
+ ".def": "text/plain",
+ ".der": "application/x-x509-ca-cert",
+ ".dif": "video/x-dv",
+ ".dir": "application/x-director",
+ ".divx": "video/divx",
+ ".dl": "video/dl",
+ ".dmg": "application/x-apple-diskimage",
+ ".doc": "application/msword",
+ ".dot": "application/msword",
+ ".dp": "application/commonground",
+ ".drw": "application/drafting",
+ ".dump": "application/octet-stream",
+ ".dv": "video/x-dv",
+ ".dvi": "application/x-dvi",
+ ".dwf": "drawing/x-dwf=(old)",
+ ".dwg": "application/acad",
+ ".dxf": "application/dxf",
+ ".dxr": "application/x-director",
+ ".el": "text/x-scriptelisp",
+ ".elc": "application/x-bytecodeelisp=(compiled=elisp)",
+ ".eml": "message/rfc822",
+ ".env": "application/x-envoy",
+ ".eps": "application/postscript",
+ ".es": "application/x-esrehber",
+ ".etx": "text/x-setext",
+ ".evy": "application/envoy",
+ ".exe": "application/octet-stream",
+ ".f77": "text/x-fortran",
+ ".f90": "text/x-fortran",
+ ".f": "text/x-fortran",
+ ".fdf": "application/vndfdf",
+ ".fif": "application/fractals",
+ ".fli": "video/fli",
+ ".flo": "image/florian",
+ ".flv": "video/x-flv",
+ ".flx": "text/vndfmiflexstor",
+ ".fmf": "video/x-atomic3d-feature",
+ ".for": "text/x-fortran",
+ ".fpx": "image/vndfpx",
+ ".frl": "application/freeloader",
+ ".funk": "audio/make",
+ ".g3": "image/g3fax",
+ ".g": "text/plain",
+ ".gif": "image/gif",
+ ".gl": "video/gl",
+ ".gsd": "audio/x-gsm",
+ ".gsm": "audio/x-gsm",
+ ".gsp": "application/x-gsp",
+ ".gss": "application/x-gss",
+ ".gtar": "application/x-gtar",
+ ".gz": "application/x-compressed",
+ ".gzip": "application/x-gzip",
+ ".h": "text/x-h",
+ ".hdf": "application/x-hdf",
+ ".help": "application/x-helpfile",
+ ".hgl": "application/vndhp-hpgl",
+ ".hh": "text/x-h",
+ ".hlb": "text/x-script",
+ ".hlp": "application/hlp",
+ ".hpg": "application/vndhp-hpgl",
+ ".hpgl": "application/vndhp-hpgl",
+ ".hqx": "application/binhex",
+ ".hta": "application/hta",
+ ".htc": "text/x-component",
+ ".htm": "text/html",
+ ".html": "text/html",
+ ".htmls": "text/html",
+ ".htt": "text/webviewhtml",
+ ".htx": "text/html",
+ ".ice": "x-conference/x-cooltalk",
+ ".ico": "image/x-icon",
+ ".ics": "text/calendar",
+ ".icz": "text/calendar",
+ ".idc": "text/plain",
+ ".ief": "image/ief",
+ ".iefs": "image/ief",
+ ".iges": "application/iges",
+ ".igs": "application/iges",
+ ".ima": "application/x-ima",
+ ".imap": "application/x-httpd-imap",
+ ".inf": "application/inf",
+ ".ins": "application/x-internett-signup",
+ ".ip": "application/x-ip2",
+ ".isu": "video/x-isvideo",
+ ".it": "audio/it",
+ ".iv": "application/x-inventor",
+ ".ivr": "i-world/i-vrml",
+ ".ivy": "application/x-livescreen",
+ ".jam": "audio/x-jam",
+ ".jav": "text/x-java-source",
+ ".java": "text/x-java-source",
+ ".jcm": "application/x-java-commerce",
+ ".jfif-tbnl": "image/jpeg",
+ ".jfif": "image/jpeg",
+ ".jnlp": "application/x-java-jnlp-file",
+ ".jpe": "image/jpeg",
+ ".jpeg": "image/jpeg",
+ ".jpg": "image/jpeg",
+ ".jps": "image/x-jps",
+ ".js": "application/javascript",
+ ".json": "application/json",
+ ".jut": "image/jutvision",
+ ".kar": "audio/midi",
+ ".karbon": "application/vnd.kde.karbon",
+ ".kfo": "application/vnd.kde.kformula",
+ ".flw": "application/vnd.kde.kivio",
+ ".kml": "application/vnd.google-earth.kml+xml",
+ ".kmz": "application/vnd.google-earth.kmz",
+ ".kon": "application/vnd.kde.kontour",
+ ".kpr": "application/vnd.kde.kpresenter",
+ ".kpt": "application/vnd.kde.kpresenter",
+ ".ksp": "application/vnd.kde.kspread",
+ ".kwd": "application/vnd.kde.kword",
+ ".kwt": "application/vnd.kde.kword",
+ ".ksh": "text/x-scriptksh",
+ ".la": "audio/nspaudio",
+ ".lam": "audio/x-liveaudio",
+ ".latex": "application/x-latex",
+ ".lha": "application/lha",
+ ".lhx": "application/octet-stream",
+ ".list": "text/plain",
+ ".lma": "audio/nspaudio",
+ ".log": "text/plain",
+ ".lsp": "text/x-scriptlisp",
+ ".lst": "text/plain",
+ ".lsx": "text/x-la-asf",
+ ".ltx": "application/x-latex",
+ ".lzh": "application/octet-stream",
+ ".lzx": "application/lzx",
+ ".m1v": "video/mpeg",
+ ".m2a": "audio/mpeg",
+ ".m2v": "video/mpeg",
+ ".m3u": "audio/x-mpegurl",
+ ".m": "text/x-m",
+ ".man": "application/x-troff-man",
+ ".manifest": "text/cache-manifest",
+ ".map": "application/x-navimap",
+ ".mar": "text/plain",
+ ".mbd": "application/mbedlet",
+ ".mc$": "application/x-magic-cap-package-10",
+ ".mcd": "application/mcad",
+ ".mcf": "text/mcf",
+ ".mcp": "application/netmc",
+ ".me": "application/x-troff-me",
+ ".mht": "message/rfc822",
+ ".mhtml": "message/rfc822",
+ ".mid": "application/x-midi",
+ ".midi": "application/x-midi",
+ ".mif": "application/x-frame",
+ ".mime": "message/rfc822",
+ ".mjf": "audio/x-vndaudioexplosionmjuicemediafile",
+ ".mjpg": "video/x-motion-jpeg",
+ ".mm": "application/base64",
+ ".mme": "application/base64",
+ ".mod": "audio/mod",
+ ".moov": "video/quicktime",
+ ".mov": "video/quicktime",
+ ".movie": "video/x-sgi-movie",
+ ".mp2": "audio/mpeg",
+ ".mp3": "audio/mpeg3",
+ ".mp4": "video/mp4",
+ ".mpa": "audio/mpeg",
+ ".mpc": "application/x-project",
+ ".mpe": "video/mpeg",
+ ".mpeg": "video/mpeg",
+ ".mpg": "video/mpeg",
+ ".mpga": "audio/mpeg",
+ ".mpp": "application/vndms-project",
+ ".mpt": "application/x-project",
+ ".mpv": "application/x-project",
+ ".mpx": "application/x-project",
+ ".mrc": "application/marc",
+ ".ms": "application/x-troff-ms",
+ ".mv": "video/x-sgi-movie",
+ ".my": "audio/make",
+ ".mzz": "application/x-vndaudioexplosionmzz",
+ ".nap": "image/naplps",
+ ".naplps": "image/naplps",
+ ".nc": "application/x-netcdf",
+ ".ncm": "application/vndnokiaconfiguration-message",
+ ".nif": "image/x-niff",
+ ".niff": "image/x-niff",
+ ".nix": "application/x-mix-transfer",
+ ".nsc": "application/x-conference",
+ ".nvd": "application/x-navidoc",
+ ".o": "application/octet-stream",
+ ".oda": "application/oda",
+ ".odb": "application/vnd.oasis.opendocument.database",
+ ".odc": "application/vnd.oasis.opendocument.chart",
+ ".odf": "application/vnd.oasis.opendocument.formula",
+ ".odg": "application/vnd.oasis.opendocument.graphics",
+ ".odi": "application/vnd.oasis.opendocument.image",
+ ".odm": "application/vnd.oasis.opendocument.text-master",
+ ".odp": "application/vnd.oasis.opendocument.presentation",
+ ".ods": "application/vnd.oasis.opendocument.spreadsheet",
+ ".odt": "application/vnd.oasis.opendocument.text",
+ ".oga": "audio/ogg",
+ ".ogg": "audio/ogg",
+ ".ogv": "video/ogg",
+ ".omc": "application/x-omc",
+ ".omcd": "application/x-omcdatamaker",
+ ".omcr": "application/x-omcregerator",
+ ".otc": "application/vnd.oasis.opendocument.chart-template",
+ ".otf": "application/vnd.oasis.opendocument.formula-template",
+ ".otg": "application/vnd.oasis.opendocument.graphics-template",
+ ".oth": "application/vnd.oasis.opendocument.text-web",
+ ".oti": "application/vnd.oasis.opendocument.image-template",
+ ".otm": "application/vnd.oasis.opendocument.text-master",
+ ".otp": "application/vnd.oasis.opendocument.presentation-template",
+ ".ots": "application/vnd.oasis.opendocument.spreadsheet-template",
+ ".ott": "application/vnd.oasis.opendocument.text-template",
+ ".p10": "application/pkcs10",
+ ".p12": "application/pkcs-12",
+ ".p7a": "application/x-pkcs7-signature",
+ ".p7c": "application/pkcs7-mime",
+ ".p7m": "application/pkcs7-mime",
+ ".p7r": "application/x-pkcs7-certreqresp",
+ ".p7s": "application/pkcs7-signature",
+ ".p": "text/x-pascal",
+ ".part": "application/pro_eng",
+ ".pas": "text/pascal",
+ ".pbm": "image/x-portable-bitmap",
+ ".pcl": "application/vndhp-pcl",
+ ".pct": "image/x-pict",
+ ".pcx": "image/x-pcx",
+ ".pdb": "chemical/x-pdb",
+ ".pdf": "application/pdf",
+ ".pfunk": "audio/make",
+ ".pgm": "image/x-portable-graymap",
+ ".pic": "image/pict",
+ ".pict": "image/pict",
+ ".pkg": "application/x-newton-compatible-pkg",
+ ".pko": "application/vndms-pkipko",
+ ".pl": "text/x-scriptperl",
+ ".plx": "application/x-pixclscript",
+ ".pm4": "application/x-pagemaker",
+ ".pm5": "application/x-pagemaker",
+ ".pm": "text/x-scriptperl-module",
+ ".png": "image/png",
+ ".pnm": "application/x-portable-anymap",
+ ".pot": "application/mspowerpoint",
+ ".pov": "model/x-pov",
+ ".ppa": "application/vndms-powerpoint",
+ ".ppm": "image/x-portable-pixmap",
+ ".pps": "application/mspowerpoint",
+ ".ppt": "application/mspowerpoint",
+ ".ppz": "application/mspowerpoint",
+ ".pre": "application/x-freelance",
+ ".prt": "application/pro_eng",
+ ".ps": "application/postscript",
+ ".psd": "application/octet-stream",
+ ".pvu": "paleovu/x-pv",
+ ".pwz": "application/vndms-powerpoint",
+ ".py": "text/x-scriptphyton",
+ ".pyc": "application/x-bytecodepython",
+ ".qcp": "audio/vndqcelp",
+ ".qd3": "x-world/x-3dmf",
+ ".qd3d": "x-world/x-3dmf",
+ ".qif": "image/x-quicktime",
+ ".qt": "video/quicktime",
+ ".qtc": "video/x-qtc",
+ ".qti": "image/x-quicktime",
+ ".qtif": "image/x-quicktime",
+ ".ra": "audio/x-pn-realaudio",
+ ".ram": "audio/x-pn-realaudio",
+ ".rar": "application/x-rar-compressed",
+ ".ras": "application/x-cmu-raster",
+ ".rast": "image/cmu-raster",
+ ".rexx": "text/x-scriptrexx",
+ ".rf": "image/vndrn-realflash",
+ ".rgb": "image/x-rgb",
+ ".rm": "application/vndrn-realmedia",
+ ".rmi": "audio/mid",
+ ".rmm": "audio/x-pn-realaudio",
+ ".rmp": "audio/x-pn-realaudio",
+ ".rng": "application/ringing-tones",
+ ".rnx": "application/vndrn-realplayer",
+ ".roff": "application/x-troff",
+ ".rp": "image/vndrn-realpix",
+ ".rpm": "audio/x-pn-realaudio-plugin",
+ ".rt": "text/vndrn-realtext",
+ ".rtf": "text/richtext",
+ ".rtx": "text/richtext",
+ ".rv": "video/vndrn-realvideo",
+ ".s": "text/x-asm",
+ ".s3m": "audio/s3m",
+ ".s7z": "application/x-7z-compressed",
+ ".saveme": "application/octet-stream",
+ ".sbk": "application/x-tbook",
+ ".scm": "text/x-scriptscheme",
+ ".sdml": "text/plain",
+ ".sdp": "application/sdp",
+ ".sdr": "application/sounder",
+ ".sea": "application/sea",
+ ".set": "application/set",
+ ".sgm": "text/x-sgml",
+ ".sgml": "text/x-sgml",
+ ".sh": "text/x-scriptsh",
+ ".shar": "application/x-bsh",
+ ".shtml": "text/x-server-parsed-html",
+ ".sid": "audio/x-psid",
+ ".skd": "application/x-koan",
+ ".skm": "application/x-koan",
+ ".skp": "application/x-koan",
+ ".skt": "application/x-koan",
+ ".sit": "application/x-stuffit",
+ ".sitx": "application/x-stuffitx",
+ ".sl": "application/x-seelogo",
+ ".smi": "application/smil",
+ ".smil": "application/smil",
+ ".snd": "audio/basic",
+ ".sol": "application/solids",
+ ".spc": "text/x-speech",
+ ".spl": "application/futuresplash",
+ ".spr": "application/x-sprite",
+ ".sprite": "application/x-sprite",
+ ".spx": "audio/ogg",
+ ".src": "application/x-wais-source",
+ ".ssi": "text/x-server-parsed-html",
+ ".ssm": "application/streamingmedia",
+ ".sst": "application/vndms-pkicertstore",
+ ".step": "application/step",
+ ".stl": "application/sla",
+ ".stp": "application/step",
+ ".sv4cpio": "application/x-sv4cpio",
+ ".sv4crc": "application/x-sv4crc",
+ ".svf": "image/vnddwg",
+ ".svg": "image/svg+xml",
+ ".svr": "application/x-world",
+ ".swf": "application/x-shockwave-flash",
+ ".t": "application/x-troff",
+ ".talk": "text/x-speech",
+ ".tar": "application/x-tar",
+ ".tbk": "application/toolbook",
+ ".tcl": "text/x-scripttcl",
+ ".tcsh": "text/x-scripttcsh",
+ ".tex": "application/x-tex",
+ ".texi": "application/x-texinfo",
+ ".texinfo": "application/x-texinfo",
+ ".text": "text/plain",
+ ".tgz": "application/gnutar",
+ ".tif": "image/tiff",
+ ".tiff": "image/tiff",
+ ".tr": "application/x-troff",
+ ".tsi": "audio/tsp-audio",
+ ".tsp": "application/dsptype",
+ ".tsv": "text/tab-separated-values",
+ ".turbot": "image/florian",
+ ".txt": "text/plain",
+ ".uil": "text/x-uil",
+ ".uni": "text/uri-list",
+ ".unis": "text/uri-list",
+ ".unv": "application/i-deas",
+ ".uri": "text/uri-list",
+ ".uris": "text/uri-list",
+ ".ustar": "application/x-ustar",
+ ".uu": "text/x-uuencode",
+ ".uue": "text/x-uuencode",
+ ".vcd": "application/x-cdlink",
+ ".vcf": "text/x-vcard",
+ ".vcard": "text/x-vcard",
+ ".vcs": "text/x-vcalendar",
+ ".vda": "application/vda",
+ ".vdo": "video/vdo",
+ ".vew": "application/groupwise",
+ ".viv": "video/vivo",
+ ".vivo": "video/vivo",
+ ".vmd": "application/vocaltec-media-desc",
+ ".vmf": "application/vocaltec-media-file",
+ ".voc": "audio/voc",
+ ".vos": "video/vosaic",
+ ".vox": "audio/voxware",
+ ".vqe": "audio/x-twinvq-plugin",
+ ".vqf": "audio/x-twinvq",
+ ".vql": "audio/x-twinvq-plugin",
+ ".vrml": "application/x-vrml",
+ ".vrt": "x-world/x-vrt",
+ ".vsd": "application/x-visio",
+ ".vst": "application/x-visio",
+ ".vsw": "application/x-visio",
+ ".w60": "application/wordperfect60",
+ ".w61": "application/wordperfect61",
+ ".w6w": "application/msword",
+ ".wav": "audio/wav",
+ ".wb1": "application/x-qpro",
+ ".wbmp": "image/vnd.wap.wbmp",
+ ".web": "application/vndxara",
+ ".wiz": "application/msword",
+ ".wk1": "application/x-123",
+ ".wmf": "windows/metafile",
+ ".wml": "text/vnd.wap.wml",
+ ".wmlc": "application/vnd.wap.wmlc",
+ ".wmls": "text/vnd.wap.wmlscript",
+ ".wmlsc": "application/vnd.wap.wmlscriptc",
+ ".word": "application/msword",
+ ".wp5": "application/wordperfect",
+ ".wp6": "application/wordperfect",
+ ".wp": "application/wordperfect",
+ ".wpd": "application/wordperfect",
+ ".wq1": "application/x-lotus",
+ ".wri": "application/mswrite",
+ ".wrl": "application/x-world",
+ ".wrz": "model/vrml",
+ ".wsc": "text/scriplet",
+ ".wsrc": "application/x-wais-source",
+ ".wtk": "application/x-wintalk",
+ ".x-png": "image/png",
+ ".xbm": "image/x-xbitmap",
+ ".xdr": "video/x-amt-demorun",
+ ".xgz": "xgl/drawing",
+ ".xif": "image/vndxiff",
+ ".xl": "application/excel",
+ ".xla": "application/excel",
+ ".xlb": "application/excel",
+ ".xlc": "application/excel",
+ ".xld": "application/excel",
+ ".xlk": "application/excel",
+ ".xll": "application/excel",
+ ".xlm": "application/excel",
+ ".xls": "application/excel",
+ ".xlt": "application/excel",
+ ".xlv": "application/excel",
+ ".xlw": "application/excel",
+ ".xm": "audio/xm",
+ ".xml": "text/xml",
+ ".xmz": "xgl/movie",
+ ".xpix": "application/x-vndls-xpix",
+ ".xpm": "image/x-xpixmap",
+ ".xsr": "video/x-amt-showrun",
+ ".xwd": "image/x-xwd",
+ ".xyz": "chemical/x-pdb",
+ ".z": "application/x-compress",
+ ".zip": "application/zip",
+ ".zoo": "application/octet-stream",
+ ".zsh": "text/x-scriptzsh",
+ ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
+ ".docm": "application/vnd.ms-word.document.macroEnabled.12",
+ ".dotx": "application/vnd.openxmlformats-officedocument.wordprocessingml.template",
+ ".dotm": "application/vnd.ms-word.template.macroEnabled.12",
+ ".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
+ ".xlsm": "application/vnd.ms-excel.sheet.macroEnabled.12",
+ ".xltx": "application/vnd.openxmlformats-officedocument.spreadsheetml.template",
+ ".xltm": "application/vnd.ms-excel.template.macroEnabled.12",
+ ".xlsb": "application/vnd.ms-excel.sheet.binary.macroEnabled.12",
+ ".xlam": "application/vnd.ms-excel.addin.macroEnabled.12",
+ ".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
+ ".pptm": "application/vnd.ms-powerpoint.presentation.macroEnabled.12",
+ ".ppsx": "application/vnd.openxmlformats-officedocument.presentationml.slideshow",
+ ".ppsm": "application/vnd.ms-powerpoint.slideshow.macroEnabled.12",
+ ".potx": "application/vnd.openxmlformats-officedocument.presentationml.template",
+ ".potm": "application/vnd.ms-powerpoint.template.macroEnabled.12",
+ ".ppam": "application/vnd.ms-powerpoint.addin.macroEnabled.12",
+ ".sldx": "application/vnd.openxmlformats-officedocument.presentationml.slide",
+ ".sldm": "application/vnd.ms-powerpoint.slide.macroEnabled.12",
+ ".thmx": "application/vnd.ms-officetheme",
+ ".onetoc": "application/onenote",
+ ".onetoc2": "application/onenote",
+ ".onetmp": "application/onenote",
+ ".onepkg": "application/onenote",
+ ".key": "application/x-iwork-keynote-sffkey",
+ ".kth": "application/x-iwork-keynote-sffkth",
+ ".nmbtemplate": "application/x-iwork-numbers-sfftemplate",
+ ".numbers": "application/x-iwork-numbers-sffnumbers",
+ ".pages": "application/x-iwork-pages-sffpages",
+ ".template": "application/x-iwork-pages-sfftemplate",
+ ".xpi": "application/x-xpinstall",
+ ".oex": "application/x-opera-extension",
+ ".mustache": "text/html",
+}
diff --git a/pkg/namespace.go b/pkg/namespace.go
new file mode 100644
index 00000000..4952c9d5
--- /dev/null
+++ b/pkg/namespace.go
@@ -0,0 +1,396 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package beego
+
+import (
+ "net/http"
+ "strings"
+
+ beecontext "github.com/astaxie/beego/context"
+)
+
+type namespaceCond func(*beecontext.Context) bool
+
+// LinkNamespace used as link action
+type LinkNamespace func(*Namespace)
+
+// Namespace is store all the info
+type Namespace struct {
+ prefix string
+ handlers *ControllerRegister
+}
+
+// NewNamespace get new Namespace
+func NewNamespace(prefix string, params ...LinkNamespace) *Namespace {
+ ns := &Namespace{
+ prefix: prefix,
+ handlers: NewControllerRegister(),
+ }
+ for _, p := range params {
+ p(ns)
+ }
+ return ns
+}
+
+// Cond set condition function
+// if cond return true can run this namespace, else can't
+// usage:
+// ns.Cond(func (ctx *context.Context) bool{
+// if ctx.Input.Domain() == "api.beego.me" {
+// return true
+// }
+// return false
+// })
+// Cond as the first filter
+func (n *Namespace) Cond(cond namespaceCond) *Namespace {
+ fn := func(ctx *beecontext.Context) {
+ if !cond(ctx) {
+ exception("405", ctx)
+ }
+ }
+ if v := n.handlers.filters[BeforeRouter]; len(v) > 0 {
+ mr := new(FilterRouter)
+ mr.tree = NewTree()
+ mr.pattern = "*"
+ mr.filterFunc = fn
+ mr.tree.AddRouter("*", true)
+ n.handlers.filters[BeforeRouter] = append([]*FilterRouter{mr}, v...)
+ } else {
+ n.handlers.InsertFilter("*", BeforeRouter, fn)
+ }
+ return n
+}
+
+// Filter add filter in the Namespace
+// action has before & after
+// FilterFunc
+// usage:
+// Filter("before", func (ctx *context.Context){
+// _, ok := ctx.Input.Session("uid").(int)
+// if !ok && ctx.Request.RequestURI != "/login" {
+// ctx.Redirect(302, "/login")
+// }
+// })
+func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace {
+ var a int
+ if action == "before" {
+ a = BeforeRouter
+ } else if action == "after" {
+ a = FinishRouter
+ }
+ for _, f := range filter {
+ n.handlers.InsertFilter("*", a, f)
+ }
+ return n
+}
+
+// Router same as beego.Rourer
+// refer: https://godoc.org/github.com/astaxie/beego#Router
+func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace {
+ n.handlers.Add(rootpath, c, mappingMethods...)
+ return n
+}
+
+// AutoRouter same as beego.AutoRouter
+// refer: https://godoc.org/github.com/astaxie/beego#AutoRouter
+func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace {
+ n.handlers.AddAuto(c)
+ return n
+}
+
+// AutoPrefix same as beego.AutoPrefix
+// refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix
+func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace {
+ n.handlers.AddAutoPrefix(prefix, c)
+ return n
+}
+
+// Get same as beego.Get
+// refer: https://godoc.org/github.com/astaxie/beego#Get
+func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace {
+ n.handlers.Get(rootpath, f)
+ return n
+}
+
+// Post same as beego.Post
+// refer: https://godoc.org/github.com/astaxie/beego#Post
+func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace {
+ n.handlers.Post(rootpath, f)
+ return n
+}
+
+// Delete same as beego.Delete
+// refer: https://godoc.org/github.com/astaxie/beego#Delete
+func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace {
+ n.handlers.Delete(rootpath, f)
+ return n
+}
+
+// Put same as beego.Put
+// refer: https://godoc.org/github.com/astaxie/beego#Put
+func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace {
+ n.handlers.Put(rootpath, f)
+ return n
+}
+
+// Head same as beego.Head
+// refer: https://godoc.org/github.com/astaxie/beego#Head
+func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace {
+ n.handlers.Head(rootpath, f)
+ return n
+}
+
+// Options same as beego.Options
+// refer: https://godoc.org/github.com/astaxie/beego#Options
+func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace {
+ n.handlers.Options(rootpath, f)
+ return n
+}
+
+// Patch same as beego.Patch
+// refer: https://godoc.org/github.com/astaxie/beego#Patch
+func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace {
+ n.handlers.Patch(rootpath, f)
+ return n
+}
+
+// Any same as beego.Any
+// refer: https://godoc.org/github.com/astaxie/beego#Any
+func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace {
+ n.handlers.Any(rootpath, f)
+ return n
+}
+
+// Handler same as beego.Handler
+// refer: https://godoc.org/github.com/astaxie/beego#Handler
+func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace {
+ n.handlers.Handler(rootpath, h)
+ return n
+}
+
+// Include add include class
+// refer: https://godoc.org/github.com/astaxie/beego#Include
+func (n *Namespace) Include(cList ...ControllerInterface) *Namespace {
+ n.handlers.Include(cList...)
+ return n
+}
+
+// Namespace add nest Namespace
+// usage:
+//ns := beego.NewNamespace(“/v1”).
+//Namespace(
+// beego.NewNamespace("/shop").
+// Get("/:id", func(ctx *context.Context) {
+// ctx.Output.Body([]byte("shopinfo"))
+// }),
+// beego.NewNamespace("/order").
+// Get("/:id", func(ctx *context.Context) {
+// ctx.Output.Body([]byte("orderinfo"))
+// }),
+// beego.NewNamespace("/crm").
+// Get("/:id", func(ctx *context.Context) {
+// ctx.Output.Body([]byte("crminfo"))
+// }),
+//)
+func (n *Namespace) Namespace(ns ...*Namespace) *Namespace {
+ for _, ni := range ns {
+ for k, v := range ni.handlers.routers {
+ if _, ok := n.handlers.routers[k]; ok {
+ addPrefix(v, ni.prefix)
+ n.handlers.routers[k].AddTree(ni.prefix, v)
+ } else {
+ t := NewTree()
+ t.AddTree(ni.prefix, v)
+ addPrefix(t, ni.prefix)
+ n.handlers.routers[k] = t
+ }
+ }
+ if ni.handlers.enableFilter {
+ for pos, filterList := range ni.handlers.filters {
+ for _, mr := range filterList {
+ t := NewTree()
+ t.AddTree(ni.prefix, mr.tree)
+ mr.tree = t
+ n.handlers.insertFilterRouter(pos, mr)
+ }
+ }
+ }
+ }
+ return n
+}
+
+// AddNamespace register Namespace into beego.Handler
+// support multi Namespace
+func AddNamespace(nl ...*Namespace) {
+ for _, n := range nl {
+ for k, v := range n.handlers.routers {
+ if _, ok := BeeApp.Handlers.routers[k]; ok {
+ addPrefix(v, n.prefix)
+ BeeApp.Handlers.routers[k].AddTree(n.prefix, v)
+ } else {
+ t := NewTree()
+ t.AddTree(n.prefix, v)
+ addPrefix(t, n.prefix)
+ BeeApp.Handlers.routers[k] = t
+ }
+ }
+ if n.handlers.enableFilter {
+ for pos, filterList := range n.handlers.filters {
+ for _, mr := range filterList {
+ t := NewTree()
+ t.AddTree(n.prefix, mr.tree)
+ mr.tree = t
+ BeeApp.Handlers.insertFilterRouter(pos, mr)
+ }
+ }
+ }
+ }
+}
+
+func addPrefix(t *Tree, prefix string) {
+ for _, v := range t.fixrouters {
+ addPrefix(v, prefix)
+ }
+ if t.wildcard != nil {
+ addPrefix(t.wildcard, prefix)
+ }
+ for _, l := range t.leaves {
+ if c, ok := l.runObject.(*ControllerInfo); ok {
+ if !strings.HasPrefix(c.pattern, prefix) {
+ c.pattern = prefix + c.pattern
+ }
+ }
+ }
+}
+
+// NSCond is Namespace Condition
+func NSCond(cond namespaceCond) LinkNamespace {
+ return func(ns *Namespace) {
+ ns.Cond(cond)
+ }
+}
+
+// NSBefore Namespace BeforeRouter filter
+func NSBefore(filterList ...FilterFunc) LinkNamespace {
+ return func(ns *Namespace) {
+ ns.Filter("before", filterList...)
+ }
+}
+
+// NSAfter add Namespace FinishRouter filter
+func NSAfter(filterList ...FilterFunc) LinkNamespace {
+ return func(ns *Namespace) {
+ ns.Filter("after", filterList...)
+ }
+}
+
+// NSInclude Namespace Include ControllerInterface
+func NSInclude(cList ...ControllerInterface) LinkNamespace {
+ return func(ns *Namespace) {
+ ns.Include(cList...)
+ }
+}
+
+// NSRouter call Namespace Router
+func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace {
+ return func(ns *Namespace) {
+ ns.Router(rootpath, c, mappingMethods...)
+ }
+}
+
+// NSGet call Namespace Get
+func NSGet(rootpath string, f FilterFunc) LinkNamespace {
+ return func(ns *Namespace) {
+ ns.Get(rootpath, f)
+ }
+}
+
+// NSPost call Namespace Post
+func NSPost(rootpath string, f FilterFunc) LinkNamespace {
+ return func(ns *Namespace) {
+ ns.Post(rootpath, f)
+ }
+}
+
+// NSHead call Namespace Head
+func NSHead(rootpath string, f FilterFunc) LinkNamespace {
+ return func(ns *Namespace) {
+ ns.Head(rootpath, f)
+ }
+}
+
+// NSPut call Namespace Put
+func NSPut(rootpath string, f FilterFunc) LinkNamespace {
+ return func(ns *Namespace) {
+ ns.Put(rootpath, f)
+ }
+}
+
+// NSDelete call Namespace Delete
+func NSDelete(rootpath string, f FilterFunc) LinkNamespace {
+ return func(ns *Namespace) {
+ ns.Delete(rootpath, f)
+ }
+}
+
+// NSAny call Namespace Any
+func NSAny(rootpath string, f FilterFunc) LinkNamespace {
+ return func(ns *Namespace) {
+ ns.Any(rootpath, f)
+ }
+}
+
+// NSOptions call Namespace Options
+func NSOptions(rootpath string, f FilterFunc) LinkNamespace {
+ return func(ns *Namespace) {
+ ns.Options(rootpath, f)
+ }
+}
+
+// NSPatch call Namespace Patch
+func NSPatch(rootpath string, f FilterFunc) LinkNamespace {
+ return func(ns *Namespace) {
+ ns.Patch(rootpath, f)
+ }
+}
+
+// NSAutoRouter call Namespace AutoRouter
+func NSAutoRouter(c ControllerInterface) LinkNamespace {
+ return func(ns *Namespace) {
+ ns.AutoRouter(c)
+ }
+}
+
+// NSAutoPrefix call Namespace AutoPrefix
+func NSAutoPrefix(prefix string, c ControllerInterface) LinkNamespace {
+ return func(ns *Namespace) {
+ ns.AutoPrefix(prefix, c)
+ }
+}
+
+// NSNamespace add sub Namespace
+func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace {
+ return func(ns *Namespace) {
+ n := NewNamespace(prefix, params...)
+ ns.Namespace(n)
+ }
+}
+
+// NSHandler add handler
+func NSHandler(rootpath string, h http.Handler) LinkNamespace {
+ return func(ns *Namespace) {
+ ns.Handler(rootpath, h)
+ }
+}
diff --git a/pkg/namespace_test.go b/pkg/namespace_test.go
new file mode 100644
index 00000000..b3f20dff
--- /dev/null
+++ b/pkg/namespace_test.go
@@ -0,0 +1,168 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package beego
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "testing"
+
+ "github.com/astaxie/beego/context"
+)
+
+func TestNamespaceGet(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/v1/user", nil)
+ w := httptest.NewRecorder()
+
+ ns := NewNamespace("/v1")
+ ns.Get("/user", func(ctx *context.Context) {
+ ctx.Output.Body([]byte("v1_user"))
+ })
+ AddNamespace(ns)
+ BeeApp.Handlers.ServeHTTP(w, r)
+ if w.Body.String() != "v1_user" {
+ t.Errorf("TestNamespaceGet can't run, get the response is " + w.Body.String())
+ }
+}
+
+func TestNamespacePost(t *testing.T) {
+ r, _ := http.NewRequest("POST", "/v1/user/123", nil)
+ w := httptest.NewRecorder()
+
+ ns := NewNamespace("/v1")
+ ns.Post("/user/:id", func(ctx *context.Context) {
+ ctx.Output.Body([]byte(ctx.Input.Param(":id")))
+ })
+ AddNamespace(ns)
+ BeeApp.Handlers.ServeHTTP(w, r)
+ if w.Body.String() != "123" {
+ t.Errorf("TestNamespacePost can't run, get the response is " + w.Body.String())
+ }
+}
+
+func TestNamespaceNest(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/v1/admin/order", nil)
+ w := httptest.NewRecorder()
+
+ ns := NewNamespace("/v1")
+ ns.Namespace(
+ NewNamespace("/admin").
+ Get("/order", func(ctx *context.Context) {
+ ctx.Output.Body([]byte("order"))
+ }),
+ )
+ AddNamespace(ns)
+ BeeApp.Handlers.ServeHTTP(w, r)
+ if w.Body.String() != "order" {
+ t.Errorf("TestNamespaceNest can't run, get the response is " + w.Body.String())
+ }
+}
+
+func TestNamespaceNestParam(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/v1/admin/order/123", nil)
+ w := httptest.NewRecorder()
+
+ ns := NewNamespace("/v1")
+ ns.Namespace(
+ NewNamespace("/admin").
+ Get("/order/:id", func(ctx *context.Context) {
+ ctx.Output.Body([]byte(ctx.Input.Param(":id")))
+ }),
+ )
+ AddNamespace(ns)
+ BeeApp.Handlers.ServeHTTP(w, r)
+ if w.Body.String() != "123" {
+ t.Errorf("TestNamespaceNestParam can't run, get the response is " + w.Body.String())
+ }
+}
+
+func TestNamespaceRouter(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/v1/api/list", nil)
+ w := httptest.NewRecorder()
+
+ ns := NewNamespace("/v1")
+ ns.Router("/api/list", &TestController{}, "*:List")
+ AddNamespace(ns)
+ BeeApp.Handlers.ServeHTTP(w, r)
+ if w.Body.String() != "i am list" {
+ t.Errorf("TestNamespaceRouter can't run, get the response is " + w.Body.String())
+ }
+}
+
+func TestNamespaceAutoFunc(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/v1/test/list", nil)
+ w := httptest.NewRecorder()
+
+ ns := NewNamespace("/v1")
+ ns.AutoRouter(&TestController{})
+ AddNamespace(ns)
+ BeeApp.Handlers.ServeHTTP(w, r)
+ if w.Body.String() != "i am list" {
+ t.Errorf("user define func can't run")
+ }
+}
+
+func TestNamespaceFilter(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/v1/user/123", nil)
+ w := httptest.NewRecorder()
+
+ ns := NewNamespace("/v1")
+ ns.Filter("before", func(ctx *context.Context) {
+ ctx.Output.Body([]byte("this is Filter"))
+ }).
+ Get("/user/:id", func(ctx *context.Context) {
+ ctx.Output.Body([]byte(ctx.Input.Param(":id")))
+ })
+ AddNamespace(ns)
+ BeeApp.Handlers.ServeHTTP(w, r)
+ if w.Body.String() != "this is Filter" {
+ t.Errorf("TestNamespaceFilter can't run, get the response is " + w.Body.String())
+ }
+}
+
+func TestNamespaceCond(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/v2/test/list", nil)
+ w := httptest.NewRecorder()
+
+ ns := NewNamespace("/v2")
+ ns.Cond(func(ctx *context.Context) bool {
+ return ctx.Input.Domain() == "beego.me"
+ }).
+ AutoRouter(&TestController{})
+ AddNamespace(ns)
+ BeeApp.Handlers.ServeHTTP(w, r)
+ if w.Code != 405 {
+ t.Errorf("TestNamespaceCond can't run get the result " + strconv.Itoa(w.Code))
+ }
+}
+
+func TestNamespaceInside(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/v3/shop/order/123", nil)
+ w := httptest.NewRecorder()
+ ns := NewNamespace("/v3",
+ NSAutoRouter(&TestController{}),
+ NSNamespace("/shop",
+ NSGet("/order/:id", func(ctx *context.Context) {
+ ctx.Output.Body([]byte(ctx.Input.Param(":id")))
+ }),
+ ),
+ )
+ AddNamespace(ns)
+ BeeApp.Handlers.ServeHTTP(w, r)
+ if w.Body.String() != "123" {
+ t.Errorf("TestNamespaceInside can't run, get the response is " + w.Body.String())
+ }
+}
diff --git a/pkg/parser.go b/pkg/parser.go
new file mode 100644
index 00000000..3a311894
--- /dev/null
+++ b/pkg/parser.go
@@ -0,0 +1,591 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package beego
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "go/ast"
+ "go/parser"
+ "go/token"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "regexp"
+ "sort"
+ "strconv"
+ "strings"
+ "unicode"
+
+ "github.com/astaxie/beego/context/param"
+ "github.com/astaxie/beego/logs"
+ "github.com/astaxie/beego/utils"
+)
+
+var globalRouterTemplate = `package {{.routersDir}}
+
+import (
+ "github.com/astaxie/beego"
+ "github.com/astaxie/beego/context/param"{{.globalimport}}
+)
+
+func init() {
+{{.globalinfo}}
+}
+`
+
+var (
+ lastupdateFilename = "lastupdate.tmp"
+ commentFilename string
+ pkgLastupdate map[string]int64
+ genInfoList map[string][]ControllerComments
+
+ routerHooks = map[string]int{
+ "beego.BeforeStatic": BeforeStatic,
+ "beego.BeforeRouter": BeforeRouter,
+ "beego.BeforeExec": BeforeExec,
+ "beego.AfterExec": AfterExec,
+ "beego.FinishRouter": FinishRouter,
+ }
+
+ routerHooksMapping = map[int]string{
+ BeforeStatic: "beego.BeforeStatic",
+ BeforeRouter: "beego.BeforeRouter",
+ BeforeExec: "beego.BeforeExec",
+ AfterExec: "beego.AfterExec",
+ FinishRouter: "beego.FinishRouter",
+ }
+)
+
+const commentPrefix = "commentsRouter_"
+
+func init() {
+ pkgLastupdate = make(map[string]int64)
+}
+
+func parserPkg(pkgRealpath, pkgpath string) error {
+ rep := strings.NewReplacer("\\", "_", "/", "_", ".", "_")
+ commentFilename, _ = filepath.Rel(AppPath, pkgRealpath)
+ commentFilename = commentPrefix + rep.Replace(commentFilename) + ".go"
+ if !compareFile(pkgRealpath) {
+ logs.Info(pkgRealpath + " no changed")
+ return nil
+ }
+ genInfoList = make(map[string][]ControllerComments)
+ fileSet := token.NewFileSet()
+ astPkgs, err := parser.ParseDir(fileSet, pkgRealpath, func(info os.FileInfo) bool {
+ name := info.Name()
+ return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
+ }, parser.ParseComments)
+
+ if err != nil {
+ return err
+ }
+ for _, pkg := range astPkgs {
+ for _, fl := range pkg.Files {
+ for _, d := range fl.Decls {
+ switch specDecl := d.(type) {
+ case *ast.FuncDecl:
+ if specDecl.Recv != nil {
+ exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser
+ if ok {
+ parserComments(specDecl, fmt.Sprint(exp.X), pkgpath)
+ }
+ }
+ }
+ }
+ }
+ }
+ genRouterCode(pkgRealpath)
+ savetoFile(pkgRealpath)
+ return nil
+}
+
+type parsedComment struct {
+ routerPath string
+ methods []string
+ params map[string]parsedParam
+ filters []parsedFilter
+ imports []parsedImport
+}
+
+type parsedImport struct {
+ importPath string
+ importAlias string
+}
+
+type parsedFilter struct {
+ pattern string
+ pos int
+ filter string
+ params []bool
+}
+
+type parsedParam struct {
+ name string
+ datatype string
+ location string
+ defValue string
+ required bool
+}
+
+func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error {
+ if f.Doc != nil {
+ parsedComments, err := parseComment(f.Doc.List)
+ if err != nil {
+ return err
+ }
+ for _, parsedComment := range parsedComments {
+ if parsedComment.routerPath != "" {
+ key := pkgpath + ":" + controllerName
+ cc := ControllerComments{}
+ cc.Method = f.Name.String()
+ cc.Router = parsedComment.routerPath
+ cc.AllowHTTPMethods = parsedComment.methods
+ cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment)
+ cc.FilterComments = buildFilters(parsedComment.filters)
+ cc.ImportComments = buildImports(parsedComment.imports)
+ genInfoList[key] = append(genInfoList[key], cc)
+ }
+ }
+ }
+ return nil
+}
+
+func buildImports(pis []parsedImport) []*ControllerImportComments {
+ var importComments []*ControllerImportComments
+
+ for _, pi := range pis {
+ importComments = append(importComments, &ControllerImportComments{
+ ImportPath: pi.importPath,
+ ImportAlias: pi.importAlias,
+ })
+ }
+
+ return importComments
+}
+
+func buildFilters(pfs []parsedFilter) []*ControllerFilterComments {
+ var filterComments []*ControllerFilterComments
+
+ for _, pf := range pfs {
+ var (
+ returnOnOutput bool
+ resetParams bool
+ )
+
+ if len(pf.params) >= 1 {
+ returnOnOutput = pf.params[0]
+ }
+
+ if len(pf.params) >= 2 {
+ resetParams = pf.params[1]
+ }
+
+ filterComments = append(filterComments, &ControllerFilterComments{
+ Filter: pf.filter,
+ Pattern: pf.pattern,
+ Pos: pf.pos,
+ ReturnOnOutput: returnOnOutput,
+ ResetParams: resetParams,
+ })
+ }
+
+ return filterComments
+}
+
+func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam {
+ result := make([]*param.MethodParam, 0, len(funcParams))
+ for _, fparam := range funcParams {
+ for _, pName := range fparam.Names {
+ methodParam := buildMethodParam(fparam, pName.Name, pc)
+ result = append(result, methodParam)
+ }
+ }
+ return result
+}
+
+func buildMethodParam(fparam *ast.Field, name string, pc *parsedComment) *param.MethodParam {
+ options := []param.MethodParamOption{}
+ if cparam, ok := pc.params[name]; ok {
+ //Build param from comment info
+ name = cparam.name
+ if cparam.required {
+ options = append(options, param.IsRequired)
+ }
+ switch cparam.location {
+ case "body":
+ options = append(options, param.InBody)
+ case "header":
+ options = append(options, param.InHeader)
+ case "path":
+ options = append(options, param.InPath)
+ }
+ if cparam.defValue != "" {
+ options = append(options, param.Default(cparam.defValue))
+ }
+ } else {
+ if paramInPath(name, pc.routerPath) {
+ options = append(options, param.InPath)
+ }
+ }
+ return param.New(name, options...)
+}
+
+func paramInPath(name, route string) bool {
+ return strings.HasSuffix(route, ":"+name) ||
+ strings.Contains(route, ":"+name+"/")
+}
+
+var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`)
+
+func parseComment(lines []*ast.Comment) (pcs []*parsedComment, err error) {
+ pcs = []*parsedComment{}
+ params := map[string]parsedParam{}
+ filters := []parsedFilter{}
+ imports := []parsedImport{}
+
+ for _, c := range lines {
+ t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
+ if strings.HasPrefix(t, "@Param") {
+ pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param")))
+ if len(pv) < 4 {
+ logs.Error("Invalid @Param format. Needs at least 4 parameters")
+ }
+ p := parsedParam{}
+ names := strings.SplitN(pv[0], "=>", 2)
+ p.name = names[0]
+ funcParamName := p.name
+ if len(names) > 1 {
+ funcParamName = names[1]
+ }
+ p.location = pv[1]
+ p.datatype = pv[2]
+ switch len(pv) {
+ case 5:
+ p.required, _ = strconv.ParseBool(pv[3])
+ case 6:
+ p.defValue = pv[3]
+ p.required, _ = strconv.ParseBool(pv[4])
+ }
+ params[funcParamName] = p
+ }
+ }
+
+ for _, c := range lines {
+ t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
+ if strings.HasPrefix(t, "@Import") {
+ iv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Import")))
+ if len(iv) == 0 || len(iv) > 2 {
+ logs.Error("Invalid @Import format. Only accepts 1 or 2 parameters")
+ continue
+ }
+
+ p := parsedImport{}
+ p.importPath = iv[0]
+
+ if len(iv) == 2 {
+ p.importAlias = iv[1]
+ }
+
+ imports = append(imports, p)
+ }
+ }
+
+filterLoop:
+ for _, c := range lines {
+ t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
+ if strings.HasPrefix(t, "@Filter") {
+ fv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Filter")))
+ if len(fv) < 3 {
+ logs.Error("Invalid @Filter format. Needs at least 3 parameters")
+ continue filterLoop
+ }
+
+ p := parsedFilter{}
+ p.pattern = fv[0]
+ posName := fv[1]
+ if pos, exists := routerHooks[posName]; exists {
+ p.pos = pos
+ } else {
+ logs.Error("Invalid @Filter pos: ", posName)
+ continue filterLoop
+ }
+
+ p.filter = fv[2]
+ fvParams := fv[3:]
+ for _, fvParam := range fvParams {
+ switch fvParam {
+ case "true":
+ p.params = append(p.params, true)
+ case "false":
+ p.params = append(p.params, false)
+ default:
+ logs.Error("Invalid @Filter param: ", fvParam)
+ continue filterLoop
+ }
+ }
+
+ filters = append(filters, p)
+ }
+ }
+
+ for _, c := range lines {
+ var pc = &parsedComment{}
+ pc.params = params
+ pc.filters = filters
+ pc.imports = imports
+
+ t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
+ if strings.HasPrefix(t, "@router") {
+ t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
+ matches := routeRegex.FindStringSubmatch(t)
+ if len(matches) == 3 {
+ pc.routerPath = matches[1]
+ methods := matches[2]
+ if methods == "" {
+ pc.methods = []string{"get"}
+ //pc.hasGet = true
+ } else {
+ pc.methods = strings.Split(methods, ",")
+ //pc.hasGet = strings.Contains(methods, "get")
+ }
+ pcs = append(pcs, pc)
+ } else {
+ return nil, errors.New("Router information is missing")
+ }
+ }
+ }
+ return
+}
+
+// direct copy from bee\g_docs.go
+// analysis params return []string
+// @Param query form string true "The email for login"
+// [query form string true "The email for login"]
+func getparams(str string) []string {
+ var s []rune
+ var j int
+ var start bool
+ var r []string
+ var quoted int8
+ for _, c := range str {
+ if unicode.IsSpace(c) && quoted == 0 {
+ if !start {
+ continue
+ } else {
+ start = false
+ j++
+ r = append(r, string(s))
+ s = make([]rune, 0)
+ continue
+ }
+ }
+
+ start = true
+ if c == '"' {
+ quoted ^= 1
+ continue
+ }
+ s = append(s, c)
+ }
+ if len(s) > 0 {
+ r = append(r, string(s))
+ }
+ return r
+}
+
+func genRouterCode(pkgRealpath string) {
+ os.Mkdir(getRouterDir(pkgRealpath), 0755)
+ logs.Info("generate router from comments")
+ var (
+ globalinfo string
+ globalimport string
+ sortKey []string
+ )
+ for k := range genInfoList {
+ sortKey = append(sortKey, k)
+ }
+ sort.Strings(sortKey)
+ for _, k := range sortKey {
+ cList := genInfoList[k]
+ sort.Sort(ControllerCommentsSlice(cList))
+ for _, c := range cList {
+ allmethod := "nil"
+ if len(c.AllowHTTPMethods) > 0 {
+ allmethod = "[]string{"
+ for _, m := range c.AllowHTTPMethods {
+ allmethod += `"` + m + `",`
+ }
+ allmethod = strings.TrimRight(allmethod, ",") + "}"
+ }
+
+ params := "nil"
+ if len(c.Params) > 0 {
+ params = "[]map[string]string{"
+ for _, p := range c.Params {
+ for k, v := range p {
+ params = params + `map[string]string{` + k + `:"` + v + `"},`
+ }
+ }
+ params = strings.TrimRight(params, ",") + "}"
+ }
+
+ methodParams := "param.Make("
+ if len(c.MethodParams) > 0 {
+ lines := make([]string, 0, len(c.MethodParams))
+ for _, m := range c.MethodParams {
+ lines = append(lines, fmt.Sprint(m))
+ }
+ methodParams += "\n " +
+ strings.Join(lines, ",\n ") +
+ ",\n "
+ }
+ methodParams += ")"
+
+ imports := ""
+ if len(c.ImportComments) > 0 {
+ for _, i := range c.ImportComments {
+ var s string
+ if i.ImportAlias != "" {
+ s = fmt.Sprintf(`
+ %s "%s"`, i.ImportAlias, i.ImportPath)
+ } else {
+ s = fmt.Sprintf(`
+ "%s"`, i.ImportPath)
+ }
+ if !strings.Contains(globalimport, s) {
+ imports += s
+ }
+ }
+ }
+
+ filters := ""
+ if len(c.FilterComments) > 0 {
+ for _, f := range c.FilterComments {
+ filters += fmt.Sprintf(` &beego.ControllerFilter{
+ Pattern: "%s",
+ Pos: %s,
+ Filter: %s,
+ ReturnOnOutput: %v,
+ ResetParams: %v,
+ },`, f.Pattern, routerHooksMapping[f.Pos], f.Filter, f.ReturnOnOutput, f.ResetParams)
+ }
+ }
+
+ if filters == "" {
+ filters = "nil"
+ } else {
+ filters = fmt.Sprintf(`[]*beego.ControllerFilter{
+%s
+ }`, filters)
+ }
+
+ globalimport += imports
+
+ globalinfo = globalinfo + `
+ beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"],
+ beego.ControllerComments{
+ Method: "` + strings.TrimSpace(c.Method) + `",
+ ` + `Router: "` + c.Router + `"` + `,
+ AllowHTTPMethods: ` + allmethod + `,
+ MethodParams: ` + methodParams + `,
+ Filters: ` + filters + `,
+ Params: ` + params + `})
+`
+ }
+ }
+
+ if globalinfo != "" {
+ f, err := os.Create(filepath.Join(getRouterDir(pkgRealpath), commentFilename))
+ if err != nil {
+ panic(err)
+ }
+ defer f.Close()
+
+ routersDir := AppConfig.DefaultString("routersdir", "routers")
+ content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1)
+ content = strings.Replace(content, "{{.routersDir}}", routersDir, -1)
+ content = strings.Replace(content, "{{.globalimport}}", globalimport, -1)
+ f.WriteString(content)
+ }
+}
+
+func compareFile(pkgRealpath string) bool {
+ if !utils.FileExists(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) {
+ return true
+ }
+ if utils.FileExists(lastupdateFilename) {
+ content, err := ioutil.ReadFile(lastupdateFilename)
+ if err != nil {
+ return true
+ }
+ json.Unmarshal(content, &pkgLastupdate)
+ lastupdate, err := getpathTime(pkgRealpath)
+ if err != nil {
+ return true
+ }
+ if v, ok := pkgLastupdate[pkgRealpath]; ok {
+ if lastupdate <= v {
+ return false
+ }
+ }
+ }
+ return true
+}
+
+func savetoFile(pkgRealpath string) {
+ lastupdate, err := getpathTime(pkgRealpath)
+ if err != nil {
+ return
+ }
+ pkgLastupdate[pkgRealpath] = lastupdate
+ d, err := json.Marshal(pkgLastupdate)
+ if err != nil {
+ return
+ }
+ ioutil.WriteFile(lastupdateFilename, d, os.ModePerm)
+}
+
+func getpathTime(pkgRealpath string) (lastupdate int64, err error) {
+ fl, err := ioutil.ReadDir(pkgRealpath)
+ if err != nil {
+ return lastupdate, err
+ }
+ for _, f := range fl {
+ if lastupdate < f.ModTime().UnixNano() {
+ lastupdate = f.ModTime().UnixNano()
+ }
+ }
+ return lastupdate, nil
+}
+
+func getRouterDir(pkgRealpath string) string {
+ dir := filepath.Dir(pkgRealpath)
+ for {
+ routersDir := AppConfig.DefaultString("routersdir", "routers")
+ d := filepath.Join(dir, routersDir)
+ if utils.FileExists(d) {
+ return d
+ }
+
+ if r, _ := filepath.Rel(dir, AppPath); r == "." {
+ return d
+ }
+ // Parent dir.
+ dir = filepath.Dir(dir)
+ }
+}
diff --git a/pkg/plugins/apiauth/apiauth.go b/pkg/plugins/apiauth/apiauth.go
new file mode 100644
index 00000000..10e25f3f
--- /dev/null
+++ b/pkg/plugins/apiauth/apiauth.go
@@ -0,0 +1,165 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package apiauth provides handlers to enable apiauth support.
+//
+// Simple Usage:
+// import(
+// "github.com/astaxie/beego"
+// "github.com/astaxie/beego/plugins/apiauth"
+// )
+//
+// func main(){
+// // apiauth every request
+// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APIBaiscAuth("appid","appkey"))
+// beego.Run()
+// }
+//
+// Advanced Usage:
+//
+// func getAppSecret(appid string) string {
+// // get appsecret by appid
+// // maybe store in configure, maybe in database
+// }
+//
+// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APISecretAuth(getAppSecret, 360))
+//
+// Information:
+//
+// In the request user should include these params in the query
+//
+// 1. appid
+//
+// appid is assigned to the application
+//
+// 2. signature
+//
+// get the signature use apiauth.Signature()
+//
+// when you send to server remember use url.QueryEscape()
+//
+// 3. timestamp:
+//
+// send the request time, the format is yyyy-mm-dd HH:ii:ss
+//
+package apiauth
+
+import (
+ "bytes"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/base64"
+ "fmt"
+ "net/url"
+ "sort"
+ "time"
+
+ "github.com/astaxie/beego"
+ "github.com/astaxie/beego/context"
+)
+
+// AppIDToAppSecret is used to get appsecret throw appid
+type AppIDToAppSecret func(string) string
+
+// APIBasicAuth use the basic appid/appkey as the AppIdToAppSecret
+func APIBasicAuth(appid, appkey string) beego.FilterFunc {
+ ft := func(aid string) string {
+ if aid == appid {
+ return appkey
+ }
+ return ""
+ }
+ return APISecretAuth(ft, 300)
+}
+
+// APIBaiscAuth calls APIBasicAuth for previous callers
+func APIBaiscAuth(appid, appkey string) beego.FilterFunc {
+ return APIBasicAuth(appid, appkey)
+}
+
+// APISecretAuth use AppIdToAppSecret verify and
+func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc {
+ return func(ctx *context.Context) {
+ if ctx.Input.Query("appid") == "" {
+ ctx.ResponseWriter.WriteHeader(403)
+ ctx.WriteString("miss query param: appid")
+ return
+ }
+ appsecret := f(ctx.Input.Query("appid"))
+ if appsecret == "" {
+ ctx.ResponseWriter.WriteHeader(403)
+ ctx.WriteString("not exist this appid")
+ return
+ }
+ if ctx.Input.Query("signature") == "" {
+ ctx.ResponseWriter.WriteHeader(403)
+ ctx.WriteString("miss query param: signature")
+ return
+ }
+ if ctx.Input.Query("timestamp") == "" {
+ ctx.ResponseWriter.WriteHeader(403)
+ ctx.WriteString("miss query param: timestamp")
+ return
+ }
+ u, err := time.Parse("2006-01-02 15:04:05", ctx.Input.Query("timestamp"))
+ if err != nil {
+ ctx.ResponseWriter.WriteHeader(403)
+ ctx.WriteString("timestamp format is error, should 2006-01-02 15:04:05")
+ return
+ }
+ t := time.Now()
+ if t.Sub(u).Seconds() > float64(timeout) {
+ ctx.ResponseWriter.WriteHeader(403)
+ ctx.WriteString("timeout! the request time is long ago, please try again")
+ return
+ }
+ if ctx.Input.Query("signature") !=
+ Signature(appsecret, ctx.Input.Method(), ctx.Request.Form, ctx.Input.URL()) {
+ ctx.ResponseWriter.WriteHeader(403)
+ ctx.WriteString("auth failed")
+ }
+ }
+}
+
+// Signature used to generate signature with the appsecret/method/params/RequestURI
+func Signature(appsecret, method string, params url.Values, RequestURL string) (result string) {
+ var b bytes.Buffer
+ keys := make([]string, len(params))
+ pa := make(map[string]string)
+ for k, v := range params {
+ pa[k] = v[0]
+ keys = append(keys, k)
+ }
+
+ sort.Strings(keys)
+
+ for _, key := range keys {
+ if key == "signature" {
+ continue
+ }
+
+ val := pa[key]
+ if key != "" && val != "" {
+ b.WriteString(key)
+ b.WriteString(val)
+ }
+ }
+
+ stringToSign := fmt.Sprintf("%v\n%v\n%v\n", method, b.String(), RequestURL)
+
+ sha256 := sha256.New
+ hash := hmac.New(sha256, []byte(appsecret))
+ hash.Write([]byte(stringToSign))
+ return base64.StdEncoding.EncodeToString(hash.Sum(nil))
+}
diff --git a/pkg/plugins/apiauth/apiauth_test.go b/pkg/plugins/apiauth/apiauth_test.go
new file mode 100644
index 00000000..1f56cb0f
--- /dev/null
+++ b/pkg/plugins/apiauth/apiauth_test.go
@@ -0,0 +1,20 @@
+package apiauth
+
+import (
+ "net/url"
+ "testing"
+)
+
+func TestSignature(t *testing.T) {
+ appsecret := "beego secret"
+ method := "GET"
+ RequestURL := "http://localhost/test/url"
+ params := make(url.Values)
+ params.Add("arg1", "hello")
+ params.Add("arg2", "beego")
+
+ signature := "mFdpvLh48ca4mDVEItE9++AKKQ/IVca7O/ZyyB8hR58="
+ if Signature(appsecret, method, params, RequestURL) != signature {
+ t.Error("Signature error")
+ }
+}
diff --git a/pkg/plugins/auth/basic.go b/pkg/plugins/auth/basic.go
new file mode 100644
index 00000000..c478044a
--- /dev/null
+++ b/pkg/plugins/auth/basic.go
@@ -0,0 +1,107 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package auth provides handlers to enable basic auth support.
+// Simple Usage:
+// import(
+// "github.com/astaxie/beego"
+// "github.com/astaxie/beego/plugins/auth"
+// )
+//
+// func main(){
+// // authenticate every request
+// beego.InsertFilter("*", beego.BeforeRouter,auth.Basic("username","secretpassword"))
+// beego.Run()
+// }
+//
+//
+// Advanced Usage:
+//
+// func SecretAuth(username, password string) bool {
+// return username == "astaxie" && password == "helloBeego"
+// }
+// authPlugin := auth.NewBasicAuthenticator(SecretAuth, "Authorization Required")
+// beego.InsertFilter("*", beego.BeforeRouter,authPlugin)
+package auth
+
+import (
+ "encoding/base64"
+ "net/http"
+ "strings"
+
+ "github.com/astaxie/beego"
+ "github.com/astaxie/beego/context"
+)
+
+var defaultRealm = "Authorization Required"
+
+// Basic is the http basic auth
+func Basic(username string, password string) beego.FilterFunc {
+ secrets := func(user, pass string) bool {
+ return user == username && pass == password
+ }
+ return NewBasicAuthenticator(secrets, defaultRealm)
+}
+
+// NewBasicAuthenticator return the BasicAuth
+func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFunc {
+ return func(ctx *context.Context) {
+ a := &BasicAuth{Secrets: secrets, Realm: Realm}
+ if username := a.CheckAuth(ctx.Request); username == "" {
+ a.RequireAuth(ctx.ResponseWriter, ctx.Request)
+ }
+ }
+}
+
+// SecretProvider is the SecretProvider function
+type SecretProvider func(user, pass string) bool
+
+// BasicAuth store the SecretProvider and Realm
+type BasicAuth struct {
+ Secrets SecretProvider
+ Realm string
+}
+
+// CheckAuth Checks the username/password combination from the request. Returns
+// either an empty string (authentication failed) or the name of the
+// authenticated user.
+// Supports MD5 and SHA1 password entries
+func (a *BasicAuth) CheckAuth(r *http.Request) string {
+ s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
+ if len(s) != 2 || s[0] != "Basic" {
+ return ""
+ }
+
+ b, err := base64.StdEncoding.DecodeString(s[1])
+ if err != nil {
+ return ""
+ }
+ pair := strings.SplitN(string(b), ":", 2)
+ if len(pair) != 2 {
+ return ""
+ }
+
+ if a.Secrets(pair[0], pair[1]) {
+ return pair[0]
+ }
+ return ""
+}
+
+// RequireAuth http.Handler for BasicAuth which initiates the authentication process
+// (or requires reauthentication).
+func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("WWW-Authenticate", `Basic realm="`+a.Realm+`"`)
+ w.WriteHeader(401)
+ w.Write([]byte("401 Unauthorized\n"))
+}
diff --git a/pkg/plugins/authz/authz.go b/pkg/plugins/authz/authz.go
new file mode 100644
index 00000000..9dc0db76
--- /dev/null
+++ b/pkg/plugins/authz/authz.go
@@ -0,0 +1,86 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package authz provides handlers to enable ACL, RBAC, ABAC authorization support.
+// Simple Usage:
+// import(
+// "github.com/astaxie/beego"
+// "github.com/astaxie/beego/plugins/authz"
+// "github.com/casbin/casbin"
+// )
+//
+// func main(){
+// // mediate the access for every request
+// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
+// beego.Run()
+// }
+//
+//
+// Advanced Usage:
+//
+// func main(){
+// e := casbin.NewEnforcer("authz_model.conf", "")
+// e.AddRoleForUser("alice", "admin")
+// e.AddPolicy(...)
+//
+// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(e))
+// beego.Run()
+// }
+package authz
+
+import (
+ "github.com/astaxie/beego"
+ "github.com/astaxie/beego/context"
+ "github.com/casbin/casbin"
+ "net/http"
+)
+
+// NewAuthorizer returns the authorizer.
+// Use a casbin enforcer as input
+func NewAuthorizer(e *casbin.Enforcer) beego.FilterFunc {
+ return func(ctx *context.Context) {
+ a := &BasicAuthorizer{enforcer: e}
+
+ if !a.CheckPermission(ctx.Request) {
+ a.RequirePermission(ctx.ResponseWriter)
+ }
+ }
+}
+
+// BasicAuthorizer stores the casbin handler
+type BasicAuthorizer struct {
+ enforcer *casbin.Enforcer
+}
+
+// GetUserName gets the user name from the request.
+// Currently, only HTTP basic authentication is supported
+func (a *BasicAuthorizer) GetUserName(r *http.Request) string {
+ username, _, _ := r.BasicAuth()
+ return username
+}
+
+// CheckPermission checks the user/method/path combination from the request.
+// Returns true (permission granted) or false (permission forbidden)
+func (a *BasicAuthorizer) CheckPermission(r *http.Request) bool {
+ user := a.GetUserName(r)
+ method := r.Method
+ path := r.URL.Path
+ return a.enforcer.Enforce(user, path, method)
+}
+
+// RequirePermission returns the 403 Forbidden to the client
+func (a *BasicAuthorizer) RequirePermission(w http.ResponseWriter) {
+ w.WriteHeader(403)
+ w.Write([]byte("403 Forbidden\n"))
+}
diff --git a/pkg/plugins/authz/authz_model.conf b/pkg/plugins/authz/authz_model.conf
new file mode 100644
index 00000000..d1b3dbd7
--- /dev/null
+++ b/pkg/plugins/authz/authz_model.conf
@@ -0,0 +1,14 @@
+[request_definition]
+r = sub, obj, act
+
+[policy_definition]
+p = sub, obj, act
+
+[role_definition]
+g = _, _
+
+[policy_effect]
+e = some(where (p.eft == allow))
+
+[matchers]
+m = g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*")
\ No newline at end of file
diff --git a/pkg/plugins/authz/authz_policy.csv b/pkg/plugins/authz/authz_policy.csv
new file mode 100644
index 00000000..c062dd3e
--- /dev/null
+++ b/pkg/plugins/authz/authz_policy.csv
@@ -0,0 +1,7 @@
+p, alice, /dataset1/*, GET
+p, alice, /dataset1/resource1, POST
+p, bob, /dataset2/resource1, *
+p, bob, /dataset2/resource2, GET
+p, bob, /dataset2/folder1/*, POST
+p, dataset1_admin, /dataset1/*, *
+g, cathy, dataset1_admin
\ No newline at end of file
diff --git a/pkg/plugins/authz/authz_test.go b/pkg/plugins/authz/authz_test.go
new file mode 100644
index 00000000..49aed84c
--- /dev/null
+++ b/pkg/plugins/authz/authz_test.go
@@ -0,0 +1,107 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package authz
+
+import (
+ "github.com/astaxie/beego"
+ "github.com/astaxie/beego/context"
+ "github.com/astaxie/beego/plugins/auth"
+ "github.com/casbin/casbin"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, path string, method string, code int) {
+ r, _ := http.NewRequest(method, path, nil)
+ r.SetBasicAuth(user, "123")
+ w := httptest.NewRecorder()
+ handler.ServeHTTP(w, r)
+
+ if w.Code != code {
+ t.Errorf("%s, %s, %s: %d, supposed to be %d", user, path, method, w.Code, code)
+ }
+}
+
+func TestBasic(t *testing.T) {
+ handler := beego.NewControllerRegister()
+
+ handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123"))
+ handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
+
+ handler.Any("*", func(ctx *context.Context) {
+ ctx.Output.SetStatus(200)
+ })
+
+ testRequest(t, handler, "alice", "/dataset1/resource1", "GET", 200)
+ testRequest(t, handler, "alice", "/dataset1/resource1", "POST", 200)
+ testRequest(t, handler, "alice", "/dataset1/resource2", "GET", 200)
+ testRequest(t, handler, "alice", "/dataset1/resource2", "POST", 403)
+}
+
+func TestPathWildcard(t *testing.T) {
+ handler := beego.NewControllerRegister()
+
+ handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123"))
+ handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
+
+ handler.Any("*", func(ctx *context.Context) {
+ ctx.Output.SetStatus(200)
+ })
+
+ testRequest(t, handler, "bob", "/dataset2/resource1", "GET", 200)
+ testRequest(t, handler, "bob", "/dataset2/resource1", "POST", 200)
+ testRequest(t, handler, "bob", "/dataset2/resource1", "DELETE", 200)
+ testRequest(t, handler, "bob", "/dataset2/resource2", "GET", 200)
+ testRequest(t, handler, "bob", "/dataset2/resource2", "POST", 403)
+ testRequest(t, handler, "bob", "/dataset2/resource2", "DELETE", 403)
+
+ testRequest(t, handler, "bob", "/dataset2/folder1/item1", "GET", 403)
+ testRequest(t, handler, "bob", "/dataset2/folder1/item1", "POST", 200)
+ testRequest(t, handler, "bob", "/dataset2/folder1/item1", "DELETE", 403)
+ testRequest(t, handler, "bob", "/dataset2/folder1/item2", "GET", 403)
+ testRequest(t, handler, "bob", "/dataset2/folder1/item2", "POST", 200)
+ testRequest(t, handler, "bob", "/dataset2/folder1/item2", "DELETE", 403)
+}
+
+func TestRBAC(t *testing.T) {
+ handler := beego.NewControllerRegister()
+
+ handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123"))
+ e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")
+ handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e))
+
+ handler.Any("*", func(ctx *context.Context) {
+ ctx.Output.SetStatus(200)
+ })
+
+ // cathy can access all /dataset1/* resources via all methods because it has the dataset1_admin role.
+ testRequest(t, handler, "cathy", "/dataset1/item", "GET", 200)
+ testRequest(t, handler, "cathy", "/dataset1/item", "POST", 200)
+ testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 200)
+ testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403)
+ testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403)
+ testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403)
+
+ // delete all roles on user cathy, so cathy cannot access any resources now.
+ e.DeleteRolesForUser("cathy")
+
+ testRequest(t, handler, "cathy", "/dataset1/item", "GET", 403)
+ testRequest(t, handler, "cathy", "/dataset1/item", "POST", 403)
+ testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 403)
+ testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403)
+ testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403)
+ testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403)
+}
diff --git a/pkg/plugins/cors/cors.go b/pkg/plugins/cors/cors.go
new file mode 100644
index 00000000..45c327ab
--- /dev/null
+++ b/pkg/plugins/cors/cors.go
@@ -0,0 +1,228 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package cors provides handlers to enable CORS support.
+// Usage
+// import (
+// "github.com/astaxie/beego"
+// "github.com/astaxie/beego/plugins/cors"
+// )
+//
+// func main() {
+// // CORS for https://foo.* origins, allowing:
+// // - PUT and PATCH methods
+// // - Origin header
+// // - Credentials share
+// beego.InsertFilter("*", beego.BeforeRouter, cors.Allow(&cors.Options{
+// AllowOrigins: []string{"https://*.foo.com"},
+// AllowMethods: []string{"PUT", "PATCH"},
+// AllowHeaders: []string{"Origin"},
+// ExposeHeaders: []string{"Content-Length"},
+// AllowCredentials: true,
+// }))
+// beego.Run()
+// }
+package cors
+
+import (
+ "net/http"
+ "regexp"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/astaxie/beego"
+ "github.com/astaxie/beego/context"
+)
+
+const (
+ headerAllowOrigin = "Access-Control-Allow-Origin"
+ headerAllowCredentials = "Access-Control-Allow-Credentials"
+ headerAllowHeaders = "Access-Control-Allow-Headers"
+ headerAllowMethods = "Access-Control-Allow-Methods"
+ headerExposeHeaders = "Access-Control-Expose-Headers"
+ headerMaxAge = "Access-Control-Max-Age"
+
+ headerOrigin = "Origin"
+ headerRequestMethod = "Access-Control-Request-Method"
+ headerRequestHeaders = "Access-Control-Request-Headers"
+)
+
+var (
+ defaultAllowHeaders = []string{"Origin", "Accept", "Content-Type", "Authorization"}
+ // Regex patterns are generated from AllowOrigins. These are used and generated internally.
+ allowOriginPatterns = []string{}
+)
+
+// Options represents Access Control options.
+type Options struct {
+ // If set, all origins are allowed.
+ AllowAllOrigins bool
+ // A list of allowed origins. Wild cards and FQDNs are supported.
+ AllowOrigins []string
+ // If set, allows to share auth credentials such as cookies.
+ AllowCredentials bool
+ // A list of allowed HTTP methods.
+ AllowMethods []string
+ // A list of allowed HTTP headers.
+ AllowHeaders []string
+ // A list of exposed HTTP headers.
+ ExposeHeaders []string
+ // Max age of the CORS headers.
+ MaxAge time.Duration
+}
+
+// Header converts options into CORS headers.
+func (o *Options) Header(origin string) (headers map[string]string) {
+ headers = make(map[string]string)
+ // if origin is not allowed, don't extend the headers
+ // with CORS headers.
+ if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) {
+ return
+ }
+
+ // add allow origin
+ if o.AllowAllOrigins {
+ headers[headerAllowOrigin] = "*"
+ } else {
+ headers[headerAllowOrigin] = origin
+ }
+
+ // add allow credentials
+ headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials)
+
+ // add allow methods
+ if len(o.AllowMethods) > 0 {
+ headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",")
+ }
+
+ // add allow headers
+ if len(o.AllowHeaders) > 0 {
+ headers[headerAllowHeaders] = strings.Join(o.AllowHeaders, ",")
+ }
+
+ // add exposed header
+ if len(o.ExposeHeaders) > 0 {
+ headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",")
+ }
+ // add a max age header
+ if o.MaxAge > time.Duration(0) {
+ headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10)
+ }
+ return
+}
+
+// PreflightHeader converts options into CORS headers for a preflight response.
+func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map[string]string) {
+ headers = make(map[string]string)
+ if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) {
+ return
+ }
+ // verify if requested method is allowed
+ for _, method := range o.AllowMethods {
+ if method == rMethod {
+ headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",")
+ break
+ }
+ }
+
+ // verify if requested headers are allowed
+ var allowed []string
+ for _, rHeader := range strings.Split(rHeaders, ",") {
+ rHeader = strings.TrimSpace(rHeader)
+ lookupLoop:
+ for _, allowedHeader := range o.AllowHeaders {
+ if strings.ToLower(rHeader) == strings.ToLower(allowedHeader) {
+ allowed = append(allowed, rHeader)
+ break lookupLoop
+ }
+ }
+ }
+
+ headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials)
+ // add allow origin
+ if o.AllowAllOrigins {
+ headers[headerAllowOrigin] = "*"
+ } else {
+ headers[headerAllowOrigin] = origin
+ }
+
+ // add allowed headers
+ if len(allowed) > 0 {
+ headers[headerAllowHeaders] = strings.Join(allowed, ",")
+ }
+
+ // add exposed headers
+ if len(o.ExposeHeaders) > 0 {
+ headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",")
+ }
+ // add a max age header
+ if o.MaxAge > time.Duration(0) {
+ headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10)
+ }
+ return
+}
+
+// IsOriginAllowed looks up if the origin matches one of the patterns
+// generated from Options.AllowOrigins patterns.
+func (o *Options) IsOriginAllowed(origin string) (allowed bool) {
+ for _, pattern := range allowOriginPatterns {
+ allowed, _ = regexp.MatchString(pattern, origin)
+ if allowed {
+ return
+ }
+ }
+ return
+}
+
+// Allow enables CORS for requests those match the provided options.
+func Allow(opts *Options) beego.FilterFunc {
+ // Allow default headers if nothing is specified.
+ if len(opts.AllowHeaders) == 0 {
+ opts.AllowHeaders = defaultAllowHeaders
+ }
+
+ for _, origin := range opts.AllowOrigins {
+ pattern := regexp.QuoteMeta(origin)
+ pattern = strings.Replace(pattern, "\\*", ".*", -1)
+ pattern = strings.Replace(pattern, "\\?", ".", -1)
+ allowOriginPatterns = append(allowOriginPatterns, "^"+pattern+"$")
+ }
+
+ return func(ctx *context.Context) {
+ var (
+ origin = ctx.Input.Header(headerOrigin)
+ requestedMethod = ctx.Input.Header(headerRequestMethod)
+ requestedHeaders = ctx.Input.Header(headerRequestHeaders)
+ // additional headers to be added
+ // to the response.
+ headers map[string]string
+ )
+
+ if ctx.Input.Method() == "OPTIONS" &&
+ (requestedMethod != "" || requestedHeaders != "") {
+ headers = opts.PreflightHeader(origin, requestedMethod, requestedHeaders)
+ for key, value := range headers {
+ ctx.Output.Header(key, value)
+ }
+ ctx.ResponseWriter.WriteHeader(http.StatusOK)
+ return
+ }
+ headers = opts.Header(origin)
+
+ for key, value := range headers {
+ ctx.Output.Header(key, value)
+ }
+ }
+}
diff --git a/pkg/plugins/cors/cors_test.go b/pkg/plugins/cors/cors_test.go
new file mode 100644
index 00000000..34039143
--- /dev/null
+++ b/pkg/plugins/cors/cors_test.go
@@ -0,0 +1,253 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cors
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/astaxie/beego"
+ "github.com/astaxie/beego/context"
+)
+
+// HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header
+type HTTPHeaderGuardRecorder struct {
+ *httptest.ResponseRecorder
+ savedHeaderMap http.Header
+}
+
+// NewRecorder return HttpHeaderGuardRecorder
+func NewRecorder() *HTTPHeaderGuardRecorder {
+ return &HTTPHeaderGuardRecorder{httptest.NewRecorder(), nil}
+}
+
+func (gr *HTTPHeaderGuardRecorder) WriteHeader(code int) {
+ gr.ResponseRecorder.WriteHeader(code)
+ gr.savedHeaderMap = gr.ResponseRecorder.Header()
+}
+
+func (gr *HTTPHeaderGuardRecorder) Header() http.Header {
+ if gr.savedHeaderMap != nil {
+ // headers were written. clone so we don't get updates
+ clone := make(http.Header)
+ for k, v := range gr.savedHeaderMap {
+ clone[k] = v
+ }
+ return clone
+ }
+ return gr.ResponseRecorder.Header()
+}
+
+func Test_AllowAll(t *testing.T) {
+ recorder := httptest.NewRecorder()
+ handler := beego.NewControllerRegister()
+ handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
+ AllowAllOrigins: true,
+ }))
+ handler.Any("/foo", func(ctx *context.Context) {
+ ctx.Output.SetStatus(500)
+ })
+ r, _ := http.NewRequest("PUT", "/foo", nil)
+ handler.ServeHTTP(recorder, r)
+
+ if recorder.HeaderMap.Get(headerAllowOrigin) != "*" {
+ t.Errorf("Allow-Origin header should be *")
+ }
+}
+
+func Test_AllowRegexMatch(t *testing.T) {
+ recorder := httptest.NewRecorder()
+ handler := beego.NewControllerRegister()
+ handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
+ AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"},
+ }))
+ handler.Any("/foo", func(ctx *context.Context) {
+ ctx.Output.SetStatus(500)
+ })
+ origin := "https://bar.foo.com"
+ r, _ := http.NewRequest("PUT", "/foo", nil)
+ r.Header.Add("Origin", origin)
+ handler.ServeHTTP(recorder, r)
+
+ headerValue := recorder.HeaderMap.Get(headerAllowOrigin)
+ if headerValue != origin {
+ t.Errorf("Allow-Origin header should be %v, found %v", origin, headerValue)
+ }
+}
+
+func Test_AllowRegexNoMatch(t *testing.T) {
+ recorder := httptest.NewRecorder()
+ handler := beego.NewControllerRegister()
+ handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
+ AllowOrigins: []string{"https://*.foo.com"},
+ }))
+ handler.Any("/foo", func(ctx *context.Context) {
+ ctx.Output.SetStatus(500)
+ })
+ origin := "https://ww.foo.com.evil.com"
+ r, _ := http.NewRequest("PUT", "/foo", nil)
+ r.Header.Add("Origin", origin)
+ handler.ServeHTTP(recorder, r)
+
+ headerValue := recorder.HeaderMap.Get(headerAllowOrigin)
+ if headerValue != "" {
+ t.Errorf("Allow-Origin header should not exist, found %v", headerValue)
+ }
+}
+
+func Test_OtherHeaders(t *testing.T) {
+ recorder := httptest.NewRecorder()
+ handler := beego.NewControllerRegister()
+ handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
+ AllowAllOrigins: true,
+ AllowCredentials: true,
+ AllowMethods: []string{"PATCH", "GET"},
+ AllowHeaders: []string{"Origin", "X-whatever"},
+ ExposeHeaders: []string{"Content-Length", "Hello"},
+ MaxAge: 5 * time.Minute,
+ }))
+ handler.Any("/foo", func(ctx *context.Context) {
+ ctx.Output.SetStatus(500)
+ })
+ r, _ := http.NewRequest("PUT", "/foo", nil)
+ handler.ServeHTTP(recorder, r)
+
+ credentialsVal := recorder.HeaderMap.Get(headerAllowCredentials)
+ methodsVal := recorder.HeaderMap.Get(headerAllowMethods)
+ headersVal := recorder.HeaderMap.Get(headerAllowHeaders)
+ exposedHeadersVal := recorder.HeaderMap.Get(headerExposeHeaders)
+ maxAgeVal := recorder.HeaderMap.Get(headerMaxAge)
+
+ if credentialsVal != "true" {
+ t.Errorf("Allow-Credentials is expected to be true, found %v", credentialsVal)
+ }
+
+ if methodsVal != "PATCH,GET" {
+ t.Errorf("Allow-Methods is expected to be PATCH,GET; found %v", methodsVal)
+ }
+
+ if headersVal != "Origin,X-whatever" {
+ t.Errorf("Allow-Headers is expected to be Origin,X-whatever; found %v", headersVal)
+ }
+
+ if exposedHeadersVal != "Content-Length,Hello" {
+ t.Errorf("Expose-Headers are expected to be Content-Length,Hello. Found %v", exposedHeadersVal)
+ }
+
+ if maxAgeVal != "300" {
+ t.Errorf("Max-Age is expected to be 300, found %v", maxAgeVal)
+ }
+}
+
+func Test_DefaultAllowHeaders(t *testing.T) {
+ recorder := httptest.NewRecorder()
+ handler := beego.NewControllerRegister()
+ handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
+ AllowAllOrigins: true,
+ }))
+ handler.Any("/foo", func(ctx *context.Context) {
+ ctx.Output.SetStatus(500)
+ })
+
+ r, _ := http.NewRequest("PUT", "/foo", nil)
+ handler.ServeHTTP(recorder, r)
+
+ headersVal := recorder.HeaderMap.Get(headerAllowHeaders)
+ if headersVal != "Origin,Accept,Content-Type,Authorization" {
+ t.Errorf("Allow-Headers is expected to be Origin,Accept,Content-Type,Authorization; found %v", headersVal)
+ }
+}
+
+func Test_Preflight(t *testing.T) {
+ recorder := NewRecorder()
+ handler := beego.NewControllerRegister()
+ handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
+ AllowAllOrigins: true,
+ AllowMethods: []string{"PUT", "PATCH"},
+ AllowHeaders: []string{"Origin", "X-whatever", "X-CaseSensitive"},
+ }))
+
+ handler.Any("/foo", func(ctx *context.Context) {
+ ctx.Output.SetStatus(200)
+ })
+
+ r, _ := http.NewRequest("OPTIONS", "/foo", nil)
+ r.Header.Add(headerRequestMethod, "PUT")
+ r.Header.Add(headerRequestHeaders, "X-whatever, x-casesensitive")
+ handler.ServeHTTP(recorder, r)
+
+ headers := recorder.Header()
+ methodsVal := headers.Get(headerAllowMethods)
+ headersVal := headers.Get(headerAllowHeaders)
+ originVal := headers.Get(headerAllowOrigin)
+
+ if methodsVal != "PUT,PATCH" {
+ t.Errorf("Allow-Methods is expected to be PUT,PATCH, found %v", methodsVal)
+ }
+
+ if !strings.Contains(headersVal, "X-whatever") {
+ t.Errorf("Allow-Headers is expected to contain X-whatever, found %v", headersVal)
+ }
+
+ if !strings.Contains(headersVal, "x-casesensitive") {
+ t.Errorf("Allow-Headers is expected to contain x-casesensitive, found %v", headersVal)
+ }
+
+ if originVal != "*" {
+ t.Errorf("Allow-Origin is expected to be *, found %v", originVal)
+ }
+
+ if recorder.Code != http.StatusOK {
+ t.Errorf("Status code is expected to be 200, found %d", recorder.Code)
+ }
+}
+
+func Benchmark_WithoutCORS(b *testing.B) {
+ recorder := httptest.NewRecorder()
+ handler := beego.NewControllerRegister()
+ beego.BConfig.RunMode = beego.PROD
+ handler.Any("/foo", func(ctx *context.Context) {
+ ctx.Output.SetStatus(500)
+ })
+ b.ResetTimer()
+ r, _ := http.NewRequest("PUT", "/foo", nil)
+ for i := 0; i < b.N; i++ {
+ handler.ServeHTTP(recorder, r)
+ }
+}
+
+func Benchmark_WithCORS(b *testing.B) {
+ recorder := httptest.NewRecorder()
+ handler := beego.NewControllerRegister()
+ beego.BConfig.RunMode = beego.PROD
+ handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
+ AllowAllOrigins: true,
+ AllowCredentials: true,
+ AllowMethods: []string{"PATCH", "GET"},
+ AllowHeaders: []string{"Origin", "X-whatever"},
+ MaxAge: 5 * time.Minute,
+ }))
+ handler.Any("/foo", func(ctx *context.Context) {
+ ctx.Output.SetStatus(500)
+ })
+ b.ResetTimer()
+ r, _ := http.NewRequest("PUT", "/foo", nil)
+ for i := 0; i < b.N; i++ {
+ handler.ServeHTTP(recorder, r)
+ }
+}
diff --git a/pkg/policy.go b/pkg/policy.go
new file mode 100644
index 00000000..ab23f927
--- /dev/null
+++ b/pkg/policy.go
@@ -0,0 +1,97 @@
+// Copyright 2016 beego authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package beego
+
+import (
+ "strings"
+
+ "github.com/astaxie/beego/context"
+)
+
+// PolicyFunc defines a policy function which is invoked before the controller handler is executed.
+type PolicyFunc func(*context.Context)
+
+// FindPolicy Find Router info for URL
+func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc {
+ var urlPath = cont.Input.URL()
+ if !BConfig.RouterCaseSensitive {
+ urlPath = strings.ToLower(urlPath)
+ }
+ httpMethod := cont.Input.Method()
+ isWildcard := false
+ // Find policy for current method
+ t, ok := p.policies[httpMethod]
+ // If not found - find policy for whole controller
+ if !ok {
+ t, ok = p.policies["*"]
+ isWildcard = true
+ }
+ if ok {
+ runObjects := t.Match(urlPath, cont)
+ if r, ok := runObjects.([]PolicyFunc); ok {
+ return r
+ } else if !isWildcard {
+ // If no policies found and we checked not for "*" method - try to find it
+ t, ok = p.policies["*"]
+ if ok {
+ runObjects = t.Match(urlPath, cont)
+ if r, ok = runObjects.([]PolicyFunc); ok {
+ return r
+ }
+ }
+ }
+ }
+ return nil
+}
+
+func (p *ControllerRegister) addToPolicy(method, pattern string, r ...PolicyFunc) {
+ method = strings.ToUpper(method)
+ p.enablePolicy = true
+ if !BConfig.RouterCaseSensitive {
+ pattern = strings.ToLower(pattern)
+ }
+ if t, ok := p.policies[method]; ok {
+ t.AddRouter(pattern, r)
+ } else {
+ t := NewTree()
+ t.AddRouter(pattern, r)
+ p.policies[method] = t
+ }
+}
+
+// Policy Register new policy in beego
+func Policy(pattern, method string, policy ...PolicyFunc) {
+ BeeApp.Handlers.addToPolicy(method, pattern, policy...)
+}
+
+// Find policies and execute if were found
+func (p *ControllerRegister) execPolicy(cont *context.Context, urlPath string) (started bool) {
+ if !p.enablePolicy {
+ return false
+ }
+ // Find Policy for method
+ policyList := p.FindPolicy(cont)
+ if len(policyList) > 0 {
+ // Run policies
+ for _, runPolicy := range policyList {
+ runPolicy(cont)
+ if cont.ResponseWriter.Started {
+ return true
+ }
+ }
+ return false
+ }
+ return false
+}
diff --git a/pkg/router.go b/pkg/router.go
new file mode 100644
index 00000000..6a8ac6f7
--- /dev/null
+++ b/pkg/router.go
@@ -0,0 +1,1052 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package beego
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+ "os"
+ "path"
+ "path/filepath"
+ "reflect"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ beecontext "github.com/astaxie/beego/context"
+ "github.com/astaxie/beego/context/param"
+ "github.com/astaxie/beego/logs"
+ "github.com/astaxie/beego/toolbox"
+ "github.com/astaxie/beego/utils"
+)
+
+// default filter execution points
+const (
+ BeforeStatic = iota
+ BeforeRouter
+ BeforeExec
+ AfterExec
+ FinishRouter
+)
+
+const (
+ routerTypeBeego = iota
+ routerTypeRESTFul
+ routerTypeHandler
+)
+
+var (
+ // HTTPMETHOD list the supported http methods.
+ HTTPMETHOD = map[string]bool{
+ "GET": true,
+ "POST": true,
+ "PUT": true,
+ "DELETE": true,
+ "PATCH": true,
+ "OPTIONS": true,
+ "HEAD": true,
+ "TRACE": true,
+ "CONNECT": true,
+ "MKCOL": true,
+ "COPY": true,
+ "MOVE": true,
+ "PROPFIND": true,
+ "PROPPATCH": true,
+ "LOCK": true,
+ "UNLOCK": true,
+ }
+ // these beego.Controller's methods shouldn't reflect to AutoRouter
+ exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString",
+ "RenderBytes", "Redirect", "Abort", "StopRun", "UrlFor", "ServeJSON", "ServeJSONP",
+ "ServeYAML", "ServeXML", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool",
+ "GetFloat", "GetFile", "SaveToFile", "StartSession", "SetSession", "GetSession",
+ "DelSession", "SessionRegenerateID", "DestroySession", "IsAjax", "GetSecureCookie",
+ "SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml",
+ "GetControllerAndAction", "ServeFormatted"}
+
+ urlPlaceholder = "{{placeholder}}"
+ // DefaultAccessLogFilter will skip the accesslog if return true
+ DefaultAccessLogFilter FilterHandler = &logFilter{}
+)
+
+// FilterHandler is an interface for
+type FilterHandler interface {
+ Filter(*beecontext.Context) bool
+}
+
+// default log filter static file will not show
+type logFilter struct {
+}
+
+func (l *logFilter) Filter(ctx *beecontext.Context) bool {
+ requestPath := path.Clean(ctx.Request.URL.Path)
+ if requestPath == "/favicon.ico" || requestPath == "/robots.txt" {
+ return true
+ }
+ for prefix := range BConfig.WebConfig.StaticDir {
+ if strings.HasPrefix(requestPath, prefix) {
+ return true
+ }
+ }
+ return false
+}
+
+// ExceptMethodAppend to append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter
+func ExceptMethodAppend(action string) {
+ exceptMethod = append(exceptMethod, action)
+}
+
+// ControllerInfo holds information about the controller.
+type ControllerInfo struct {
+ pattern string
+ controllerType reflect.Type
+ methods map[string]string
+ handler http.Handler
+ runFunction FilterFunc
+ routerType int
+ initialize func() ControllerInterface
+ methodParams []*param.MethodParam
+}
+
+func (c *ControllerInfo) GetPattern() string {
+ return c.pattern
+}
+
+// ControllerRegister containers registered router rules, controller handlers and filters.
+type ControllerRegister struct {
+ routers map[string]*Tree
+ enablePolicy bool
+ policies map[string]*Tree
+ enableFilter bool
+ filters [FinishRouter + 1][]*FilterRouter
+ pool sync.Pool
+}
+
+// NewControllerRegister returns a new ControllerRegister.
+func NewControllerRegister() *ControllerRegister {
+ return &ControllerRegister{
+ routers: make(map[string]*Tree),
+ policies: make(map[string]*Tree),
+ pool: sync.Pool{
+ New: func() interface{} {
+ return beecontext.NewContext()
+ },
+ },
+ }
+}
+
+// Add controller handler and pattern rules to ControllerRegister.
+// usage:
+// default methods is the same name as method
+// Add("/user",&UserController{})
+// Add("/api/list",&RestController{},"*:ListFood")
+// Add("/api/create",&RestController{},"post:CreateFood")
+// Add("/api/update",&RestController{},"put:UpdateFood")
+// Add("/api/delete",&RestController{},"delete:DeleteFood")
+// Add("/api",&RestController{},"get,post:ApiFunc"
+// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
+func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
+ p.addWithMethodParams(pattern, c, nil, mappingMethods...)
+}
+
+func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) {
+ reflectVal := reflect.ValueOf(c)
+ t := reflect.Indirect(reflectVal).Type()
+ methods := make(map[string]string)
+ if len(mappingMethods) > 0 {
+ semi := strings.Split(mappingMethods[0], ";")
+ for _, v := range semi {
+ colon := strings.Split(v, ":")
+ if len(colon) != 2 {
+ panic("method mapping format is invalid")
+ }
+ comma := strings.Split(colon[0], ",")
+ for _, m := range comma {
+ if m == "*" || HTTPMETHOD[strings.ToUpper(m)] {
+ if val := reflectVal.MethodByName(colon[1]); val.IsValid() {
+ methods[strings.ToUpper(m)] = colon[1]
+ } else {
+ panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name())
+ }
+ } else {
+ panic(v + " is an invalid method mapping. Method doesn't exist " + m)
+ }
+ }
+ }
+ }
+
+ route := &ControllerInfo{}
+ route.pattern = pattern
+ route.methods = methods
+ route.routerType = routerTypeBeego
+ route.controllerType = t
+ route.initialize = func() ControllerInterface {
+ vc := reflect.New(route.controllerType)
+ execController, ok := vc.Interface().(ControllerInterface)
+ if !ok {
+ panic("controller is not ControllerInterface")
+ }
+
+ elemVal := reflect.ValueOf(c).Elem()
+ elemType := reflect.TypeOf(c).Elem()
+ execElem := reflect.ValueOf(execController).Elem()
+
+ numOfFields := elemVal.NumField()
+ for i := 0; i < numOfFields; i++ {
+ fieldType := elemType.Field(i)
+ elemField := execElem.FieldByName(fieldType.Name)
+ if elemField.CanSet() {
+ fieldVal := elemVal.Field(i)
+ elemField.Set(fieldVal)
+ }
+ }
+
+ return execController
+ }
+
+ route.methodParams = methodParams
+ if len(methods) == 0 {
+ for m := range HTTPMETHOD {
+ p.addToRouter(m, pattern, route)
+ }
+ } else {
+ for k := range methods {
+ if k == "*" {
+ for m := range HTTPMETHOD {
+ p.addToRouter(m, pattern, route)
+ }
+ } else {
+ p.addToRouter(k, pattern, route)
+ }
+ }
+ }
+}
+
+func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) {
+ if !BConfig.RouterCaseSensitive {
+ pattern = strings.ToLower(pattern)
+ }
+ if t, ok := p.routers[method]; ok {
+ t.AddRouter(pattern, r)
+ } else {
+ t := NewTree()
+ t.AddRouter(pattern, r)
+ p.routers[method] = t
+ }
+}
+
+// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller
+// Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{})
+func (p *ControllerRegister) Include(cList ...ControllerInterface) {
+ if BConfig.RunMode == DEV {
+ skip := make(map[string]bool, 10)
+ wgopath := utils.GetGOPATHs()
+ go111module := os.Getenv(`GO111MODULE`)
+ for _, c := range cList {
+ reflectVal := reflect.ValueOf(c)
+ t := reflect.Indirect(reflectVal).Type()
+ // for go modules
+ if go111module == `on` {
+ pkgpath := filepath.Join(WorkPath, "..", t.PkgPath())
+ if utils.FileExists(pkgpath) {
+ if pkgpath != "" {
+ if _, ok := skip[pkgpath]; !ok {
+ skip[pkgpath] = true
+ parserPkg(pkgpath, t.PkgPath())
+ }
+ }
+ }
+ } else {
+ if len(wgopath) == 0 {
+ panic("you are in dev mode. So please set gopath")
+ }
+ pkgpath := ""
+ for _, wg := range wgopath {
+ wg, _ = filepath.EvalSymlinks(filepath.Join(wg, "src", t.PkgPath()))
+ if utils.FileExists(wg) {
+ pkgpath = wg
+ break
+ }
+ }
+ if pkgpath != "" {
+ if _, ok := skip[pkgpath]; !ok {
+ skip[pkgpath] = true
+ parserPkg(pkgpath, t.PkgPath())
+ }
+ }
+ }
+ }
+ }
+ for _, c := range cList {
+ reflectVal := reflect.ValueOf(c)
+ t := reflect.Indirect(reflectVal).Type()
+ key := t.PkgPath() + ":" + t.Name()
+ if comm, ok := GlobalControllerRouter[key]; ok {
+ for _, a := range comm {
+ for _, f := range a.Filters {
+ p.InsertFilter(f.Pattern, f.Pos, f.Filter, f.ReturnOnOutput, f.ResetParams)
+ }
+
+ p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)
+ }
+ }
+ }
+}
+
+// GetContext returns a context from pool, so usually you should remember to call Reset function to clean the context
+// And don't forget to give back context to pool
+// example:
+// ctx := p.GetContext()
+// ctx.Reset(w, q)
+// defer p.GiveBackContext(ctx)
+func (p *ControllerRegister) GetContext() *beecontext.Context {
+ return p.pool.Get().(*beecontext.Context)
+}
+
+// GiveBackContext put the ctx into pool so that it could be reuse
+func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) {
+ p.pool.Put(ctx)
+}
+
+// Get add get method
+// usage:
+// Get("/", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func (p *ControllerRegister) Get(pattern string, f FilterFunc) {
+ p.AddMethod("get", pattern, f)
+}
+
+// Post add post method
+// usage:
+// Post("/api", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func (p *ControllerRegister) Post(pattern string, f FilterFunc) {
+ p.AddMethod("post", pattern, f)
+}
+
+// Put add put method
+// usage:
+// Put("/api/:id", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func (p *ControllerRegister) Put(pattern string, f FilterFunc) {
+ p.AddMethod("put", pattern, f)
+}
+
+// Delete add delete method
+// usage:
+// Delete("/api/:id", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func (p *ControllerRegister) Delete(pattern string, f FilterFunc) {
+ p.AddMethod("delete", pattern, f)
+}
+
+// Head add head method
+// usage:
+// Head("/api/:id", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func (p *ControllerRegister) Head(pattern string, f FilterFunc) {
+ p.AddMethod("head", pattern, f)
+}
+
+// Patch add patch method
+// usage:
+// Patch("/api/:id", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func (p *ControllerRegister) Patch(pattern string, f FilterFunc) {
+ p.AddMethod("patch", pattern, f)
+}
+
+// Options add options method
+// usage:
+// Options("/api/:id", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func (p *ControllerRegister) Options(pattern string, f FilterFunc) {
+ p.AddMethod("options", pattern, f)
+}
+
+// Any add all method
+// usage:
+// Any("/api/:id", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func (p *ControllerRegister) Any(pattern string, f FilterFunc) {
+ p.AddMethod("*", pattern, f)
+}
+
+// AddMethod add http method router
+// usage:
+// AddMethod("get","/api/:id", func(ctx *context.Context){
+// ctx.Output.Body("hello world")
+// })
+func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
+ method = strings.ToUpper(method)
+ if method != "*" && !HTTPMETHOD[method] {
+ panic("not support http method: " + method)
+ }
+ route := &ControllerInfo{}
+ route.pattern = pattern
+ route.routerType = routerTypeRESTFul
+ route.runFunction = f
+ methods := make(map[string]string)
+ if method == "*" {
+ for val := range HTTPMETHOD {
+ methods[val] = val
+ }
+ } else {
+ methods[method] = method
+ }
+ route.methods = methods
+ for k := range methods {
+ if k == "*" {
+ for m := range HTTPMETHOD {
+ p.addToRouter(m, pattern, route)
+ }
+ } else {
+ p.addToRouter(k, pattern, route)
+ }
+ }
+}
+
+// Handler add user defined Handler
+func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) {
+ route := &ControllerInfo{}
+ route.pattern = pattern
+ route.routerType = routerTypeHandler
+ route.handler = h
+ if len(options) > 0 {
+ if _, ok := options[0].(bool); ok {
+ pattern = path.Join(pattern, "?:all(.*)")
+ }
+ }
+ for m := range HTTPMETHOD {
+ p.addToRouter(m, pattern, route)
+ }
+}
+
+// AddAuto router to ControllerRegister.
+// example beego.AddAuto(&MainContorlller{}),
+// MainController has method List and Page.
+// visit the url /main/list to execute List function
+// /main/page to execute Page function.
+func (p *ControllerRegister) AddAuto(c ControllerInterface) {
+ p.AddAutoPrefix("/", c)
+}
+
+// AddAutoPrefix Add auto router to ControllerRegister with prefix.
+// example beego.AddAutoPrefix("/admin",&MainContorlller{}),
+// MainController has method List and Page.
+// visit the url /admin/main/list to execute List function
+// /admin/main/page to execute Page function.
+func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) {
+ reflectVal := reflect.ValueOf(c)
+ rt := reflectVal.Type()
+ ct := reflect.Indirect(reflectVal).Type()
+ controllerName := strings.TrimSuffix(ct.Name(), "Controller")
+ for i := 0; i < rt.NumMethod(); i++ {
+ if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
+ route := &ControllerInfo{}
+ route.routerType = routerTypeBeego
+ route.methods = map[string]string{"*": rt.Method(i).Name}
+ route.controllerType = ct
+ pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*")
+ patternInit := path.Join(prefix, controllerName, rt.Method(i).Name, "*")
+ patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name))
+ patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name)
+ route.pattern = pattern
+ for m := range HTTPMETHOD {
+ p.addToRouter(m, pattern, route)
+ p.addToRouter(m, patternInit, route)
+ p.addToRouter(m, patternFix, route)
+ p.addToRouter(m, patternFixInit, route)
+ }
+ }
+ }
+}
+
+// InsertFilter Add a FilterFunc with pattern rule and action constant.
+// params is for:
+// 1. setting the returnOnOutput value (false allows multiple filters to execute)
+// 2. determining whether or not params need to be reset.
+func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error {
+ mr := &FilterRouter{
+ tree: NewTree(),
+ pattern: pattern,
+ filterFunc: filter,
+ returnOnOutput: true,
+ }
+ if !BConfig.RouterCaseSensitive {
+ mr.pattern = strings.ToLower(pattern)
+ }
+
+ paramsLen := len(params)
+ if paramsLen > 0 {
+ mr.returnOnOutput = params[0]
+ }
+ if paramsLen > 1 {
+ mr.resetParams = params[1]
+ }
+ mr.tree.AddRouter(pattern, true)
+ return p.insertFilterRouter(pos, mr)
+}
+
+// add Filter into
+func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) {
+ if pos < BeforeStatic || pos > FinishRouter {
+ return errors.New("can not find your filter position")
+ }
+ p.enableFilter = true
+ p.filters[pos] = append(p.filters[pos], mr)
+ return nil
+}
+
+// URLFor does another controller handler in this request function.
+// it can access any controller method.
+func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string {
+ paths := strings.Split(endpoint, ".")
+ if len(paths) <= 1 {
+ logs.Warn("urlfor endpoint must like path.controller.method")
+ return ""
+ }
+ if len(values)%2 != 0 {
+ logs.Warn("urlfor params must key-value pair")
+ return ""
+ }
+ params := make(map[string]string)
+ if len(values) > 0 {
+ key := ""
+ for k, v := range values {
+ if k%2 == 0 {
+ key = fmt.Sprint(v)
+ } else {
+ params[key] = fmt.Sprint(v)
+ }
+ }
+ }
+ controllerName := strings.Join(paths[:len(paths)-1], "/")
+ methodName := paths[len(paths)-1]
+ for m, t := range p.routers {
+ ok, url := p.getURL(t, "/", controllerName, methodName, params, m)
+ if ok {
+ return url
+ }
+ }
+ return ""
+}
+
+func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName string, params map[string]string, httpMethod string) (bool, string) {
+ for _, subtree := range t.fixrouters {
+ u := path.Join(url, subtree.prefix)
+ ok, u := p.getURL(subtree, u, controllerName, methodName, params, httpMethod)
+ if ok {
+ return ok, u
+ }
+ }
+ if t.wildcard != nil {
+ u := path.Join(url, urlPlaceholder)
+ ok, u := p.getURL(t.wildcard, u, controllerName, methodName, params, httpMethod)
+ if ok {
+ return ok, u
+ }
+ }
+ for _, l := range t.leaves {
+ if c, ok := l.runObject.(*ControllerInfo); ok {
+ if c.routerType == routerTypeBeego &&
+ strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) {
+ find := false
+ if HTTPMETHOD[strings.ToUpper(methodName)] {
+ if len(c.methods) == 0 {
+ find = true
+ } else if m, ok := c.methods[strings.ToUpper(methodName)]; ok && m == strings.ToUpper(methodName) {
+ find = true
+ } else if m, ok = c.methods["*"]; ok && m == methodName {
+ find = true
+ }
+ }
+ if !find {
+ for m, md := range c.methods {
+ if (m == "*" || m == httpMethod) && md == methodName {
+ find = true
+ }
+ }
+ }
+ if find {
+ if l.regexps == nil {
+ if len(l.wildcards) == 0 {
+ return true, strings.Replace(url, "/"+urlPlaceholder, "", 1) + toURL(params)
+ }
+ if len(l.wildcards) == 1 {
+ if v, ok := params[l.wildcards[0]]; ok {
+ delete(params, l.wildcards[0])
+ return true, strings.Replace(url, urlPlaceholder, v, 1) + toURL(params)
+ }
+ return false, ""
+ }
+ if len(l.wildcards) == 3 && l.wildcards[0] == "." {
+ if p, ok := params[":path"]; ok {
+ if e, isok := params[":ext"]; isok {
+ delete(params, ":path")
+ delete(params, ":ext")
+ return true, strings.Replace(url, urlPlaceholder, p+"."+e, -1) + toURL(params)
+ }
+ }
+ }
+ canSkip := false
+ for _, v := range l.wildcards {
+ if v == ":" {
+ canSkip = true
+ continue
+ }
+ if u, ok := params[v]; ok {
+ delete(params, v)
+ url = strings.Replace(url, urlPlaceholder, u, 1)
+ } else {
+ if canSkip {
+ canSkip = false
+ continue
+ }
+ return false, ""
+ }
+ }
+ return true, url + toURL(params)
+ }
+ var i int
+ var startReg bool
+ regURL := ""
+ for _, v := range strings.Trim(l.regexps.String(), "^$") {
+ if v == '(' {
+ startReg = true
+ continue
+ } else if v == ')' {
+ startReg = false
+ if v, ok := params[l.wildcards[i]]; ok {
+ delete(params, l.wildcards[i])
+ regURL = regURL + v
+ i++
+ } else {
+ break
+ }
+ } else if !startReg {
+ regURL = string(append([]rune(regURL), v))
+ }
+ }
+ if l.regexps.MatchString(regURL) {
+ ps := strings.Split(regURL, "/")
+ for _, p := range ps {
+ url = strings.Replace(url, urlPlaceholder, p, 1)
+ }
+ return true, url + toURL(params)
+ }
+ }
+ }
+ }
+ }
+
+ return false, ""
+}
+
+func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) {
+ var preFilterParams map[string]string
+ for _, filterR := range p.filters[pos] {
+ if filterR.returnOnOutput && context.ResponseWriter.Started {
+ return true
+ }
+ if filterR.resetParams {
+ preFilterParams = context.Input.Params()
+ }
+ if ok := filterR.ValidRouter(urlPath, context); ok {
+ filterR.filterFunc(context)
+ if filterR.resetParams {
+ context.Input.ResetParams()
+ for k, v := range preFilterParams {
+ context.Input.SetParam(k, v)
+ }
+ }
+ }
+ if filterR.returnOnOutput && context.ResponseWriter.Started {
+ return true
+ }
+ }
+ return false
+}
+
+// Implement http.Handler interface.
+func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
+ startTime := time.Now()
+ var (
+ runRouter reflect.Type
+ findRouter bool
+ runMethod string
+ methodParams []*param.MethodParam
+ routerInfo *ControllerInfo
+ isRunnable bool
+ )
+ context := p.GetContext()
+
+ context.Reset(rw, r)
+
+ defer p.GiveBackContext(context)
+ if BConfig.RecoverFunc != nil {
+ defer BConfig.RecoverFunc(context)
+ }
+
+ context.Output.EnableGzip = BConfig.EnableGzip
+
+ if BConfig.RunMode == DEV {
+ context.Output.Header("Server", BConfig.ServerName)
+ }
+
+ var urlPath = r.URL.Path
+
+ if !BConfig.RouterCaseSensitive {
+ urlPath = strings.ToLower(urlPath)
+ }
+
+ // filter wrong http method
+ if !HTTPMETHOD[r.Method] {
+ exception("405", context)
+ goto Admin
+ }
+
+ // filter for static file
+ if len(p.filters[BeforeStatic]) > 0 && p.execFilter(context, urlPath, BeforeStatic) {
+ goto Admin
+ }
+
+ serverStaticRouter(context)
+
+ if context.ResponseWriter.Started {
+ findRouter = true
+ goto Admin
+ }
+
+ if r.Method != http.MethodGet && r.Method != http.MethodHead {
+ if BConfig.CopyRequestBody && !context.Input.IsUpload() {
+ // connection will close if the incoming data are larger (RFC 7231, 6.5.11)
+ if r.ContentLength > BConfig.MaxMemory {
+ logs.Error(errors.New("payload too large"))
+ exception("413", context)
+ goto Admin
+ }
+ context.Input.CopyBody(BConfig.MaxMemory)
+ }
+ context.Input.ParseFormOrMulitForm(BConfig.MaxMemory)
+ }
+
+ // session init
+ if BConfig.WebConfig.Session.SessionOn {
+ var err error
+ context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r)
+ if err != nil {
+ logs.Error(err)
+ exception("503", context)
+ goto Admin
+ }
+ defer func() {
+ if context.Input.CruSession != nil {
+ context.Input.CruSession.SessionRelease(rw)
+ }
+ }()
+ }
+ if len(p.filters[BeforeRouter]) > 0 && p.execFilter(context, urlPath, BeforeRouter) {
+ goto Admin
+ }
+ // User can define RunController and RunMethod in filter
+ if context.Input.RunController != nil && context.Input.RunMethod != "" {
+ findRouter = true
+ runMethod = context.Input.RunMethod
+ runRouter = context.Input.RunController
+ } else {
+ routerInfo, findRouter = p.FindRouter(context)
+ }
+
+ // if no matches to url, throw a not found exception
+ if !findRouter {
+ exception("404", context)
+ goto Admin
+ }
+ if splat := context.Input.Param(":splat"); splat != "" {
+ for k, v := range strings.Split(splat, "/") {
+ context.Input.SetParam(strconv.Itoa(k), v)
+ }
+ }
+
+ if routerInfo != nil {
+ // store router pattern into context
+ context.Input.SetData("RouterPattern", routerInfo.pattern)
+ }
+
+ // execute middleware filters
+ if len(p.filters[BeforeExec]) > 0 && p.execFilter(context, urlPath, BeforeExec) {
+ goto Admin
+ }
+
+ // check policies
+ if p.execPolicy(context, urlPath) {
+ goto Admin
+ }
+
+ if routerInfo != nil {
+ if routerInfo.routerType == routerTypeRESTFul {
+ if _, ok := routerInfo.methods[r.Method]; ok {
+ isRunnable = true
+ routerInfo.runFunction(context)
+ } else {
+ exception("405", context)
+ goto Admin
+ }
+ } else if routerInfo.routerType == routerTypeHandler {
+ isRunnable = true
+ routerInfo.handler.ServeHTTP(context.ResponseWriter, context.Request)
+ } else {
+ runRouter = routerInfo.controllerType
+ methodParams = routerInfo.methodParams
+ method := r.Method
+ if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPut {
+ method = http.MethodPut
+ }
+ if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodDelete {
+ method = http.MethodDelete
+ }
+ if m, ok := routerInfo.methods[method]; ok {
+ runMethod = m
+ } else if m, ok = routerInfo.methods["*"]; ok {
+ runMethod = m
+ } else {
+ runMethod = method
+ }
+ }
+ }
+
+ // also defined runRouter & runMethod from filter
+ if !isRunnable {
+ // Invoke the request handler
+ var execController ControllerInterface
+ if routerInfo != nil && routerInfo.initialize != nil {
+ execController = routerInfo.initialize()
+ } else {
+ vc := reflect.New(runRouter)
+ var ok bool
+ execController, ok = vc.Interface().(ControllerInterface)
+ if !ok {
+ panic("controller is not ControllerInterface")
+ }
+ }
+
+ // call the controller init function
+ execController.Init(context, runRouter.Name(), runMethod, execController)
+
+ // call prepare function
+ execController.Prepare()
+
+ // if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf
+ if BConfig.WebConfig.EnableXSRF {
+ execController.XSRFToken()
+ if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut ||
+ (r.Method == http.MethodPost && (context.Input.Query("_method") == http.MethodDelete || context.Input.Query("_method") == http.MethodPut)) {
+ execController.CheckXSRFCookie()
+ }
+ }
+
+ execController.URLMapping()
+
+ if !context.ResponseWriter.Started {
+ // exec main logic
+ switch runMethod {
+ case http.MethodGet:
+ execController.Get()
+ case http.MethodPost:
+ execController.Post()
+ case http.MethodDelete:
+ execController.Delete()
+ case http.MethodPut:
+ execController.Put()
+ case http.MethodHead:
+ execController.Head()
+ case http.MethodPatch:
+ execController.Patch()
+ case http.MethodOptions:
+ execController.Options()
+ case http.MethodTrace:
+ execController.Trace()
+ default:
+ if !execController.HandlerFunc(runMethod) {
+ vc := reflect.ValueOf(execController)
+ method := vc.MethodByName(runMethod)
+ in := param.ConvertParams(methodParams, method.Type(), context)
+ out := method.Call(in)
+
+ // For backward compatibility we only handle response if we had incoming methodParams
+ if methodParams != nil {
+ p.handleParamResponse(context, execController, out)
+ }
+ }
+ }
+
+ // render template
+ if !context.ResponseWriter.Started && context.Output.Status == 0 {
+ if BConfig.WebConfig.AutoRender {
+ if err := execController.Render(); err != nil {
+ logs.Error(err)
+ }
+ }
+ }
+ }
+
+ // finish all runRouter. release resource
+ execController.Finish()
+ }
+
+ // execute middleware filters
+ if len(p.filters[AfterExec]) > 0 && p.execFilter(context, urlPath, AfterExec) {
+ goto Admin
+ }
+
+ if len(p.filters[FinishRouter]) > 0 && p.execFilter(context, urlPath, FinishRouter) {
+ goto Admin
+ }
+
+Admin:
+ // admin module record QPS
+
+ statusCode := context.ResponseWriter.Status
+ if statusCode == 0 {
+ statusCode = 200
+ }
+
+ LogAccess(context, &startTime, statusCode)
+
+ timeDur := time.Since(startTime)
+ context.ResponseWriter.Elapsed = timeDur
+ if BConfig.Listen.EnableAdmin {
+ pattern := ""
+ if routerInfo != nil {
+ pattern = routerInfo.pattern
+ }
+
+ if FilterMonitorFunc(r.Method, r.URL.Path, timeDur, pattern, statusCode) {
+ routerName := ""
+ if runRouter != nil {
+ routerName = runRouter.Name()
+ }
+ go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, routerName, timeDur)
+ }
+ }
+
+ if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs {
+ match := map[bool]string{true: "match", false: "nomatch"}
+ devInfo := fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s",
+ context.Input.IP(),
+ logs.ColorByStatus(statusCode), statusCode, logs.ResetColor(),
+ timeDur.String(),
+ match[findRouter],
+ logs.ColorByMethod(r.Method), r.Method, logs.ResetColor(),
+ r.URL.Path)
+ if routerInfo != nil {
+ devInfo += fmt.Sprintf(" r:%s", routerInfo.pattern)
+ }
+
+ logs.Debug(devInfo)
+ }
+ // Call WriteHeader if status code has been set changed
+ if context.Output.Status != 0 {
+ context.ResponseWriter.WriteHeader(context.Output.Status)
+ }
+}
+
+func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) {
+ // looping in reverse order for the case when both error and value are returned and error sets the response status code
+ for i := len(results) - 1; i >= 0; i-- {
+ result := results[i]
+ if result.Kind() != reflect.Interface || !result.IsNil() {
+ resultValue := result.Interface()
+ context.RenderMethodResult(resultValue)
+ }
+ }
+ if !context.ResponseWriter.Started && len(results) > 0 && context.Output.Status == 0 {
+ context.Output.SetStatus(200)
+ }
+}
+
+// FindRouter Find Router info for URL
+func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) {
+ var urlPath = context.Input.URL()
+ if !BConfig.RouterCaseSensitive {
+ urlPath = strings.ToLower(urlPath)
+ }
+ httpMethod := context.Input.Method()
+ if t, ok := p.routers[httpMethod]; ok {
+ runObject := t.Match(urlPath, context)
+ if r, ok := runObject.(*ControllerInfo); ok {
+ return r, true
+ }
+ }
+ return
+}
+
+func toURL(params map[string]string) string {
+ if len(params) == 0 {
+ return ""
+ }
+ u := "?"
+ for k, v := range params {
+ u += k + "=" + v + "&"
+ }
+ return strings.TrimRight(u, "&")
+}
+
+// LogAccess logging info HTTP Access
+func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) {
+ // Skip logging if AccessLogs config is false
+ if !BConfig.Log.AccessLogs {
+ return
+ }
+ // Skip logging static requests unless EnableStaticLogs config is true
+ if !BConfig.Log.EnableStaticLogs && DefaultAccessLogFilter.Filter(ctx) {
+ return
+ }
+ var (
+ requestTime time.Time
+ elapsedTime time.Duration
+ r = ctx.Request
+ )
+ if startTime != nil {
+ requestTime = *startTime
+ elapsedTime = time.Since(*startTime)
+ }
+ record := &logs.AccessLogRecord{
+ RemoteAddr: ctx.Input.IP(),
+ RequestTime: requestTime,
+ RequestMethod: r.Method,
+ Request: fmt.Sprintf("%s %s %s", r.Method, r.RequestURI, r.Proto),
+ ServerProtocol: r.Proto,
+ Host: r.Host,
+ Status: statusCode,
+ ElapsedTime: elapsedTime,
+ HTTPReferrer: r.Header.Get("Referer"),
+ HTTPUserAgent: r.Header.Get("User-Agent"),
+ RemoteUser: r.Header.Get("Remote-User"),
+ BodyBytesSent: r.ContentLength,
+ }
+ logs.AccessLog(record, BConfig.Log.AccessLogsFormat)
+}
diff --git a/pkg/router_test.go b/pkg/router_test.go
new file mode 100644
index 00000000..8ec7927a
--- /dev/null
+++ b/pkg/router_test.go
@@ -0,0 +1,732 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package beego
+
+import (
+ "bytes"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/astaxie/beego/context"
+ "github.com/astaxie/beego/logs"
+)
+
+type TestController struct {
+ Controller
+}
+
+func (tc *TestController) Get() {
+ tc.Data["Username"] = "astaxie"
+ tc.Ctx.Output.Body([]byte("ok"))
+}
+
+func (tc *TestController) Post() {
+ tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Query(":name")))
+}
+
+func (tc *TestController) Param() {
+ tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Query(":name")))
+}
+
+func (tc *TestController) List() {
+ tc.Ctx.Output.Body([]byte("i am list"))
+}
+
+func (tc *TestController) Params() {
+ tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Param("0") + tc.Ctx.Input.Param("1") + tc.Ctx.Input.Param("2")))
+}
+
+func (tc *TestController) Myext() {
+ tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Param(":ext")))
+}
+
+func (tc *TestController) GetURL() {
+ tc.Ctx.Output.Body([]byte(tc.URLFor(".Myext")))
+}
+
+func (tc *TestController) GetParams() {
+ tc.Ctx.WriteString(tc.Ctx.Input.Query(":last") + "+" +
+ tc.Ctx.Input.Query(":first") + "+" + tc.Ctx.Input.Query("learn"))
+}
+
+func (tc *TestController) GetManyRouter() {
+ tc.Ctx.WriteString(tc.Ctx.Input.Query(":id") + tc.Ctx.Input.Query(":page"))
+}
+
+func (tc *TestController) GetEmptyBody() {
+ var res []byte
+ tc.Ctx.Output.Body(res)
+}
+
+type JSONController struct {
+ Controller
+}
+
+func (jc *JSONController) Prepare() {
+ jc.Data["json"] = "prepare"
+ jc.ServeJSON(true)
+}
+
+func (jc *JSONController) Get() {
+ jc.Data["Username"] = "astaxie"
+ jc.Ctx.Output.Body([]byte("ok"))
+}
+
+func TestUrlFor(t *testing.T) {
+ handler := NewControllerRegister()
+ handler.Add("/api/list", &TestController{}, "*:List")
+ handler.Add("/person/:last/:first", &TestController{}, "*:Param")
+ if a := handler.URLFor("TestController.List"); a != "/api/list" {
+ logs.Info(a)
+ t.Errorf("TestController.List must equal to /api/list")
+ }
+ if a := handler.URLFor("TestController.Param", ":last", "xie", ":first", "asta"); a != "/person/xie/asta" {
+ t.Errorf("TestController.Param must equal to /person/xie/asta, but get " + a)
+ }
+}
+
+func TestUrlFor3(t *testing.T) {
+ handler := NewControllerRegister()
+ handler.AddAuto(&TestController{})
+ if a := handler.URLFor("TestController.Myext"); a != "/test/myext" && a != "/Test/Myext" {
+ t.Errorf("TestController.Myext must equal to /test/myext, but get " + a)
+ }
+ if a := handler.URLFor("TestController.GetURL"); a != "/test/geturl" && a != "/Test/GetURL" {
+ t.Errorf("TestController.GetURL must equal to /test/geturl, but get " + a)
+ }
+}
+
+func TestUrlFor2(t *testing.T) {
+ handler := NewControllerRegister()
+ handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, "*:List")
+ handler.Add("/v1/:username/edit", &TestController{}, "get:GetURL")
+ handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param")
+ handler.Add("/:year:int/:month:int/:title/:entid", &TestController{})
+ if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" {
+ logs.Info(handler.URLFor("TestController.GetURL"))
+ t.Errorf("TestController.List must equal to /v1/astaxie/edit")
+ }
+
+ if handler.URLFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") !=
+ "/v1/za/cms_12_123.html" {
+ logs.Info(handler.URLFor("TestController.List"))
+ t.Errorf("TestController.List must equal to /v1/za/cms_12_123.html")
+ }
+ if handler.URLFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") !=
+ "/v1/za_cms/ttt_12_123.html" {
+ logs.Info(handler.URLFor("TestController.Param"))
+ t.Errorf("TestController.List must equal to /v1/za_cms/ttt_12_123.html")
+ }
+ if handler.URLFor("TestController.Get", ":year", "1111", ":month", "11",
+ ":title", "aaaa", ":entid", "aaaa") !=
+ "/1111/11/aaaa/aaaa" {
+ logs.Info(handler.URLFor("TestController.Get"))
+ t.Errorf("TestController.Get must equal to /1111/11/aaaa/aaaa")
+ }
+}
+
+func TestUserFunc(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/api/list", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.Add("/api/list", &TestController{}, "*:List")
+ handler.ServeHTTP(w, r)
+ if w.Body.String() != "i am list" {
+ t.Errorf("user define func can't run")
+ }
+}
+
+func TestPostFunc(t *testing.T) {
+ r, _ := http.NewRequest("POST", "/astaxie", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.Add("/:name", &TestController{})
+ handler.ServeHTTP(w, r)
+ if w.Body.String() != "astaxie" {
+ t.Errorf("post func should astaxie")
+ }
+}
+
+func TestAutoFunc(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/test/list", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.AddAuto(&TestController{})
+ handler.ServeHTTP(w, r)
+ if w.Body.String() != "i am list" {
+ t.Errorf("user define func can't run")
+ }
+}
+
+func TestAutoFunc2(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/Test/List", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.AddAuto(&TestController{})
+ handler.ServeHTTP(w, r)
+ if w.Body.String() != "i am list" {
+ t.Errorf("user define func can't run")
+ }
+}
+
+func TestAutoFuncParams(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/test/params/2009/11/12", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.AddAuto(&TestController{})
+ handler.ServeHTTP(w, r)
+ if w.Body.String() != "20091112" {
+ t.Errorf("user define func can't run")
+ }
+}
+
+func TestAutoExtFunc(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/test/myext.json", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.AddAuto(&TestController{})
+ handler.ServeHTTP(w, r)
+ if w.Body.String() != "json" {
+ t.Errorf("user define func can't run")
+ }
+}
+
+func TestRouteOk(t *testing.T) {
+
+ r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.Add("/person/:last/:first", &TestController{}, "get:GetParams")
+ handler.ServeHTTP(w, r)
+ body := w.Body.String()
+ if body != "anderson+thomas+kungfu" {
+ t.Errorf("url param set to [%s];", body)
+ }
+}
+
+func TestManyRoute(t *testing.T) {
+
+ r, _ := http.NewRequest("GET", "/beego32-12.html", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, "get:GetManyRouter")
+ handler.ServeHTTP(w, r)
+
+ body := w.Body.String()
+
+ if body != "3212" {
+ t.Errorf("url param set to [%s];", body)
+ }
+}
+
+// Test for issue #1669
+func TestEmptyResponse(t *testing.T) {
+
+ r, _ := http.NewRequest("GET", "/beego-empty.html", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.Add("/beego-empty.html", &TestController{}, "get:GetEmptyBody")
+ handler.ServeHTTP(w, r)
+
+ if body := w.Body.String(); body != "" {
+ t.Error("want empty body")
+ }
+}
+
+func TestNotFound(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.ServeHTTP(w, r)
+
+ if w.Code != http.StatusNotFound {
+ t.Errorf("Code set to [%v]; want [%v]", w.Code, http.StatusNotFound)
+ }
+}
+
+// TestStatic tests the ability to serve static
+// content from the filesystem
+func TestStatic(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/static/js/jquery.js", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.ServeHTTP(w, r)
+
+ if w.Code != 404 {
+ t.Errorf("handler.Static failed to serve file")
+ }
+}
+
+func TestPrepare(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/json/list", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.Add("/json/list", &JSONController{})
+ handler.ServeHTTP(w, r)
+ if w.Body.String() != `"prepare"` {
+ t.Errorf(w.Body.String() + "user define func can't run")
+ }
+}
+
+func TestAutoPrefix(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/admin/test/list", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.AddAutoPrefix("/admin", &TestController{})
+ handler.ServeHTTP(w, r)
+ if w.Body.String() != "i am list" {
+ t.Errorf("TestAutoPrefix can't run")
+ }
+}
+
+func TestRouterGet(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/user", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.Get("/user", func(ctx *context.Context) {
+ ctx.Output.Body([]byte("Get userlist"))
+ })
+ handler.ServeHTTP(w, r)
+ if w.Body.String() != "Get userlist" {
+ t.Errorf("TestRouterGet can't run")
+ }
+}
+
+func TestRouterPost(t *testing.T) {
+ r, _ := http.NewRequest("POST", "/user/123", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.Post("/user/:id", func(ctx *context.Context) {
+ ctx.Output.Body([]byte(ctx.Input.Param(":id")))
+ })
+ handler.ServeHTTP(w, r)
+ if w.Body.String() != "123" {
+ t.Errorf("TestRouterPost can't run")
+ }
+}
+
+func sayhello(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("sayhello"))
+}
+
+func TestRouterHandler(t *testing.T) {
+ r, _ := http.NewRequest("POST", "/sayhi", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.Handler("/sayhi", http.HandlerFunc(sayhello))
+ handler.ServeHTTP(w, r)
+ if w.Body.String() != "sayhello" {
+ t.Errorf("TestRouterHandler can't run")
+ }
+}
+
+func TestRouterHandlerAll(t *testing.T) {
+ r, _ := http.NewRequest("POST", "/sayhi/a/b/c", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.Handler("/sayhi", http.HandlerFunc(sayhello), true)
+ handler.ServeHTTP(w, r)
+ if w.Body.String() != "sayhello" {
+ t.Errorf("TestRouterHandler can't run")
+ }
+}
+
+//
+// Benchmarks NewApp:
+//
+
+func beegoFilterFunc(ctx *context.Context) {
+ ctx.WriteString("hello")
+}
+
+type AdminController struct {
+ Controller
+}
+
+func (a *AdminController) Get() {
+ a.Ctx.WriteString("hello")
+}
+
+func TestRouterFunc(t *testing.T) {
+ mux := NewControllerRegister()
+ mux.Get("/action", beegoFilterFunc)
+ mux.Post("/action", beegoFilterFunc)
+ rw, r := testRequest("GET", "/action")
+ mux.ServeHTTP(rw, r)
+ if rw.Body.String() != "hello" {
+ t.Errorf("TestRouterFunc can't run")
+ }
+}
+
+func BenchmarkFunc(b *testing.B) {
+ mux := NewControllerRegister()
+ mux.Get("/action", beegoFilterFunc)
+ rw, r := testRequest("GET", "/action")
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mux.ServeHTTP(rw, r)
+ }
+}
+
+func BenchmarkController(b *testing.B) {
+ mux := NewControllerRegister()
+ mux.Add("/action", &AdminController{})
+ rw, r := testRequest("GET", "/action")
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mux.ServeHTTP(rw, r)
+ }
+}
+
+func testRequest(method, path string) (*httptest.ResponseRecorder, *http.Request) {
+ request, _ := http.NewRequest(method, path, nil)
+ recorder := httptest.NewRecorder()
+
+ return recorder, request
+}
+
+// Expectation: A Filter with the correct configuration should be created given
+// specific parameters.
+func TestInsertFilter(t *testing.T) {
+ testName := "TestInsertFilter"
+
+ mux := NewControllerRegister()
+ mux.InsertFilter("*", BeforeRouter, func(*context.Context) {})
+ if !mux.filters[BeforeRouter][0].returnOnOutput {
+ t.Errorf(
+ "%s: passing no variadic params should set returnOnOutput to true",
+ testName)
+ }
+ if mux.filters[BeforeRouter][0].resetParams {
+ t.Errorf(
+ "%s: passing no variadic params should set resetParams to false",
+ testName)
+ }
+
+ mux = NewControllerRegister()
+ mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, false)
+ if mux.filters[BeforeRouter][0].returnOnOutput {
+ t.Errorf(
+ "%s: passing false as 1st variadic param should set returnOnOutput to false",
+ testName)
+ }
+
+ mux = NewControllerRegister()
+ mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, true, true)
+ if !mux.filters[BeforeRouter][0].resetParams {
+ t.Errorf(
+ "%s: passing true as 2nd variadic param should set resetParams to true",
+ testName)
+ }
+}
+
+// Expectation: the second variadic arg should cause the execution of the filter
+// to preserve the parameters from before its execution.
+func TestParamResetFilter(t *testing.T) {
+ testName := "TestParamResetFilter"
+ route := "/beego/*" // splat
+ path := "/beego/routes/routes"
+
+ mux := NewControllerRegister()
+
+ mux.InsertFilter("*", BeforeExec, beegoResetParams, true, true)
+
+ mux.Get(route, beegoHandleResetParams)
+
+ rw, r := testRequest("GET", path)
+ mux.ServeHTTP(rw, r)
+
+ // The two functions, `beegoResetParams` and `beegoHandleResetParams` add
+ // a response header of `Splat`. The expectation here is that that Header
+ // value should match what the _request's_ router set, not the filter's.
+
+ headers := rw.Result().Header
+ if len(headers["Splat"]) != 1 {
+ t.Errorf(
+ "%s: There was an error in the test. Splat param not set in Header",
+ testName)
+ }
+ if headers["Splat"][0] != "routes/routes" {
+ t.Errorf(
+ "%s: expected `:splat` param to be [routes/routes] but it was [%s]",
+ testName, headers["Splat"][0])
+ }
+}
+
+// Execution point: BeforeRouter
+// expectation: only BeforeRouter function is executed, notmatch output as router doesn't handle
+func TestFilterBeforeRouter(t *testing.T) {
+ testName := "TestFilterBeforeRouter"
+ url := "/beforeRouter"
+
+ mux := NewControllerRegister()
+ mux.InsertFilter(url, BeforeRouter, beegoBeforeRouter1)
+
+ mux.Get(url, beegoFilterFunc)
+
+ rw, r := testRequest("GET", url)
+ mux.ServeHTTP(rw, r)
+
+ if !strings.Contains(rw.Body.String(), "BeforeRouter1") {
+ t.Errorf(testName + " BeforeRouter did not run")
+ }
+ if strings.Contains(rw.Body.String(), "hello") {
+ t.Errorf(testName + " BeforeRouter did not return properly")
+ }
+}
+
+// Execution point: BeforeExec
+// expectation: only BeforeExec function is executed, match as router determines route only
+func TestFilterBeforeExec(t *testing.T) {
+ testName := "TestFilterBeforeExec"
+ url := "/beforeExec"
+
+ mux := NewControllerRegister()
+ mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput)
+ mux.InsertFilter(url, BeforeExec, beegoBeforeExec1)
+
+ mux.Get(url, beegoFilterFunc)
+
+ rw, r := testRequest("GET", url)
+ mux.ServeHTTP(rw, r)
+
+ if !strings.Contains(rw.Body.String(), "BeforeExec1") {
+ t.Errorf(testName + " BeforeExec did not run")
+ }
+ if strings.Contains(rw.Body.String(), "hello") {
+ t.Errorf(testName + " BeforeExec did not return properly")
+ }
+ if strings.Contains(rw.Body.String(), "BeforeRouter") {
+ t.Errorf(testName + " BeforeRouter ran in error")
+ }
+}
+
+// Execution point: AfterExec
+// expectation: only AfterExec function is executed, match as router handles
+func TestFilterAfterExec(t *testing.T) {
+ testName := "TestFilterAfterExec"
+ url := "/afterExec"
+
+ mux := NewControllerRegister()
+ mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput)
+ mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput)
+ mux.InsertFilter(url, AfterExec, beegoAfterExec1, false)
+
+ mux.Get(url, beegoFilterFunc)
+
+ rw, r := testRequest("GET", url)
+ mux.ServeHTTP(rw, r)
+
+ if !strings.Contains(rw.Body.String(), "AfterExec1") {
+ t.Errorf(testName + " AfterExec did not run")
+ }
+ if !strings.Contains(rw.Body.String(), "hello") {
+ t.Errorf(testName + " handler did not run properly")
+ }
+ if strings.Contains(rw.Body.String(), "BeforeRouter") {
+ t.Errorf(testName + " BeforeRouter ran in error")
+ }
+ if strings.Contains(rw.Body.String(), "BeforeExec") {
+ t.Errorf(testName + " BeforeExec ran in error")
+ }
+}
+
+// Execution point: FinishRouter
+// expectation: only FinishRouter function is executed, match as router handles
+func TestFilterFinishRouter(t *testing.T) {
+ testName := "TestFilterFinishRouter"
+ url := "/finishRouter"
+
+ mux := NewControllerRegister()
+ mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput)
+ mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput)
+ mux.InsertFilter(url, AfterExec, beegoFilterNoOutput)
+ mux.InsertFilter(url, FinishRouter, beegoFinishRouter1)
+
+ mux.Get(url, beegoFilterFunc)
+
+ rw, r := testRequest("GET", url)
+ mux.ServeHTTP(rw, r)
+
+ if strings.Contains(rw.Body.String(), "FinishRouter1") {
+ t.Errorf(testName + " FinishRouter did not run")
+ }
+ if !strings.Contains(rw.Body.String(), "hello") {
+ t.Errorf(testName + " handler did not run properly")
+ }
+ if strings.Contains(rw.Body.String(), "AfterExec1") {
+ t.Errorf(testName + " AfterExec ran in error")
+ }
+ if strings.Contains(rw.Body.String(), "BeforeRouter") {
+ t.Errorf(testName + " BeforeRouter ran in error")
+ }
+ if strings.Contains(rw.Body.String(), "BeforeExec") {
+ t.Errorf(testName + " BeforeExec ran in error")
+ }
+}
+
+// Execution point: FinishRouter
+// expectation: only first FinishRouter function is executed, match as router handles
+func TestFilterFinishRouterMultiFirstOnly(t *testing.T) {
+ testName := "TestFilterFinishRouterMultiFirstOnly"
+ url := "/finishRouterMultiFirstOnly"
+
+ mux := NewControllerRegister()
+ mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false)
+ mux.InsertFilter(url, FinishRouter, beegoFinishRouter2)
+
+ mux.Get(url, beegoFilterFunc)
+
+ rw, r := testRequest("GET", url)
+ mux.ServeHTTP(rw, r)
+
+ if !strings.Contains(rw.Body.String(), "FinishRouter1") {
+ t.Errorf(testName + " FinishRouter1 did not run")
+ }
+ if !strings.Contains(rw.Body.String(), "hello") {
+ t.Errorf(testName + " handler did not run properly")
+ }
+ // not expected in body
+ if strings.Contains(rw.Body.String(), "FinishRouter2") {
+ t.Errorf(testName + " FinishRouter2 did run")
+ }
+}
+
+// Execution point: FinishRouter
+// expectation: both FinishRouter functions execute, match as router handles
+func TestFilterFinishRouterMulti(t *testing.T) {
+ testName := "TestFilterFinishRouterMulti"
+ url := "/finishRouterMulti"
+
+ mux := NewControllerRegister()
+ mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false)
+ mux.InsertFilter(url, FinishRouter, beegoFinishRouter2, false)
+
+ mux.Get(url, beegoFilterFunc)
+
+ rw, r := testRequest("GET", url)
+ mux.ServeHTTP(rw, r)
+
+ if !strings.Contains(rw.Body.String(), "FinishRouter1") {
+ t.Errorf(testName + " FinishRouter1 did not run")
+ }
+ if !strings.Contains(rw.Body.String(), "hello") {
+ t.Errorf(testName + " handler did not run properly")
+ }
+ if !strings.Contains(rw.Body.String(), "FinishRouter2") {
+ t.Errorf(testName + " FinishRouter2 did not run properly")
+ }
+}
+
+func beegoFilterNoOutput(ctx *context.Context) {
+}
+
+func beegoBeforeRouter1(ctx *context.Context) {
+ ctx.WriteString("|BeforeRouter1")
+}
+
+func beegoBeforeExec1(ctx *context.Context) {
+ ctx.WriteString("|BeforeExec1")
+}
+
+func beegoAfterExec1(ctx *context.Context) {
+ ctx.WriteString("|AfterExec1")
+}
+
+func beegoFinishRouter1(ctx *context.Context) {
+ ctx.WriteString("|FinishRouter1")
+}
+
+func beegoFinishRouter2(ctx *context.Context) {
+ ctx.WriteString("|FinishRouter2")
+}
+
+func beegoResetParams(ctx *context.Context) {
+ ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat"))
+}
+
+func beegoHandleResetParams(ctx *context.Context) {
+ ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat"))
+}
+
+// YAML
+type YAMLController struct {
+ Controller
+}
+
+func (jc *YAMLController) Prepare() {
+ jc.Data["yaml"] = "prepare"
+ jc.ServeYAML()
+}
+
+func (jc *YAMLController) Get() {
+ jc.Data["Username"] = "astaxie"
+ jc.Ctx.Output.Body([]byte("ok"))
+}
+
+func TestYAMLPrepare(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/yaml/list", nil)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.Add("/yaml/list", &YAMLController{})
+ handler.ServeHTTP(w, r)
+ if strings.TrimSpace(w.Body.String()) != "prepare" {
+ t.Errorf(w.Body.String())
+ }
+}
+
+func TestRouterEntityTooLargeCopyBody(t *testing.T) {
+ _MaxMemory := BConfig.MaxMemory
+ _CopyRequestBody := BConfig.CopyRequestBody
+ BConfig.CopyRequestBody = true
+ BConfig.MaxMemory = 20
+
+ b := bytes.NewBuffer([]byte("barbarbarbarbarbarbarbarbarbar"))
+ r, _ := http.NewRequest("POST", "/user/123", b)
+ w := httptest.NewRecorder()
+
+ handler := NewControllerRegister()
+ handler.Post("/user/:id", func(ctx *context.Context) {
+ ctx.Output.Body([]byte(ctx.Input.Param(":id")))
+ })
+ handler.ServeHTTP(w, r)
+
+ BConfig.CopyRequestBody = _CopyRequestBody
+ BConfig.MaxMemory = _MaxMemory
+
+ if w.Code != http.StatusRequestEntityTooLarge {
+ t.Errorf("TestRouterRequestEntityTooLarge can't run")
+ }
+}
diff --git a/pkg/session/README.md b/pkg/session/README.md
new file mode 100644
index 00000000..6d0a297e
--- /dev/null
+++ b/pkg/session/README.md
@@ -0,0 +1,114 @@
+session
+==============
+
+session is a Go session manager. It can use many session providers. Just like the `database/sql` and `database/sql/driver`.
+
+## How to install?
+
+ go get github.com/astaxie/beego/session
+
+
+## What providers are supported?
+
+As of now this session manager support memory, file, Redis and MySQL.
+
+
+## How to use it?
+
+First you must import it
+
+ import (
+ "github.com/astaxie/beego/session"
+ )
+
+Then in you web app init the global session manager
+
+ var globalSessions *session.Manager
+
+* Use **memory** as provider:
+
+ func init() {
+ globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`)
+ go globalSessions.GC()
+ }
+
+* Use **file** as provider, the last param is the path where you want file to be stored:
+
+ func init() {
+ globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"./tmp"}`)
+ go globalSessions.GC()
+ }
+
+* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password:
+
+ func init() {
+ globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:6379,100,astaxie"}`)
+ go globalSessions.GC()
+ }
+
+* Use **MySQL** as provider, the last param is the DSN, learn more from [mysql](https://github.com/go-sql-driver/mysql#dsn-data-source-name):
+
+ func init() {
+ globalSessions, _ = session.NewManager(
+ "mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"username:password@protocol(address)/dbname?param=value"}`)
+ go globalSessions.GC()
+ }
+
+* Use **Cookie** as provider:
+
+ func init() {
+ globalSessions, _ = session.NewManager(
+ "cookie", `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`)
+ go globalSessions.GC()
+ }
+
+
+Finally in the handlerfunc you can use it like this
+
+ func login(w http.ResponseWriter, r *http.Request) {
+ sess := globalSessions.SessionStart(w, r)
+ defer sess.SessionRelease(w)
+ username := sess.Get("username")
+ fmt.Println(username)
+ if r.Method == "GET" {
+ t, _ := template.ParseFiles("login.gtpl")
+ t.Execute(w, nil)
+ } else {
+ fmt.Println("username:", r.Form["username"])
+ sess.Set("username", r.Form["username"])
+ fmt.Println("password:", r.Form["password"])
+ }
+ }
+
+
+## How to write own provider?
+
+When you develop a web app, maybe you want to write own provider because you must meet the requirements.
+
+Writing a provider is easy. You only need to define two struct types
+(Session and Provider), which satisfy the interface definition.
+Maybe you will find the **memory** provider is a good example.
+
+ type SessionStore interface {
+ Set(key, value interface{}) error //set session value
+ Get(key interface{}) interface{} //get session value
+ Delete(key interface{}) error //delete session value
+ SessionID() string //back current sessionID
+ SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
+ Flush() error //delete all data
+ }
+
+ type Provider interface {
+ SessionInit(gclifetime int64, config string) error
+ SessionRead(sid string) (SessionStore, error)
+ SessionExist(sid string) bool
+ SessionRegenerate(oldsid, sid string) (SessionStore, error)
+ SessionDestroy(sid string) error
+ SessionAll() int //get all active session
+ SessionGC()
+ }
+
+
+## LICENSE
+
+BSD License http://creativecommons.org/licenses/BSD/
diff --git a/pkg/session/couchbase/sess_couchbase.go b/pkg/session/couchbase/sess_couchbase.go
new file mode 100644
index 00000000..707d042c
--- /dev/null
+++ b/pkg/session/couchbase/sess_couchbase.go
@@ -0,0 +1,247 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package couchbase for session provider
+//
+// depend on github.com/couchbaselabs/go-couchbasee
+//
+// go install github.com/couchbaselabs/go-couchbase
+//
+// Usage:
+// import(
+// _ "github.com/astaxie/beego/session/couchbase"
+// "github.com/astaxie/beego/session"
+// )
+//
+// func init() {
+// globalSessions, _ = session.NewManager("couchbase", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"http://host:port/, Pool, Bucket"}``)
+// go globalSessions.GC()
+// }
+//
+// more docs: http://beego.me/docs/module/session.md
+package couchbase
+
+import (
+ "net/http"
+ "strings"
+ "sync"
+
+ couchbase "github.com/couchbase/go-couchbase"
+
+ "github.com/astaxie/beego/session"
+)
+
+var couchbpder = &Provider{}
+
+// SessionStore store each session
+type SessionStore struct {
+ b *couchbase.Bucket
+ sid string
+ lock sync.RWMutex
+ values map[interface{}]interface{}
+ maxlifetime int64
+}
+
+// Provider couchabse provided
+type Provider struct {
+ maxlifetime int64
+ savePath string
+ pool string
+ bucket string
+ b *couchbase.Bucket
+}
+
+// Set value to couchabse session
+func (cs *SessionStore) Set(key, value interface{}) error {
+ cs.lock.Lock()
+ defer cs.lock.Unlock()
+ cs.values[key] = value
+ return nil
+}
+
+// Get value from couchabse session
+func (cs *SessionStore) Get(key interface{}) interface{} {
+ cs.lock.RLock()
+ defer cs.lock.RUnlock()
+ if v, ok := cs.values[key]; ok {
+ return v
+ }
+ return nil
+}
+
+// Delete value in couchbase session by given key
+func (cs *SessionStore) Delete(key interface{}) error {
+ cs.lock.Lock()
+ defer cs.lock.Unlock()
+ delete(cs.values, key)
+ return nil
+}
+
+// Flush Clean all values in couchbase session
+func (cs *SessionStore) Flush() error {
+ cs.lock.Lock()
+ defer cs.lock.Unlock()
+ cs.values = make(map[interface{}]interface{})
+ return nil
+}
+
+// SessionID Get couchbase session store id
+func (cs *SessionStore) SessionID() string {
+ return cs.sid
+}
+
+// SessionRelease Write couchbase session with Gob string
+func (cs *SessionStore) SessionRelease(w http.ResponseWriter) {
+ defer cs.b.Close()
+
+ bo, err := session.EncodeGob(cs.values)
+ if err != nil {
+ return
+ }
+
+ cs.b.Set(cs.sid, int(cs.maxlifetime), bo)
+}
+
+func (cp *Provider) getBucket() *couchbase.Bucket {
+ c, err := couchbase.Connect(cp.savePath)
+ if err != nil {
+ return nil
+ }
+
+ pool, err := c.GetPool(cp.pool)
+ if err != nil {
+ return nil
+ }
+
+ bucket, err := pool.GetBucket(cp.bucket)
+ if err != nil {
+ return nil
+ }
+
+ return bucket
+}
+
+// SessionInit init couchbase session
+// savepath like couchbase server REST/JSON URL
+// e.g. http://host:port/, Pool, Bucket
+func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error {
+ cp.maxlifetime = maxlifetime
+ configs := strings.Split(savePath, ",")
+ if len(configs) > 0 {
+ cp.savePath = configs[0]
+ }
+ if len(configs) > 1 {
+ cp.pool = configs[1]
+ }
+ if len(configs) > 2 {
+ cp.bucket = configs[2]
+ }
+
+ return nil
+}
+
+// SessionRead read couchbase session by sid
+func (cp *Provider) SessionRead(sid string) (session.Store, error) {
+ cp.b = cp.getBucket()
+
+ var (
+ kv map[interface{}]interface{}
+ err error
+ doc []byte
+ )
+
+ err = cp.b.Get(sid, &doc)
+ if err != nil {
+ return nil, err
+ } else if doc == nil {
+ kv = make(map[interface{}]interface{})
+ } else {
+ kv, err = session.DecodeGob(doc)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
+ return cs, nil
+}
+
+// SessionExist Check couchbase session exist.
+// it checkes sid exist or not.
+func (cp *Provider) SessionExist(sid string) bool {
+ cp.b = cp.getBucket()
+ defer cp.b.Close()
+
+ var doc []byte
+
+ if err := cp.b.Get(sid, &doc); err != nil || doc == nil {
+ return false
+ }
+ return true
+}
+
+// SessionRegenerate remove oldsid and use sid to generate new session
+func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
+ cp.b = cp.getBucket()
+
+ var doc []byte
+ if err := cp.b.Get(oldsid, &doc); err != nil || doc == nil {
+ cp.b.Set(sid, int(cp.maxlifetime), "")
+ } else {
+ err := cp.b.Delete(oldsid)
+ if err != nil {
+ return nil, err
+ }
+ _, _ = cp.b.Add(sid, int(cp.maxlifetime), doc)
+ }
+
+ err := cp.b.Get(sid, &doc)
+ if err != nil {
+ return nil, err
+ }
+ var kv map[interface{}]interface{}
+ if doc == nil {
+ kv = make(map[interface{}]interface{})
+ } else {
+ kv, err = session.DecodeGob(doc)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
+ return cs, nil
+}
+
+// SessionDestroy Remove bucket in this couchbase
+func (cp *Provider) SessionDestroy(sid string) error {
+ cp.b = cp.getBucket()
+ defer cp.b.Close()
+
+ cp.b.Delete(sid)
+ return nil
+}
+
+// SessionGC Recycle
+func (cp *Provider) SessionGC() {
+}
+
+// SessionAll return all active session
+func (cp *Provider) SessionAll() int {
+ return 0
+}
+
+func init() {
+ session.Register("couchbase", couchbpder)
+}
diff --git a/pkg/session/ledis/ledis_session.go b/pkg/session/ledis/ledis_session.go
new file mode 100644
index 00000000..ee81df67
--- /dev/null
+++ b/pkg/session/ledis/ledis_session.go
@@ -0,0 +1,173 @@
+// Package ledis provide session Provider
+package ledis
+
+import (
+ "net/http"
+ "strconv"
+ "strings"
+ "sync"
+
+ "github.com/ledisdb/ledisdb/config"
+ "github.com/ledisdb/ledisdb/ledis"
+
+ "github.com/astaxie/beego/session"
+)
+
+var (
+ ledispder = &Provider{}
+ c *ledis.DB
+)
+
+// SessionStore ledis session store
+type SessionStore struct {
+ sid string
+ lock sync.RWMutex
+ values map[interface{}]interface{}
+ maxlifetime int64
+}
+
+// Set value in ledis session
+func (ls *SessionStore) Set(key, value interface{}) error {
+ ls.lock.Lock()
+ defer ls.lock.Unlock()
+ ls.values[key] = value
+ return nil
+}
+
+// Get value in ledis session
+func (ls *SessionStore) Get(key interface{}) interface{} {
+ ls.lock.RLock()
+ defer ls.lock.RUnlock()
+ if v, ok := ls.values[key]; ok {
+ return v
+ }
+ return nil
+}
+
+// Delete value in ledis session
+func (ls *SessionStore) Delete(key interface{}) error {
+ ls.lock.Lock()
+ defer ls.lock.Unlock()
+ delete(ls.values, key)
+ return nil
+}
+
+// Flush clear all values in ledis session
+func (ls *SessionStore) Flush() error {
+ ls.lock.Lock()
+ defer ls.lock.Unlock()
+ ls.values = make(map[interface{}]interface{})
+ return nil
+}
+
+// SessionID get ledis session id
+func (ls *SessionStore) SessionID() string {
+ return ls.sid
+}
+
+// SessionRelease save session values to ledis
+func (ls *SessionStore) SessionRelease(w http.ResponseWriter) {
+ b, err := session.EncodeGob(ls.values)
+ if err != nil {
+ return
+ }
+ c.Set([]byte(ls.sid), b)
+ c.Expire([]byte(ls.sid), ls.maxlifetime)
+}
+
+// Provider ledis session provider
+type Provider struct {
+ maxlifetime int64
+ savePath string
+ db int
+}
+
+// SessionInit init ledis session
+// savepath like ledis server saveDataPath,pool size
+// e.g. 127.0.0.1:6379,100,astaxie
+func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error {
+ var err error
+ lp.maxlifetime = maxlifetime
+ configs := strings.Split(savePath, ",")
+ if len(configs) == 1 {
+ lp.savePath = configs[0]
+ } else if len(configs) == 2 {
+ lp.savePath = configs[0]
+ lp.db, err = strconv.Atoi(configs[1])
+ if err != nil {
+ return err
+ }
+ }
+ cfg := new(config.Config)
+ cfg.DataDir = lp.savePath
+
+ var ledisInstance *ledis.Ledis
+ ledisInstance, err = ledis.Open(cfg)
+ if err != nil {
+ return err
+ }
+ c, err = ledisInstance.Select(lp.db)
+ return err
+}
+
+// SessionRead read ledis session by sid
+func (lp *Provider) SessionRead(sid string) (session.Store, error) {
+ var (
+ kv map[interface{}]interface{}
+ err error
+ )
+
+ kvs, _ := c.Get([]byte(sid))
+
+ if len(kvs) == 0 {
+ kv = make(map[interface{}]interface{})
+ } else {
+ if kv, err = session.DecodeGob(kvs); err != nil {
+ return nil, err
+ }
+ }
+
+ ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime}
+ return ls, nil
+}
+
+// SessionExist check ledis session exist by sid
+func (lp *Provider) SessionExist(sid string) bool {
+ count, _ := c.Exists([]byte(sid))
+ return count != 0
+}
+
+// SessionRegenerate generate new sid for ledis session
+func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
+ count, _ := c.Exists([]byte(sid))
+ if count == 0 {
+ // oldsid doesn't exists, set the new sid directly
+ // ignore error here, since if it return error
+ // the existed value will be 0
+ c.Set([]byte(sid), []byte(""))
+ c.Expire([]byte(sid), lp.maxlifetime)
+ } else {
+ data, _ := c.Get([]byte(oldsid))
+ c.Set([]byte(sid), data)
+ c.Expire([]byte(sid), lp.maxlifetime)
+ }
+ return lp.SessionRead(sid)
+}
+
+// SessionDestroy delete ledis session by id
+func (lp *Provider) SessionDestroy(sid string) error {
+ c.Del([]byte(sid))
+ return nil
+}
+
+// SessionGC Impelment method, no used.
+func (lp *Provider) SessionGC() {
+}
+
+// SessionAll return all active session
+func (lp *Provider) SessionAll() int {
+ return 0
+}
+func init() {
+ session.Register("ledis", ledispder)
+}
diff --git a/pkg/session/memcache/sess_memcache.go b/pkg/session/memcache/sess_memcache.go
new file mode 100644
index 00000000..85a2d815
--- /dev/null
+++ b/pkg/session/memcache/sess_memcache.go
@@ -0,0 +1,230 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package memcache for session provider
+//
+// depend on github.com/bradfitz/gomemcache/memcache
+//
+// go install github.com/bradfitz/gomemcache/memcache
+//
+// Usage:
+// import(
+// _ "github.com/astaxie/beego/session/memcache"
+// "github.com/astaxie/beego/session"
+// )
+//
+// func init() {
+// globalSessions, _ = session.NewManager("memcache", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:11211"}``)
+// go globalSessions.GC()
+// }
+//
+// more docs: http://beego.me/docs/module/session.md
+package memcache
+
+import (
+ "net/http"
+ "strings"
+ "sync"
+
+ "github.com/astaxie/beego/session"
+
+ "github.com/bradfitz/gomemcache/memcache"
+)
+
+var mempder = &MemProvider{}
+var client *memcache.Client
+
+// SessionStore memcache session store
+type SessionStore struct {
+ sid string
+ lock sync.RWMutex
+ values map[interface{}]interface{}
+ maxlifetime int64
+}
+
+// Set value in memcache session
+func (rs *SessionStore) Set(key, value interface{}) error {
+ rs.lock.Lock()
+ defer rs.lock.Unlock()
+ rs.values[key] = value
+ return nil
+}
+
+// Get value in memcache session
+func (rs *SessionStore) Get(key interface{}) interface{} {
+ rs.lock.RLock()
+ defer rs.lock.RUnlock()
+ if v, ok := rs.values[key]; ok {
+ return v
+ }
+ return nil
+}
+
+// Delete value in memcache session
+func (rs *SessionStore) Delete(key interface{}) error {
+ rs.lock.Lock()
+ defer rs.lock.Unlock()
+ delete(rs.values, key)
+ return nil
+}
+
+// Flush clear all values in memcache session
+func (rs *SessionStore) Flush() error {
+ rs.lock.Lock()
+ defer rs.lock.Unlock()
+ rs.values = make(map[interface{}]interface{})
+ return nil
+}
+
+// SessionID get memcache session id
+func (rs *SessionStore) SessionID() string {
+ return rs.sid
+}
+
+// SessionRelease save session values to memcache
+func (rs *SessionStore) SessionRelease(w http.ResponseWriter) {
+ b, err := session.EncodeGob(rs.values)
+ if err != nil {
+ return
+ }
+ item := memcache.Item{Key: rs.sid, Value: b, Expiration: int32(rs.maxlifetime)}
+ client.Set(&item)
+}
+
+// MemProvider memcache session provider
+type MemProvider struct {
+ maxlifetime int64
+ conninfo []string
+ poolsize int
+ password string
+}
+
+// SessionInit init memcache session
+// savepath like
+// e.g. 127.0.0.1:9090
+func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error {
+ rp.maxlifetime = maxlifetime
+ rp.conninfo = strings.Split(savePath, ";")
+ client = memcache.New(rp.conninfo...)
+ return nil
+}
+
+// SessionRead read memcache session by sid
+func (rp *MemProvider) SessionRead(sid string) (session.Store, error) {
+ if client == nil {
+ if err := rp.connectInit(); err != nil {
+ return nil, err
+ }
+ }
+ item, err := client.Get(sid)
+ if err != nil {
+ if err == memcache.ErrCacheMiss {
+ rs := &SessionStore{sid: sid, values: make(map[interface{}]interface{}), maxlifetime: rp.maxlifetime}
+ return rs, nil
+ }
+ return nil, err
+ }
+ var kv map[interface{}]interface{}
+ if len(item.Value) == 0 {
+ kv = make(map[interface{}]interface{})
+ } else {
+ kv, err = session.DecodeGob(item.Value)
+ if err != nil {
+ return nil, err
+ }
+ }
+ rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime}
+ return rs, nil
+}
+
+// SessionExist check memcache session exist by sid
+func (rp *MemProvider) SessionExist(sid string) bool {
+ if client == nil {
+ if err := rp.connectInit(); err != nil {
+ return false
+ }
+ }
+ if item, err := client.Get(sid); err != nil || len(item.Value) == 0 {
+ return false
+ }
+ return true
+}
+
+// SessionRegenerate generate new sid for memcache session
+func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
+ if client == nil {
+ if err := rp.connectInit(); err != nil {
+ return nil, err
+ }
+ }
+ var contain []byte
+ if item, err := client.Get(sid); err != nil || len(item.Value) == 0 {
+ // oldsid doesn't exists, set the new sid directly
+ // ignore error here, since if it return error
+ // the existed value will be 0
+ item.Key = sid
+ item.Value = []byte("")
+ item.Expiration = int32(rp.maxlifetime)
+ client.Set(item)
+ } else {
+ client.Delete(oldsid)
+ item.Key = sid
+ item.Expiration = int32(rp.maxlifetime)
+ client.Set(item)
+ contain = item.Value
+ }
+
+ var kv map[interface{}]interface{}
+ if len(contain) == 0 {
+ kv = make(map[interface{}]interface{})
+ } else {
+ var err error
+ kv, err = session.DecodeGob(contain)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime}
+ return rs, nil
+}
+
+// SessionDestroy delete memcache session by id
+func (rp *MemProvider) SessionDestroy(sid string) error {
+ if client == nil {
+ if err := rp.connectInit(); err != nil {
+ return err
+ }
+ }
+
+ return client.Delete(sid)
+}
+
+func (rp *MemProvider) connectInit() error {
+ client = memcache.New(rp.conninfo...)
+ return nil
+}
+
+// SessionGC Impelment method, no used.
+func (rp *MemProvider) SessionGC() {
+}
+
+// SessionAll return all activeSession
+func (rp *MemProvider) SessionAll() int {
+ return 0
+}
+
+func init() {
+ session.Register("memcache", mempder)
+}
diff --git a/pkg/session/mysql/sess_mysql.go b/pkg/session/mysql/sess_mysql.go
new file mode 100644
index 00000000..301353ab
--- /dev/null
+++ b/pkg/session/mysql/sess_mysql.go
@@ -0,0 +1,228 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package mysql for session provider
+//
+// depends on github.com/go-sql-driver/mysql:
+//
+// go install github.com/go-sql-driver/mysql
+//
+// mysql session support need create table as sql:
+// CREATE TABLE `session` (
+// `session_key` char(64) NOT NULL,
+// `session_data` blob,
+// `session_expiry` int(11) unsigned NOT NULL,
+// PRIMARY KEY (`session_key`)
+// ) ENGINE=MyISAM DEFAULT CHARSET=utf8;
+//
+// Usage:
+// import(
+// _ "github.com/astaxie/beego/session/mysql"
+// "github.com/astaxie/beego/session"
+// )
+//
+// func init() {
+// globalSessions, _ = session.NewManager("mysql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN]"}``)
+// go globalSessions.GC()
+// }
+//
+// more docs: http://beego.me/docs/module/session.md
+package mysql
+
+import (
+ "database/sql"
+ "net/http"
+ "sync"
+ "time"
+
+ "github.com/astaxie/beego/session"
+ // import mysql driver
+ _ "github.com/go-sql-driver/mysql"
+)
+
+var (
+ // TableName store the session in MySQL
+ TableName = "session"
+ mysqlpder = &Provider{}
+)
+
+// SessionStore mysql session store
+type SessionStore struct {
+ c *sql.DB
+ sid string
+ lock sync.RWMutex
+ values map[interface{}]interface{}
+}
+
+// Set value in mysql session.
+// it is temp value in map.
+func (st *SessionStore) Set(key, value interface{}) error {
+ st.lock.Lock()
+ defer st.lock.Unlock()
+ st.values[key] = value
+ return nil
+}
+
+// Get value from mysql session
+func (st *SessionStore) Get(key interface{}) interface{} {
+ st.lock.RLock()
+ defer st.lock.RUnlock()
+ if v, ok := st.values[key]; ok {
+ return v
+ }
+ return nil
+}
+
+// Delete value in mysql session
+func (st *SessionStore) Delete(key interface{}) error {
+ st.lock.Lock()
+ defer st.lock.Unlock()
+ delete(st.values, key)
+ return nil
+}
+
+// Flush clear all values in mysql session
+func (st *SessionStore) Flush() error {
+ st.lock.Lock()
+ defer st.lock.Unlock()
+ st.values = make(map[interface{}]interface{})
+ return nil
+}
+
+// SessionID get session id of this mysql session store
+func (st *SessionStore) SessionID() string {
+ return st.sid
+}
+
+// SessionRelease save mysql session values to database.
+// must call this method to save values to database.
+func (st *SessionStore) SessionRelease(w http.ResponseWriter) {
+ defer st.c.Close()
+ b, err := session.EncodeGob(st.values)
+ if err != nil {
+ return
+ }
+ st.c.Exec("UPDATE "+TableName+" set `session_data`=?, `session_expiry`=? where session_key=?",
+ b, time.Now().Unix(), st.sid)
+}
+
+// Provider mysql session provider
+type Provider struct {
+ maxlifetime int64
+ savePath string
+}
+
+// connect to mysql
+func (mp *Provider) connectInit() *sql.DB {
+ db, e := sql.Open("mysql", mp.savePath)
+ if e != nil {
+ return nil
+ }
+ return db
+}
+
+// SessionInit init mysql session.
+// savepath is the connection string of mysql.
+func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error {
+ mp.maxlifetime = maxlifetime
+ mp.savePath = savePath
+ return nil
+}
+
+// SessionRead get mysql session by sid
+func (mp *Provider) SessionRead(sid string) (session.Store, error) {
+ c := mp.connectInit()
+ row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid)
+ var sessiondata []byte
+ err := row.Scan(&sessiondata)
+ if err == sql.ErrNoRows {
+ c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)",
+ sid, "", time.Now().Unix())
+ }
+ var kv map[interface{}]interface{}
+ if len(sessiondata) == 0 {
+ kv = make(map[interface{}]interface{})
+ } else {
+ kv, err = session.DecodeGob(sessiondata)
+ if err != nil {
+ return nil, err
+ }
+ }
+ rs := &SessionStore{c: c, sid: sid, values: kv}
+ return rs, nil
+}
+
+// SessionExist check mysql session exist
+func (mp *Provider) SessionExist(sid string) bool {
+ c := mp.connectInit()
+ defer c.Close()
+ row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid)
+ var sessiondata []byte
+ err := row.Scan(&sessiondata)
+ return err != sql.ErrNoRows
+}
+
+// SessionRegenerate generate new sid for mysql session
+func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
+ c := mp.connectInit()
+ row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid)
+ var sessiondata []byte
+ err := row.Scan(&sessiondata)
+ if err == sql.ErrNoRows {
+ c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix())
+ }
+ c.Exec("update "+TableName+" set `session_key`=? where session_key=?", sid, oldsid)
+ var kv map[interface{}]interface{}
+ if len(sessiondata) == 0 {
+ kv = make(map[interface{}]interface{})
+ } else {
+ kv, err = session.DecodeGob(sessiondata)
+ if err != nil {
+ return nil, err
+ }
+ }
+ rs := &SessionStore{c: c, sid: sid, values: kv}
+ return rs, nil
+}
+
+// SessionDestroy delete mysql session by sid
+func (mp *Provider) SessionDestroy(sid string) error {
+ c := mp.connectInit()
+ c.Exec("DELETE FROM "+TableName+" where session_key=?", sid)
+ c.Close()
+ return nil
+}
+
+// SessionGC delete expired values in mysql session
+func (mp *Provider) SessionGC() {
+ c := mp.connectInit()
+ c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime)
+ c.Close()
+}
+
+// SessionAll count values in mysql session
+func (mp *Provider) SessionAll() int {
+ c := mp.connectInit()
+ defer c.Close()
+ var total int
+ err := c.QueryRow("SELECT count(*) as num from " + TableName).Scan(&total)
+ if err != nil {
+ return 0
+ }
+ return total
+}
+
+func init() {
+ session.Register("mysql", mysqlpder)
+}
diff --git a/pkg/session/postgres/sess_postgresql.go b/pkg/session/postgres/sess_postgresql.go
new file mode 100644
index 00000000..0b8b9645
--- /dev/null
+++ b/pkg/session/postgres/sess_postgresql.go
@@ -0,0 +1,243 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package postgres for session provider
+//
+// depends on github.com/lib/pq:
+//
+// go install github.com/lib/pq
+//
+//
+// needs this table in your database:
+//
+// CREATE TABLE session (
+// session_key char(64) NOT NULL,
+// session_data bytea,
+// session_expiry timestamp NOT NULL,
+// CONSTRAINT session_key PRIMARY KEY(session_key)
+// );
+//
+// will be activated with these settings in app.conf:
+//
+// SessionOn = true
+// SessionProvider = postgresql
+// SessionSavePath = "user=a password=b dbname=c sslmode=disable"
+// SessionName = session
+//
+//
+// Usage:
+// import(
+// _ "github.com/astaxie/beego/session/postgresql"
+// "github.com/astaxie/beego/session"
+// )
+//
+// func init() {
+// globalSessions, _ = session.NewManager("postgresql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"user=pqgotest dbname=pqgotest sslmode=verify-full"}``)
+// go globalSessions.GC()
+// }
+//
+// more docs: http://beego.me/docs/module/session.md
+package postgres
+
+import (
+ "database/sql"
+ "net/http"
+ "sync"
+ "time"
+
+ "github.com/astaxie/beego/session"
+ // import postgresql Driver
+ _ "github.com/lib/pq"
+)
+
+var postgresqlpder = &Provider{}
+
+// SessionStore postgresql session store
+type SessionStore struct {
+ c *sql.DB
+ sid string
+ lock sync.RWMutex
+ values map[interface{}]interface{}
+}
+
+// Set value in postgresql session.
+// it is temp value in map.
+func (st *SessionStore) Set(key, value interface{}) error {
+ st.lock.Lock()
+ defer st.lock.Unlock()
+ st.values[key] = value
+ return nil
+}
+
+// Get value from postgresql session
+func (st *SessionStore) Get(key interface{}) interface{} {
+ st.lock.RLock()
+ defer st.lock.RUnlock()
+ if v, ok := st.values[key]; ok {
+ return v
+ }
+ return nil
+}
+
+// Delete value in postgresql session
+func (st *SessionStore) Delete(key interface{}) error {
+ st.lock.Lock()
+ defer st.lock.Unlock()
+ delete(st.values, key)
+ return nil
+}
+
+// Flush clear all values in postgresql session
+func (st *SessionStore) Flush() error {
+ st.lock.Lock()
+ defer st.lock.Unlock()
+ st.values = make(map[interface{}]interface{})
+ return nil
+}
+
+// SessionID get session id of this postgresql session store
+func (st *SessionStore) SessionID() string {
+ return st.sid
+}
+
+// SessionRelease save postgresql session values to database.
+// must call this method to save values to database.
+func (st *SessionStore) SessionRelease(w http.ResponseWriter) {
+ defer st.c.Close()
+ b, err := session.EncodeGob(st.values)
+ if err != nil {
+ return
+ }
+ st.c.Exec("UPDATE session set session_data=$1, session_expiry=$2 where session_key=$3",
+ b, time.Now().Format(time.RFC3339), st.sid)
+
+}
+
+// Provider postgresql session provider
+type Provider struct {
+ maxlifetime int64
+ savePath string
+}
+
+// connect to postgresql
+func (mp *Provider) connectInit() *sql.DB {
+ db, e := sql.Open("postgres", mp.savePath)
+ if e != nil {
+ return nil
+ }
+ return db
+}
+
+// SessionInit init postgresql session.
+// savepath is the connection string of postgresql.
+func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error {
+ mp.maxlifetime = maxlifetime
+ mp.savePath = savePath
+ return nil
+}
+
+// SessionRead get postgresql session by sid
+func (mp *Provider) SessionRead(sid string) (session.Store, error) {
+ c := mp.connectInit()
+ row := c.QueryRow("select session_data from session where session_key=$1", sid)
+ var sessiondata []byte
+ err := row.Scan(&sessiondata)
+ if err == sql.ErrNoRows {
+ _, err = c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)",
+ sid, "", time.Now().Format(time.RFC3339))
+
+ if err != nil {
+ return nil, err
+ }
+ } else if err != nil {
+ return nil, err
+ }
+
+ var kv map[interface{}]interface{}
+ if len(sessiondata) == 0 {
+ kv = make(map[interface{}]interface{})
+ } else {
+ kv, err = session.DecodeGob(sessiondata)
+ if err != nil {
+ return nil, err
+ }
+ }
+ rs := &SessionStore{c: c, sid: sid, values: kv}
+ return rs, nil
+}
+
+// SessionExist check postgresql session exist
+func (mp *Provider) SessionExist(sid string) bool {
+ c := mp.connectInit()
+ defer c.Close()
+ row := c.QueryRow("select session_data from session where session_key=$1", sid)
+ var sessiondata []byte
+ err := row.Scan(&sessiondata)
+ return err != sql.ErrNoRows
+}
+
+// SessionRegenerate generate new sid for postgresql session
+func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
+ c := mp.connectInit()
+ row := c.QueryRow("select session_data from session where session_key=$1", oldsid)
+ var sessiondata []byte
+ err := row.Scan(&sessiondata)
+ if err == sql.ErrNoRows {
+ c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)",
+ oldsid, "", time.Now().Format(time.RFC3339))
+ }
+ c.Exec("update session set session_key=$1 where session_key=$2", sid, oldsid)
+ var kv map[interface{}]interface{}
+ if len(sessiondata) == 0 {
+ kv = make(map[interface{}]interface{})
+ } else {
+ kv, err = session.DecodeGob(sessiondata)
+ if err != nil {
+ return nil, err
+ }
+ }
+ rs := &SessionStore{c: c, sid: sid, values: kv}
+ return rs, nil
+}
+
+// SessionDestroy delete postgresql session by sid
+func (mp *Provider) SessionDestroy(sid string) error {
+ c := mp.connectInit()
+ c.Exec("DELETE FROM session where session_key=$1", sid)
+ c.Close()
+ return nil
+}
+
+// SessionGC delete expired values in postgresql session
+func (mp *Provider) SessionGC() {
+ c := mp.connectInit()
+ c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime)
+ c.Close()
+}
+
+// SessionAll count values in postgresql session
+func (mp *Provider) SessionAll() int {
+ c := mp.connectInit()
+ defer c.Close()
+ var total int
+ err := c.QueryRow("SELECT count(*) as num from session").Scan(&total)
+ if err != nil {
+ return 0
+ }
+ return total
+}
+
+func init() {
+ session.Register("postgresql", postgresqlpder)
+}
diff --git a/pkg/session/redis/sess_redis.go b/pkg/session/redis/sess_redis.go
new file mode 100644
index 00000000..5c382d61
--- /dev/null
+++ b/pkg/session/redis/sess_redis.go
@@ -0,0 +1,261 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package redis for session provider
+//
+// depend on github.com/gomodule/redigo/redis
+//
+// go install github.com/gomodule/redigo/redis
+//
+// Usage:
+// import(
+// _ "github.com/astaxie/beego/session/redis"
+// "github.com/astaxie/beego/session"
+// )
+//
+// func init() {
+// globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``)
+// go globalSessions.GC()
+// }
+//
+// more docs: http://beego.me/docs/module/session.md
+package redis
+
+import (
+ "net/http"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/astaxie/beego/session"
+
+ "github.com/gomodule/redigo/redis"
+)
+
+var redispder = &Provider{}
+
+// MaxPoolSize redis max pool size
+var MaxPoolSize = 100
+
+// SessionStore redis session store
+type SessionStore struct {
+ p *redis.Pool
+ sid string
+ lock sync.RWMutex
+ values map[interface{}]interface{}
+ maxlifetime int64
+}
+
+// Set value in redis session
+func (rs *SessionStore) Set(key, value interface{}) error {
+ rs.lock.Lock()
+ defer rs.lock.Unlock()
+ rs.values[key] = value
+ return nil
+}
+
+// Get value in redis session
+func (rs *SessionStore) Get(key interface{}) interface{} {
+ rs.lock.RLock()
+ defer rs.lock.RUnlock()
+ if v, ok := rs.values[key]; ok {
+ return v
+ }
+ return nil
+}
+
+// Delete value in redis session
+func (rs *SessionStore) Delete(key interface{}) error {
+ rs.lock.Lock()
+ defer rs.lock.Unlock()
+ delete(rs.values, key)
+ return nil
+}
+
+// Flush clear all values in redis session
+func (rs *SessionStore) Flush() error {
+ rs.lock.Lock()
+ defer rs.lock.Unlock()
+ rs.values = make(map[interface{}]interface{})
+ return nil
+}
+
+// SessionID get redis session id
+func (rs *SessionStore) SessionID() string {
+ return rs.sid
+}
+
+// SessionRelease save session values to redis
+func (rs *SessionStore) SessionRelease(w http.ResponseWriter) {
+ b, err := session.EncodeGob(rs.values)
+ if err != nil {
+ return
+ }
+ c := rs.p.Get()
+ defer c.Close()
+ c.Do("SETEX", rs.sid, rs.maxlifetime, string(b))
+}
+
+// Provider redis session provider
+type Provider struct {
+ maxlifetime int64
+ savePath string
+ poolsize int
+ password string
+ dbNum int
+ poollist *redis.Pool
+}
+
+// SessionInit init redis session
+// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second
+// e.g. 127.0.0.1:6379,100,astaxie,0,30
+func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
+ rp.maxlifetime = maxlifetime
+ configs := strings.Split(savePath, ",")
+ if len(configs) > 0 {
+ rp.savePath = configs[0]
+ }
+ if len(configs) > 1 {
+ poolsize, err := strconv.Atoi(configs[1])
+ if err != nil || poolsize < 0 {
+ rp.poolsize = MaxPoolSize
+ } else {
+ rp.poolsize = poolsize
+ }
+ } else {
+ rp.poolsize = MaxPoolSize
+ }
+ if len(configs) > 2 {
+ rp.password = configs[2]
+ }
+ if len(configs) > 3 {
+ dbnum, err := strconv.Atoi(configs[3])
+ if err != nil || dbnum < 0 {
+ rp.dbNum = 0
+ } else {
+ rp.dbNum = dbnum
+ }
+ } else {
+ rp.dbNum = 0
+ }
+ var idleTimeout time.Duration = 0
+ if len(configs) > 4 {
+ timeout, err := strconv.Atoi(configs[4])
+ if err == nil && timeout > 0 {
+ idleTimeout = time.Duration(timeout) * time.Second
+ }
+ }
+ rp.poollist = &redis.Pool{
+ Dial: func() (redis.Conn, error) {
+ c, err := redis.Dial("tcp", rp.savePath)
+ if err != nil {
+ return nil, err
+ }
+ if rp.password != "" {
+ if _, err = c.Do("AUTH", rp.password); err != nil {
+ c.Close()
+ return nil, err
+ }
+ }
+ // some redis proxy such as twemproxy is not support select command
+ if rp.dbNum > 0 {
+ _, err = c.Do("SELECT", rp.dbNum)
+ if err != nil {
+ c.Close()
+ return nil, err
+ }
+ }
+ return c, err
+ },
+ MaxIdle: rp.poolsize,
+ }
+
+ rp.poollist.IdleTimeout = idleTimeout
+
+ return rp.poollist.Get().Err()
+}
+
+// SessionRead read redis session by sid
+func (rp *Provider) SessionRead(sid string) (session.Store, error) {
+ c := rp.poollist.Get()
+ defer c.Close()
+
+ var kv map[interface{}]interface{}
+
+ kvs, err := redis.String(c.Do("GET", sid))
+ if err != nil && err != redis.ErrNil {
+ return nil, err
+ }
+ if len(kvs) == 0 {
+ kv = make(map[interface{}]interface{})
+ } else {
+ if kv, err = session.DecodeGob([]byte(kvs)); err != nil {
+ return nil, err
+ }
+ }
+
+ rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
+ return rs, nil
+}
+
+// SessionExist check redis session exist by sid
+func (rp *Provider) SessionExist(sid string) bool {
+ c := rp.poollist.Get()
+ defer c.Close()
+
+ if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 {
+ return false
+ }
+ return true
+}
+
+// SessionRegenerate generate new sid for redis session
+func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
+ c := rp.poollist.Get()
+ defer c.Close()
+
+ if existed, _ := redis.Int(c.Do("EXISTS", oldsid)); existed == 0 {
+ // oldsid doesn't exists, set the new sid directly
+ // ignore error here, since if it return error
+ // the existed value will be 0
+ c.Do("SET", sid, "", "EX", rp.maxlifetime)
+ } else {
+ c.Do("RENAME", oldsid, sid)
+ c.Do("EXPIRE", sid, rp.maxlifetime)
+ }
+ return rp.SessionRead(sid)
+}
+
+// SessionDestroy delete redis session by id
+func (rp *Provider) SessionDestroy(sid string) error {
+ c := rp.poollist.Get()
+ defer c.Close()
+
+ c.Do("DEL", sid)
+ return nil
+}
+
+// SessionGC Impelment method, no used.
+func (rp *Provider) SessionGC() {
+}
+
+// SessionAll return all activeSession
+func (rp *Provider) SessionAll() int {
+ return 0
+}
+
+func init() {
+ session.Register("redis", redispder)
+}
diff --git a/pkg/session/redis_cluster/redis_cluster.go b/pkg/session/redis_cluster/redis_cluster.go
new file mode 100644
index 00000000..2fe300df
--- /dev/null
+++ b/pkg/session/redis_cluster/redis_cluster.go
@@ -0,0 +1,220 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package redis for session provider
+//
+// depend on github.com/go-redis/redis
+//
+// go install github.com/go-redis/redis
+//
+// Usage:
+// import(
+// _ "github.com/astaxie/beego/session/redis_cluster"
+// "github.com/astaxie/beego/session"
+// )
+//
+// func init() {
+// globalSessions, _ = session.NewManager("redis_cluster", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070;127.0.0.1:7071"}``)
+// go globalSessions.GC()
+// }
+//
+// more docs: http://beego.me/docs/module/session.md
+package redis_cluster
+import (
+ "net/http"
+ "strconv"
+ "strings"
+ "sync"
+ "github.com/astaxie/beego/session"
+ rediss "github.com/go-redis/redis"
+ "time"
+)
+
+var redispder = &Provider{}
+
+// MaxPoolSize redis_cluster max pool size
+var MaxPoolSize = 1000
+
+// SessionStore redis_cluster session store
+type SessionStore struct {
+ p *rediss.ClusterClient
+ sid string
+ lock sync.RWMutex
+ values map[interface{}]interface{}
+ maxlifetime int64
+}
+
+// Set value in redis_cluster session
+func (rs *SessionStore) Set(key, value interface{}) error {
+ rs.lock.Lock()
+ defer rs.lock.Unlock()
+ rs.values[key] = value
+ return nil
+}
+
+// Get value in redis_cluster session
+func (rs *SessionStore) Get(key interface{}) interface{} {
+ rs.lock.RLock()
+ defer rs.lock.RUnlock()
+ if v, ok := rs.values[key]; ok {
+ return v
+ }
+ return nil
+}
+
+// Delete value in redis_cluster session
+func (rs *SessionStore) Delete(key interface{}) error {
+ rs.lock.Lock()
+ defer rs.lock.Unlock()
+ delete(rs.values, key)
+ return nil
+}
+
+// Flush clear all values in redis_cluster session
+func (rs *SessionStore) Flush() error {
+ rs.lock.Lock()
+ defer rs.lock.Unlock()
+ rs.values = make(map[interface{}]interface{})
+ return nil
+}
+
+// SessionID get redis_cluster session id
+func (rs *SessionStore) SessionID() string {
+ return rs.sid
+}
+
+// SessionRelease save session values to redis_cluster
+func (rs *SessionStore) SessionRelease(w http.ResponseWriter) {
+ b, err := session.EncodeGob(rs.values)
+ if err != nil {
+ return
+ }
+ c := rs.p
+ c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime) * time.Second)
+}
+
+// Provider redis_cluster session provider
+type Provider struct {
+ maxlifetime int64
+ savePath string
+ poolsize int
+ password string
+ dbNum int
+ poollist *rediss.ClusterClient
+}
+
+// SessionInit init redis_cluster session
+// savepath like redis server addr,pool size,password,dbnum
+// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0
+func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
+ rp.maxlifetime = maxlifetime
+ configs := strings.Split(savePath, ",")
+ if len(configs) > 0 {
+ rp.savePath = configs[0]
+ }
+ if len(configs) > 1 {
+ poolsize, err := strconv.Atoi(configs[1])
+ if err != nil || poolsize < 0 {
+ rp.poolsize = MaxPoolSize
+ } else {
+ rp.poolsize = poolsize
+ }
+ } else {
+ rp.poolsize = MaxPoolSize
+ }
+ if len(configs) > 2 {
+ rp.password = configs[2]
+ }
+ if len(configs) > 3 {
+ dbnum, err := strconv.Atoi(configs[3])
+ if err != nil || dbnum < 0 {
+ rp.dbNum = 0
+ } else {
+ rp.dbNum = dbnum
+ }
+ } else {
+ rp.dbNum = 0
+ }
+
+ rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{
+ Addrs: strings.Split(rp.savePath, ";"),
+ Password: rp.password,
+ PoolSize: rp.poolsize,
+ })
+ return rp.poollist.Ping().Err()
+}
+
+// SessionRead read redis_cluster session by sid
+func (rp *Provider) SessionRead(sid string) (session.Store, error) {
+ var kv map[interface{}]interface{}
+ kvs, err := rp.poollist.Get(sid).Result()
+ if err != nil && err != rediss.Nil {
+ return nil, err
+ }
+ if len(kvs) == 0 {
+ kv = make(map[interface{}]interface{})
+ } else {
+ if kv, err = session.DecodeGob([]byte(kvs)); err != nil {
+ return nil, err
+ }
+ }
+
+ rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
+ return rs, nil
+}
+
+// SessionExist check redis_cluster session exist by sid
+func (rp *Provider) SessionExist(sid string) bool {
+ c := rp.poollist
+ if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 {
+ return false
+ }
+ return true
+}
+
+// SessionRegenerate generate new sid for redis_cluster session
+func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
+ c := rp.poollist
+
+ if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 {
+ // oldsid doesn't exists, set the new sid directly
+ // ignore error here, since if it return error
+ // the existed value will be 0
+ c.Set(sid, "", time.Duration(rp.maxlifetime) * time.Second)
+ } else {
+ c.Rename(oldsid, sid)
+ c.Expire(sid, time.Duration(rp.maxlifetime) * time.Second)
+ }
+ return rp.SessionRead(sid)
+}
+
+// SessionDestroy delete redis session by id
+func (rp *Provider) SessionDestroy(sid string) error {
+ c := rp.poollist
+ c.Del(sid)
+ return nil
+}
+
+// SessionGC Impelment method, no used.
+func (rp *Provider) SessionGC() {
+}
+
+// SessionAll return all activeSession
+func (rp *Provider) SessionAll() int {
+ return 0
+}
+
+func init() {
+ session.Register("redis_cluster", redispder)
+}
diff --git a/pkg/session/redis_sentinel/sess_redis_sentinel.go b/pkg/session/redis_sentinel/sess_redis_sentinel.go
new file mode 100644
index 00000000..6ecb2977
--- /dev/null
+++ b/pkg/session/redis_sentinel/sess_redis_sentinel.go
@@ -0,0 +1,234 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package redis for session provider
+//
+// depend on github.com/go-redis/redis
+//
+// go install github.com/go-redis/redis
+//
+// Usage:
+// import(
+// _ "github.com/astaxie/beego/session/redis_sentinel"
+// "github.com/astaxie/beego/session"
+// )
+//
+// func init() {
+// globalSessions, _ = session.NewManager("redis_sentinel", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:26379;127.0.0.2:26379"}``)
+// go globalSessions.GC()
+// }
+//
+// more detail about params: please check the notes on the function SessionInit in this package
+package redis_sentinel
+
+import (
+ "github.com/astaxie/beego/session"
+ "github.com/go-redis/redis"
+ "net/http"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+)
+
+var redispder = &Provider{}
+
+// DefaultPoolSize redis_sentinel default pool size
+var DefaultPoolSize = 100
+
+// SessionStore redis_sentinel session store
+type SessionStore struct {
+ p *redis.Client
+ sid string
+ lock sync.RWMutex
+ values map[interface{}]interface{}
+ maxlifetime int64
+}
+
+// Set value in redis_sentinel session
+func (rs *SessionStore) Set(key, value interface{}) error {
+ rs.lock.Lock()
+ defer rs.lock.Unlock()
+ rs.values[key] = value
+ return nil
+}
+
+// Get value in redis_sentinel session
+func (rs *SessionStore) Get(key interface{}) interface{} {
+ rs.lock.RLock()
+ defer rs.lock.RUnlock()
+ if v, ok := rs.values[key]; ok {
+ return v
+ }
+ return nil
+}
+
+// Delete value in redis_sentinel session
+func (rs *SessionStore) Delete(key interface{}) error {
+ rs.lock.Lock()
+ defer rs.lock.Unlock()
+ delete(rs.values, key)
+ return nil
+}
+
+// Flush clear all values in redis_sentinel session
+func (rs *SessionStore) Flush() error {
+ rs.lock.Lock()
+ defer rs.lock.Unlock()
+ rs.values = make(map[interface{}]interface{})
+ return nil
+}
+
+// SessionID get redis_sentinel session id
+func (rs *SessionStore) SessionID() string {
+ return rs.sid
+}
+
+// SessionRelease save session values to redis_sentinel
+func (rs *SessionStore) SessionRelease(w http.ResponseWriter) {
+ b, err := session.EncodeGob(rs.values)
+ if err != nil {
+ return
+ }
+ c := rs.p
+ c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second)
+}
+
+// Provider redis_sentinel session provider
+type Provider struct {
+ maxlifetime int64
+ savePath string
+ poolsize int
+ password string
+ dbNum int
+ poollist *redis.Client
+ masterName string
+}
+
+// SessionInit init redis_sentinel session
+// savepath like redis sentinel addr,pool size,password,dbnum,masterName
+// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster
+func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
+ rp.maxlifetime = maxlifetime
+ configs := strings.Split(savePath, ",")
+ if len(configs) > 0 {
+ rp.savePath = configs[0]
+ }
+ if len(configs) > 1 {
+ poolsize, err := strconv.Atoi(configs[1])
+ if err != nil || poolsize < 0 {
+ rp.poolsize = DefaultPoolSize
+ } else {
+ rp.poolsize = poolsize
+ }
+ } else {
+ rp.poolsize = DefaultPoolSize
+ }
+ if len(configs) > 2 {
+ rp.password = configs[2]
+ }
+ if len(configs) > 3 {
+ dbnum, err := strconv.Atoi(configs[3])
+ if err != nil || dbnum < 0 {
+ rp.dbNum = 0
+ } else {
+ rp.dbNum = dbnum
+ }
+ } else {
+ rp.dbNum = 0
+ }
+ if len(configs) > 4 {
+ if configs[4] != "" {
+ rp.masterName = configs[4]
+ } else {
+ rp.masterName = "mymaster"
+ }
+ } else {
+ rp.masterName = "mymaster"
+ }
+
+ rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{
+ SentinelAddrs: strings.Split(rp.savePath, ";"),
+ Password: rp.password,
+ PoolSize: rp.poolsize,
+ DB: rp.dbNum,
+ MasterName: rp.masterName,
+ })
+
+ return rp.poollist.Ping().Err()
+}
+
+// SessionRead read redis_sentinel session by sid
+func (rp *Provider) SessionRead(sid string) (session.Store, error) {
+ var kv map[interface{}]interface{}
+ kvs, err := rp.poollist.Get(sid).Result()
+ if err != nil && err != redis.Nil {
+ return nil, err
+ }
+ if len(kvs) == 0 {
+ kv = make(map[interface{}]interface{})
+ } else {
+ if kv, err = session.DecodeGob([]byte(kvs)); err != nil {
+ return nil, err
+ }
+ }
+
+ rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
+ return rs, nil
+}
+
+// SessionExist check redis_sentinel session exist by sid
+func (rp *Provider) SessionExist(sid string) bool {
+ c := rp.poollist
+ if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 {
+ return false
+ }
+ return true
+}
+
+// SessionRegenerate generate new sid for redis_sentinel session
+func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
+ c := rp.poollist
+
+ if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 {
+ // oldsid doesn't exists, set the new sid directly
+ // ignore error here, since if it return error
+ // the existed value will be 0
+ c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second)
+ } else {
+ c.Rename(oldsid, sid)
+ c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second)
+ }
+ return rp.SessionRead(sid)
+}
+
+// SessionDestroy delete redis session by id
+func (rp *Provider) SessionDestroy(sid string) error {
+ c := rp.poollist
+ c.Del(sid)
+ return nil
+}
+
+// SessionGC Impelment method, no used.
+func (rp *Provider) SessionGC() {
+}
+
+// SessionAll return all activeSession
+func (rp *Provider) SessionAll() int {
+ return 0
+}
+
+func init() {
+ session.Register("redis_sentinel", redispder)
+}
diff --git a/pkg/session/redis_sentinel/sess_redis_sentinel_test.go b/pkg/session/redis_sentinel/sess_redis_sentinel_test.go
new file mode 100644
index 00000000..fd4155c6
--- /dev/null
+++ b/pkg/session/redis_sentinel/sess_redis_sentinel_test.go
@@ -0,0 +1,90 @@
+package redis_sentinel
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/astaxie/beego/session"
+)
+
+func TestRedisSentinel(t *testing.T) {
+ sessionConfig := &session.ManagerConfig{
+ CookieName: "gosessionid",
+ EnableSetCookie: true,
+ Gclifetime: 3600,
+ Maxlifetime: 3600,
+ Secure: false,
+ CookieLifeTime: 3600,
+ ProviderConfig: "127.0.0.1:6379,100,,0,master",
+ }
+ globalSessions, e := session.NewManager("redis_sentinel", sessionConfig)
+ if e != nil {
+ t.Log(e)
+ return
+ }
+ //todo test if e==nil
+ go globalSessions.GC()
+
+ r, _ := http.NewRequest("GET", "/", nil)
+ w := httptest.NewRecorder()
+
+ sess, err := globalSessions.SessionStart(w, r)
+ if err != nil {
+ t.Fatal("session start failed:", err)
+ }
+ defer sess.SessionRelease(w)
+
+ // SET AND GET
+ err = sess.Set("username", "astaxie")
+ if err != nil {
+ t.Fatal("set username failed:", err)
+ }
+ username := sess.Get("username")
+ if username != "astaxie" {
+ t.Fatal("get username failed")
+ }
+
+ // DELETE
+ err = sess.Delete("username")
+ if err != nil {
+ t.Fatal("delete username failed:", err)
+ }
+ username = sess.Get("username")
+ if username != nil {
+ t.Fatal("delete username failed")
+ }
+
+ // FLUSH
+ err = sess.Set("username", "astaxie")
+ if err != nil {
+ t.Fatal("set failed:", err)
+ }
+ err = sess.Set("password", "1qaz2wsx")
+ if err != nil {
+ t.Fatal("set failed:", err)
+ }
+ username = sess.Get("username")
+ if username != "astaxie" {
+ t.Fatal("get username failed")
+ }
+ password := sess.Get("password")
+ if password != "1qaz2wsx" {
+ t.Fatal("get password failed")
+ }
+ err = sess.Flush()
+ if err != nil {
+ t.Fatal("flush failed:", err)
+ }
+ username = sess.Get("username")
+ if username != nil {
+ t.Fatal("flush failed")
+ }
+ password = sess.Get("password")
+ if password != nil {
+ t.Fatal("flush failed")
+ }
+
+ sess.SessionRelease(w)
+
+}
diff --git a/pkg/session/sess_cookie.go b/pkg/session/sess_cookie.go
new file mode 100644
index 00000000..6ad5debc
--- /dev/null
+++ b/pkg/session/sess_cookie.go
@@ -0,0 +1,180 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package session
+
+import (
+ "crypto/aes"
+ "crypto/cipher"
+ "encoding/json"
+ "net/http"
+ "net/url"
+ "sync"
+)
+
+var cookiepder = &CookieProvider{}
+
+// CookieSessionStore Cookie SessionStore
+type CookieSessionStore struct {
+ sid string
+ values map[interface{}]interface{} // session data
+ lock sync.RWMutex
+}
+
+// Set value to cookie session.
+// the value are encoded as gob with hash block string.
+func (st *CookieSessionStore) Set(key, value interface{}) error {
+ st.lock.Lock()
+ defer st.lock.Unlock()
+ st.values[key] = value
+ return nil
+}
+
+// Get value from cookie session
+func (st *CookieSessionStore) Get(key interface{}) interface{} {
+ st.lock.RLock()
+ defer st.lock.RUnlock()
+ if v, ok := st.values[key]; ok {
+ return v
+ }
+ return nil
+}
+
+// Delete value in cookie session
+func (st *CookieSessionStore) Delete(key interface{}) error {
+ st.lock.Lock()
+ defer st.lock.Unlock()
+ delete(st.values, key)
+ return nil
+}
+
+// Flush Clean all values in cookie session
+func (st *CookieSessionStore) Flush() error {
+ st.lock.Lock()
+ defer st.lock.Unlock()
+ st.values = make(map[interface{}]interface{})
+ return nil
+}
+
+// SessionID Return id of this cookie session
+func (st *CookieSessionStore) SessionID() string {
+ return st.sid
+}
+
+// SessionRelease Write cookie session to http response cookie
+func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) {
+ st.lock.Lock()
+ encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values)
+ st.lock.Unlock()
+ if err == nil {
+ cookie := &http.Cookie{Name: cookiepder.config.CookieName,
+ Value: url.QueryEscape(encodedCookie),
+ Path: "/",
+ HttpOnly: true,
+ Secure: cookiepder.config.Secure,
+ MaxAge: cookiepder.config.Maxage}
+ http.SetCookie(w, cookie)
+ }
+}
+
+type cookieConfig struct {
+ SecurityKey string `json:"securityKey"`
+ BlockKey string `json:"blockKey"`
+ SecurityName string `json:"securityName"`
+ CookieName string `json:"cookieName"`
+ Secure bool `json:"secure"`
+ Maxage int `json:"maxage"`
+}
+
+// CookieProvider Cookie session provider
+type CookieProvider struct {
+ maxlifetime int64
+ config *cookieConfig
+ block cipher.Block
+}
+
+// SessionInit Init cookie session provider with max lifetime and config json.
+// maxlifetime is ignored.
+// json config:
+// securityKey - hash string
+// blockKey - gob encode hash string. it's saved as aes crypto.
+// securityName - recognized name in encoded cookie string
+// cookieName - cookie name
+// maxage - cookie max life time.
+func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error {
+ pder.config = &cookieConfig{}
+ err := json.Unmarshal([]byte(config), pder.config)
+ if err != nil {
+ return err
+ }
+ if pder.config.BlockKey == "" {
+ pder.config.BlockKey = string(generateRandomKey(16))
+ }
+ if pder.config.SecurityName == "" {
+ pder.config.SecurityName = string(generateRandomKey(20))
+ }
+ pder.block, err = aes.NewCipher([]byte(pder.config.BlockKey))
+ if err != nil {
+ return err
+ }
+ pder.maxlifetime = maxlifetime
+ return nil
+}
+
+// SessionRead Get SessionStore in cooke.
+// decode cooke string to map and put into SessionStore with sid.
+func (pder *CookieProvider) SessionRead(sid string) (Store, error) {
+ maps, _ := decodeCookie(pder.block,
+ pder.config.SecurityKey,
+ pder.config.SecurityName,
+ sid, pder.maxlifetime)
+ if maps == nil {
+ maps = make(map[interface{}]interface{})
+ }
+ rs := &CookieSessionStore{sid: sid, values: maps}
+ return rs, nil
+}
+
+// SessionExist Cookie session is always existed
+func (pder *CookieProvider) SessionExist(sid string) bool {
+ return true
+}
+
+// SessionRegenerate Implement method, no used.
+func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (Store, error) {
+ return nil, nil
+}
+
+// SessionDestroy Implement method, no used.
+func (pder *CookieProvider) SessionDestroy(sid string) error {
+ return nil
+}
+
+// SessionGC Implement method, no used.
+func (pder *CookieProvider) SessionGC() {
+}
+
+// SessionAll Implement method, return 0.
+func (pder *CookieProvider) SessionAll() int {
+ return 0
+}
+
+// SessionUpdate Implement method, no used.
+func (pder *CookieProvider) SessionUpdate(sid string) error {
+ return nil
+}
+
+func init() {
+ Register("cookie", cookiepder)
+}
diff --git a/pkg/session/sess_cookie_test.go b/pkg/session/sess_cookie_test.go
new file mode 100644
index 00000000..b6726005
--- /dev/null
+++ b/pkg/session/sess_cookie_test.go
@@ -0,0 +1,105 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package session
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+)
+
+func TestCookie(t *testing.T) {
+ config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
+ conf := new(ManagerConfig)
+ if err := json.Unmarshal([]byte(config), conf); err != nil {
+ t.Fatal("json decode error", err)
+ }
+ globalSessions, err := NewManager("cookie", conf)
+ if err != nil {
+ t.Fatal("init cookie session err", err)
+ }
+ r, _ := http.NewRequest("GET", "/", nil)
+ w := httptest.NewRecorder()
+ sess, err := globalSessions.SessionStart(w, r)
+ if err != nil {
+ t.Fatal("set error,", err)
+ }
+ err = sess.Set("username", "astaxie")
+ if err != nil {
+ t.Fatal("set error,", err)
+ }
+ if username := sess.Get("username"); username != "astaxie" {
+ t.Fatal("get username error")
+ }
+ sess.SessionRelease(w)
+ if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
+ t.Fatal("setcookie error")
+ } else {
+ parts := strings.Split(strings.TrimSpace(cookiestr), ";")
+ for k, v := range parts {
+ nameval := strings.Split(v, "=")
+ if k == 0 && nameval[0] != "gosessionid" {
+ t.Fatal("error")
+ }
+ }
+ }
+}
+
+func TestDestorySessionCookie(t *testing.T) {
+ config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
+ conf := new(ManagerConfig)
+ if err := json.Unmarshal([]byte(config), conf); err != nil {
+ t.Fatal("json decode error", err)
+ }
+ globalSessions, err := NewManager("cookie", conf)
+ if err != nil {
+ t.Fatal("init cookie session err", err)
+ }
+
+ r, _ := http.NewRequest("GET", "/", nil)
+ w := httptest.NewRecorder()
+ session, err := globalSessions.SessionStart(w, r)
+ if err != nil {
+ t.Fatal("session start err,", err)
+ }
+
+ // request again ,will get same sesssion id .
+ r1, _ := http.NewRequest("GET", "/", nil)
+ r1.Header.Set("Cookie", w.Header().Get("Set-Cookie"))
+ w = httptest.NewRecorder()
+ newSession, err := globalSessions.SessionStart(w, r1)
+ if err != nil {
+ t.Fatal("session start err,", err)
+ }
+ if newSession.SessionID() != session.SessionID() {
+ t.Fatal("get cookie session id is not the same again.")
+ }
+
+ // After destroy session , will get a new session id .
+ globalSessions.SessionDestroy(w, r1)
+ r2, _ := http.NewRequest("GET", "/", nil)
+ r2.Header.Set("Cookie", w.Header().Get("Set-Cookie"))
+
+ w = httptest.NewRecorder()
+ newSession, err = globalSessions.SessionStart(w, r2)
+ if err != nil {
+ t.Fatal("session start error")
+ }
+ if newSession.SessionID() == session.SessionID() {
+ t.Fatal("after destroy session and reqeust again ,get cookie session id is same.")
+ }
+}
diff --git a/pkg/session/sess_file.go b/pkg/session/sess_file.go
new file mode 100644
index 00000000..47ad54a7
--- /dev/null
+++ b/pkg/session/sess_file.go
@@ -0,0 +1,315 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package session
+
+import (
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "os"
+ "path"
+ "path/filepath"
+ "strings"
+ "sync"
+ "time"
+)
+
+var (
+ filepder = &FileProvider{}
+ gcmaxlifetime int64
+)
+
+// FileSessionStore File session store
+type FileSessionStore struct {
+ sid string
+ lock sync.RWMutex
+ values map[interface{}]interface{}
+}
+
+// Set value to file session
+func (fs *FileSessionStore) Set(key, value interface{}) error {
+ fs.lock.Lock()
+ defer fs.lock.Unlock()
+ fs.values[key] = value
+ return nil
+}
+
+// Get value from file session
+func (fs *FileSessionStore) Get(key interface{}) interface{} {
+ fs.lock.RLock()
+ defer fs.lock.RUnlock()
+ if v, ok := fs.values[key]; ok {
+ return v
+ }
+ return nil
+}
+
+// Delete value in file session by given key
+func (fs *FileSessionStore) Delete(key interface{}) error {
+ fs.lock.Lock()
+ defer fs.lock.Unlock()
+ delete(fs.values, key)
+ return nil
+}
+
+// Flush Clean all values in file session
+func (fs *FileSessionStore) Flush() error {
+ fs.lock.Lock()
+ defer fs.lock.Unlock()
+ fs.values = make(map[interface{}]interface{})
+ return nil
+}
+
+// SessionID Get file session store id
+func (fs *FileSessionStore) SessionID() string {
+ return fs.sid
+}
+
+// SessionRelease Write file session to local file with Gob string
+func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
+ filepder.lock.Lock()
+ defer filepder.lock.Unlock()
+ b, err := EncodeGob(fs.values)
+ if err != nil {
+ SLogger.Println(err)
+ return
+ }
+ _, err = os.Stat(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid))
+ var f *os.File
+ if err == nil {
+ f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777)
+ if err != nil {
+ SLogger.Println(err)
+ return
+ }
+ } else if os.IsNotExist(err) {
+ f, err = os.Create(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid))
+ if err != nil {
+ SLogger.Println(err)
+ return
+ }
+ } else {
+ return
+ }
+ f.Truncate(0)
+ f.Seek(0, 0)
+ f.Write(b)
+ f.Close()
+}
+
+// FileProvider File session provider
+type FileProvider struct {
+ lock sync.RWMutex
+ maxlifetime int64
+ savePath string
+}
+
+// SessionInit Init file session provider.
+// savePath sets the session files path.
+func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error {
+ fp.maxlifetime = maxlifetime
+ fp.savePath = savePath
+ return nil
+}
+
+// SessionRead Read file session by sid.
+// if file is not exist, create it.
+// the file path is generated from sid string.
+func (fp *FileProvider) SessionRead(sid string) (Store, error) {
+ invalidChars := "./"
+ if strings.ContainsAny(sid, invalidChars) {
+ return nil, errors.New("the sid shouldn't have following characters: " + invalidChars)
+ }
+ if len(sid) < 2 {
+ return nil, errors.New("length of the sid is less than 2")
+ }
+ filepder.lock.Lock()
+ defer filepder.lock.Unlock()
+
+ err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0755)
+ if err != nil {
+ SLogger.Println(err.Error())
+ }
+ _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
+ var f *os.File
+ if err == nil {
+ f, err = os.OpenFile(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), os.O_RDWR, 0777)
+ } else if os.IsNotExist(err) {
+ f, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
+ } else {
+ return nil, err
+ }
+
+ defer f.Close()
+
+ os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now())
+ var kv map[interface{}]interface{}
+ b, err := ioutil.ReadAll(f)
+ if err != nil {
+ return nil, err
+ }
+ if len(b) == 0 {
+ kv = make(map[interface{}]interface{})
+ } else {
+ kv, err = DecodeGob(b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ ss := &FileSessionStore{sid: sid, values: kv}
+ return ss, nil
+}
+
+// SessionExist Check file session exist.
+// it checks the file named from sid exist or not.
+func (fp *FileProvider) SessionExist(sid string) bool {
+ filepder.lock.Lock()
+ defer filepder.lock.Unlock()
+
+ if len(sid) < 2 {
+ SLogger.Println("min length of session id is 2", sid)
+ return false
+ }
+
+ _, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
+ return err == nil
+}
+
+// SessionDestroy Remove all files in this save path
+func (fp *FileProvider) SessionDestroy(sid string) error {
+ filepder.lock.Lock()
+ defer filepder.lock.Unlock()
+ os.Remove(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
+ return nil
+}
+
+// SessionGC Recycle files in save path
+func (fp *FileProvider) SessionGC() {
+ filepder.lock.Lock()
+ defer filepder.lock.Unlock()
+
+ gcmaxlifetime = fp.maxlifetime
+ filepath.Walk(fp.savePath, gcpath)
+}
+
+// SessionAll Get active file session number.
+// it walks save path to count files.
+func (fp *FileProvider) SessionAll() int {
+ a := &activeSession{}
+ err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error {
+ return a.visit(path, f, err)
+ })
+ if err != nil {
+ SLogger.Printf("filepath.Walk() returned %v\n", err)
+ return 0
+ }
+ return a.total
+}
+
+// SessionRegenerate Generate new sid for file session.
+// it delete old file and create new file named from new sid.
+func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) {
+ filepder.lock.Lock()
+ defer filepder.lock.Unlock()
+
+ oldPath := path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]))
+ oldSidFile := path.Join(oldPath, oldsid)
+ newPath := path.Join(fp.savePath, string(sid[0]), string(sid[1]))
+ newSidFile := path.Join(newPath, sid)
+
+ // new sid file is exist
+ _, err := os.Stat(newSidFile)
+ if err == nil {
+ return nil, fmt.Errorf("newsid %s exist", newSidFile)
+ }
+
+ err = os.MkdirAll(newPath, 0755)
+ if err != nil {
+ SLogger.Println(err.Error())
+ }
+
+ // if old sid file exist
+ // 1.read and parse file content
+ // 2.write content to new sid file
+ // 3.remove old sid file, change new sid file atime and ctime
+ // 4.return FileSessionStore
+ _, err = os.Stat(oldSidFile)
+ if err == nil {
+ b, err := ioutil.ReadFile(oldSidFile)
+ if err != nil {
+ return nil, err
+ }
+
+ var kv map[interface{}]interface{}
+ if len(b) == 0 {
+ kv = make(map[interface{}]interface{})
+ } else {
+ kv, err = DecodeGob(b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ ioutil.WriteFile(newSidFile, b, 0777)
+ os.Remove(oldSidFile)
+ os.Chtimes(newSidFile, time.Now(), time.Now())
+ ss := &FileSessionStore{sid: sid, values: kv}
+ return ss, nil
+ }
+
+ // if old sid file not exist, just create new sid file and return
+ newf, err := os.Create(newSidFile)
+ if err != nil {
+ return nil, err
+ }
+ newf.Close()
+ ss := &FileSessionStore{sid: sid, values: make(map[interface{}]interface{})}
+ return ss, nil
+}
+
+// remove file in save path if expired
+func gcpath(path string, info os.FileInfo, err error) error {
+ if err != nil {
+ return err
+ }
+ if info.IsDir() {
+ return nil
+ }
+ if (info.ModTime().Unix() + gcmaxlifetime) < time.Now().Unix() {
+ os.Remove(path)
+ }
+ return nil
+}
+
+type activeSession struct {
+ total int
+}
+
+func (as *activeSession) visit(paths string, f os.FileInfo, err error) error {
+ if err != nil {
+ return err
+ }
+ if f.IsDir() {
+ return nil
+ }
+ as.total = as.total + 1
+ return nil
+}
+
+func init() {
+ Register("file", filepder)
+}
diff --git a/pkg/session/sess_file_test.go b/pkg/session/sess_file_test.go
new file mode 100644
index 00000000..0cf021db
--- /dev/null
+++ b/pkg/session/sess_file_test.go
@@ -0,0 +1,387 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package session
+
+import (
+ "fmt"
+ "os"
+ "sync"
+ "testing"
+ "time"
+)
+
+const sid = "Session_id"
+const sidNew = "Session_id_new"
+const sessionPath = "./_session_runtime"
+
+var (
+ mutex sync.Mutex
+)
+
+func TestFileProvider_SessionInit(t *testing.T) {
+ mutex.Lock()
+ defer mutex.Unlock()
+ os.RemoveAll(sessionPath)
+ defer os.RemoveAll(sessionPath)
+ fp := &FileProvider{}
+
+ _ = fp.SessionInit(180, sessionPath)
+ if fp.maxlifetime != 180 {
+ t.Error()
+ }
+
+ if fp.savePath != sessionPath {
+ t.Error()
+ }
+}
+
+func TestFileProvider_SessionExist(t *testing.T) {
+ mutex.Lock()
+ defer mutex.Unlock()
+ os.RemoveAll(sessionPath)
+ defer os.RemoveAll(sessionPath)
+ fp := &FileProvider{}
+
+ _ = fp.SessionInit(180, sessionPath)
+
+ if fp.SessionExist(sid) {
+ t.Error()
+ }
+
+ _, err := fp.SessionRead(sid)
+ if err != nil {
+ t.Error(err)
+ }
+
+ if !fp.SessionExist(sid) {
+ t.Error()
+ }
+}
+
+func TestFileProvider_SessionExist2(t *testing.T) {
+ mutex.Lock()
+ defer mutex.Unlock()
+ os.RemoveAll(sessionPath)
+ defer os.RemoveAll(sessionPath)
+ fp := &FileProvider{}
+
+ _ = fp.SessionInit(180, sessionPath)
+
+ if fp.SessionExist(sid) {
+ t.Error()
+ }
+
+ if fp.SessionExist("") {
+ t.Error()
+ }
+
+ if fp.SessionExist("1") {
+ t.Error()
+ }
+}
+
+func TestFileProvider_SessionRead(t *testing.T) {
+ mutex.Lock()
+ defer mutex.Unlock()
+ os.RemoveAll(sessionPath)
+ defer os.RemoveAll(sessionPath)
+ fp := &FileProvider{}
+
+ _ = fp.SessionInit(180, sessionPath)
+
+ s, err := fp.SessionRead(sid)
+ if err != nil {
+ t.Error(err)
+ }
+
+ _ = s.Set("sessionValue", 18975)
+ v := s.Get("sessionValue")
+
+ if v.(int) != 18975 {
+ t.Error()
+ }
+}
+
+func TestFileProvider_SessionRead1(t *testing.T) {
+ mutex.Lock()
+ defer mutex.Unlock()
+ os.RemoveAll(sessionPath)
+ defer os.RemoveAll(sessionPath)
+ fp := &FileProvider{}
+
+ _ = fp.SessionInit(180, sessionPath)
+
+ _, err := fp.SessionRead("")
+ if err == nil {
+ t.Error(err)
+ }
+
+ _, err = fp.SessionRead("1")
+ if err == nil {
+ t.Error(err)
+ }
+}
+
+func TestFileProvider_SessionAll(t *testing.T) {
+ mutex.Lock()
+ defer mutex.Unlock()
+ os.RemoveAll(sessionPath)
+ defer os.RemoveAll(sessionPath)
+ fp := &FileProvider{}
+
+ _ = fp.SessionInit(180, sessionPath)
+
+ sessionCount := 546
+
+ for i := 1; i <= sessionCount; i++ {
+ _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i))
+ if err != nil {
+ t.Error(err)
+ }
+ }
+
+ if fp.SessionAll() != sessionCount {
+ t.Error()
+ }
+}
+
+func TestFileProvider_SessionRegenerate(t *testing.T) {
+ mutex.Lock()
+ defer mutex.Unlock()
+ os.RemoveAll(sessionPath)
+ defer os.RemoveAll(sessionPath)
+ fp := &FileProvider{}
+
+ _ = fp.SessionInit(180, sessionPath)
+
+ _, err := fp.SessionRead(sid)
+ if err != nil {
+ t.Error(err)
+ }
+
+ if !fp.SessionExist(sid) {
+ t.Error()
+ }
+
+ _, err = fp.SessionRegenerate(sid, sidNew)
+ if err != nil {
+ t.Error(err)
+ }
+
+ if fp.SessionExist(sid) {
+ t.Error()
+ }
+
+ if !fp.SessionExist(sidNew) {
+ t.Error()
+ }
+}
+
+func TestFileProvider_SessionDestroy(t *testing.T) {
+ mutex.Lock()
+ defer mutex.Unlock()
+ os.RemoveAll(sessionPath)
+ defer os.RemoveAll(sessionPath)
+ fp := &FileProvider{}
+
+ _ = fp.SessionInit(180, sessionPath)
+
+ _, err := fp.SessionRead(sid)
+ if err != nil {
+ t.Error(err)
+ }
+
+ if !fp.SessionExist(sid) {
+ t.Error()
+ }
+
+ err = fp.SessionDestroy(sid)
+ if err != nil {
+ t.Error(err)
+ }
+
+ if fp.SessionExist(sid) {
+ t.Error()
+ }
+}
+
+func TestFileProvider_SessionGC(t *testing.T) {
+ mutex.Lock()
+ defer mutex.Unlock()
+ os.RemoveAll(sessionPath)
+ defer os.RemoveAll(sessionPath)
+ fp := &FileProvider{}
+
+ _ = fp.SessionInit(1, sessionPath)
+
+ sessionCount := 412
+
+ for i := 1; i <= sessionCount; i++ {
+ _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i))
+ if err != nil {
+ t.Error(err)
+ }
+ }
+
+ time.Sleep(2 * time.Second)
+
+ fp.SessionGC()
+ if fp.SessionAll() != 0 {
+ t.Error()
+ }
+}
+
+func TestFileSessionStore_Set(t *testing.T) {
+ mutex.Lock()
+ defer mutex.Unlock()
+ os.RemoveAll(sessionPath)
+ defer os.RemoveAll(sessionPath)
+ fp := &FileProvider{}
+
+ _ = fp.SessionInit(180, sessionPath)
+
+ sessionCount := 100
+ s, _ := fp.SessionRead(sid)
+ for i := 1; i <= sessionCount; i++ {
+ err := s.Set(i, i)
+ if err != nil {
+ t.Error(err)
+ }
+ }
+}
+
+func TestFileSessionStore_Get(t *testing.T) {
+ mutex.Lock()
+ defer mutex.Unlock()
+ os.RemoveAll(sessionPath)
+ defer os.RemoveAll(sessionPath)
+ fp := &FileProvider{}
+
+ _ = fp.SessionInit(180, sessionPath)
+
+ sessionCount := 100
+ s, _ := fp.SessionRead(sid)
+ for i := 1; i <= sessionCount; i++ {
+ _ = s.Set(i, i)
+
+ v := s.Get(i)
+ if v.(int) != i {
+ t.Error()
+ }
+ }
+}
+
+func TestFileSessionStore_Delete(t *testing.T) {
+ mutex.Lock()
+ defer mutex.Unlock()
+ os.RemoveAll(sessionPath)
+ defer os.RemoveAll(sessionPath)
+ fp := &FileProvider{}
+
+ _ = fp.SessionInit(180, sessionPath)
+
+ s, _ := fp.SessionRead(sid)
+ s.Set("1", 1)
+
+ if s.Get("1") == nil {
+ t.Error()
+ }
+
+ s.Delete("1")
+
+ if s.Get("1") != nil {
+ t.Error()
+ }
+}
+
+func TestFileSessionStore_Flush(t *testing.T) {
+ mutex.Lock()
+ defer mutex.Unlock()
+ os.RemoveAll(sessionPath)
+ defer os.RemoveAll(sessionPath)
+ fp := &FileProvider{}
+
+ _ = fp.SessionInit(180, sessionPath)
+
+ sessionCount := 100
+ s, _ := fp.SessionRead(sid)
+ for i := 1; i <= sessionCount; i++ {
+ _ = s.Set(i, i)
+ }
+
+ _ = s.Flush()
+
+ for i := 1; i <= sessionCount; i++ {
+ if s.Get(i) != nil {
+ t.Error()
+ }
+ }
+}
+
+func TestFileSessionStore_SessionID(t *testing.T) {
+ mutex.Lock()
+ defer mutex.Unlock()
+ os.RemoveAll(sessionPath)
+ defer os.RemoveAll(sessionPath)
+ fp := &FileProvider{}
+
+ _ = fp.SessionInit(180, sessionPath)
+
+ sessionCount := 85
+
+ for i := 1; i <= sessionCount; i++ {
+ s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i))
+ if err != nil {
+ t.Error(err)
+ }
+ if s.SessionID() != fmt.Sprintf("%s_%d", sid, i) {
+ t.Error(err)
+ }
+ }
+}
+
+func TestFileSessionStore_SessionRelease(t *testing.T) {
+ mutex.Lock()
+ defer mutex.Unlock()
+ os.RemoveAll(sessionPath)
+ defer os.RemoveAll(sessionPath)
+ fp := &FileProvider{}
+
+ _ = fp.SessionInit(180, sessionPath)
+ filepder.savePath = sessionPath
+ sessionCount := 85
+
+ for i := 1; i <= sessionCount; i++ {
+ s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i))
+ if err != nil {
+ t.Error(err)
+ }
+
+
+ s.Set(i,i)
+ s.SessionRelease(nil)
+ }
+
+ for i := 1; i <= sessionCount; i++ {
+ s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i))
+ if err != nil {
+ t.Error(err)
+ }
+
+ if s.Get(i).(int) != i {
+ t.Error()
+ }
+ }
+}
\ No newline at end of file
diff --git a/pkg/session/sess_mem.go b/pkg/session/sess_mem.go
new file mode 100644
index 00000000..64d8b056
--- /dev/null
+++ b/pkg/session/sess_mem.go
@@ -0,0 +1,196 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package session
+
+import (
+ "container/list"
+ "net/http"
+ "sync"
+ "time"
+)
+
+var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)}
+
+// MemSessionStore memory session store.
+// it saved sessions in a map in memory.
+type MemSessionStore struct {
+ sid string //session id
+ timeAccessed time.Time //last access time
+ value map[interface{}]interface{} //session store
+ lock sync.RWMutex
+}
+
+// Set value to memory session
+func (st *MemSessionStore) Set(key, value interface{}) error {
+ st.lock.Lock()
+ defer st.lock.Unlock()
+ st.value[key] = value
+ return nil
+}
+
+// Get value from memory session by key
+func (st *MemSessionStore) Get(key interface{}) interface{} {
+ st.lock.RLock()
+ defer st.lock.RUnlock()
+ if v, ok := st.value[key]; ok {
+ return v
+ }
+ return nil
+}
+
+// Delete in memory session by key
+func (st *MemSessionStore) Delete(key interface{}) error {
+ st.lock.Lock()
+ defer st.lock.Unlock()
+ delete(st.value, key)
+ return nil
+}
+
+// Flush clear all values in memory session
+func (st *MemSessionStore) Flush() error {
+ st.lock.Lock()
+ defer st.lock.Unlock()
+ st.value = make(map[interface{}]interface{})
+ return nil
+}
+
+// SessionID get this id of memory session store
+func (st *MemSessionStore) SessionID() string {
+ return st.sid
+}
+
+// SessionRelease Implement method, no used.
+func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) {
+}
+
+// MemProvider Implement the provider interface
+type MemProvider struct {
+ lock sync.RWMutex // locker
+ sessions map[string]*list.Element // map in memory
+ list *list.List // for gc
+ maxlifetime int64
+ savePath string
+}
+
+// SessionInit init memory session
+func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error {
+ pder.maxlifetime = maxlifetime
+ pder.savePath = savePath
+ return nil
+}
+
+// SessionRead get memory session store by sid
+func (pder *MemProvider) SessionRead(sid string) (Store, error) {
+ pder.lock.RLock()
+ if element, ok := pder.sessions[sid]; ok {
+ go pder.SessionUpdate(sid)
+ pder.lock.RUnlock()
+ return element.Value.(*MemSessionStore), nil
+ }
+ pder.lock.RUnlock()
+ pder.lock.Lock()
+ newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})}
+ element := pder.list.PushFront(newsess)
+ pder.sessions[sid] = element
+ pder.lock.Unlock()
+ return newsess, nil
+}
+
+// SessionExist check session store exist in memory session by sid
+func (pder *MemProvider) SessionExist(sid string) bool {
+ pder.lock.RLock()
+ defer pder.lock.RUnlock()
+ if _, ok := pder.sessions[sid]; ok {
+ return true
+ }
+ return false
+}
+
+// SessionRegenerate generate new sid for session store in memory session
+func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) {
+ pder.lock.RLock()
+ if element, ok := pder.sessions[oldsid]; ok {
+ go pder.SessionUpdate(oldsid)
+ pder.lock.RUnlock()
+ pder.lock.Lock()
+ element.Value.(*MemSessionStore).sid = sid
+ pder.sessions[sid] = element
+ delete(pder.sessions, oldsid)
+ pder.lock.Unlock()
+ return element.Value.(*MemSessionStore), nil
+ }
+ pder.lock.RUnlock()
+ pder.lock.Lock()
+ newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})}
+ element := pder.list.PushFront(newsess)
+ pder.sessions[sid] = element
+ pder.lock.Unlock()
+ return newsess, nil
+}
+
+// SessionDestroy delete session store in memory session by id
+func (pder *MemProvider) SessionDestroy(sid string) error {
+ pder.lock.Lock()
+ defer pder.lock.Unlock()
+ if element, ok := pder.sessions[sid]; ok {
+ delete(pder.sessions, sid)
+ pder.list.Remove(element)
+ return nil
+ }
+ return nil
+}
+
+// SessionGC clean expired session stores in memory session
+func (pder *MemProvider) SessionGC() {
+ pder.lock.RLock()
+ for {
+ element := pder.list.Back()
+ if element == nil {
+ break
+ }
+ if (element.Value.(*MemSessionStore).timeAccessed.Unix() + pder.maxlifetime) < time.Now().Unix() {
+ pder.lock.RUnlock()
+ pder.lock.Lock()
+ pder.list.Remove(element)
+ delete(pder.sessions, element.Value.(*MemSessionStore).sid)
+ pder.lock.Unlock()
+ pder.lock.RLock()
+ } else {
+ break
+ }
+ }
+ pder.lock.RUnlock()
+}
+
+// SessionAll get count number of memory session
+func (pder *MemProvider) SessionAll() int {
+ return pder.list.Len()
+}
+
+// SessionUpdate expand time of session store by id in memory session
+func (pder *MemProvider) SessionUpdate(sid string) error {
+ pder.lock.Lock()
+ defer pder.lock.Unlock()
+ if element, ok := pder.sessions[sid]; ok {
+ element.Value.(*MemSessionStore).timeAccessed = time.Now()
+ pder.list.MoveToFront(element)
+ return nil
+ }
+ return nil
+}
+
+func init() {
+ Register("memory", mempder)
+}
diff --git a/pkg/session/sess_mem_test.go b/pkg/session/sess_mem_test.go
new file mode 100644
index 00000000..2e8934b8
--- /dev/null
+++ b/pkg/session/sess_mem_test.go
@@ -0,0 +1,58 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package session
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+)
+
+func TestMem(t *testing.T) {
+ config := `{"cookieName":"gosessionid","gclifetime":10, "enableSetCookie":true}`
+ conf := new(ManagerConfig)
+ if err := json.Unmarshal([]byte(config), conf); err != nil {
+ t.Fatal("json decode error", err)
+ }
+ globalSessions, _ := NewManager("memory", conf)
+ go globalSessions.GC()
+ r, _ := http.NewRequest("GET", "/", nil)
+ w := httptest.NewRecorder()
+ sess, err := globalSessions.SessionStart(w, r)
+ if err != nil {
+ t.Fatal("set error,", err)
+ }
+ defer sess.SessionRelease(w)
+ err = sess.Set("username", "astaxie")
+ if err != nil {
+ t.Fatal("set error,", err)
+ }
+ if username := sess.Get("username"); username != "astaxie" {
+ t.Fatal("get username error")
+ }
+ if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
+ t.Fatal("setcookie error")
+ } else {
+ parts := strings.Split(strings.TrimSpace(cookiestr), ";")
+ for k, v := range parts {
+ nameval := strings.Split(v, "=")
+ if k == 0 && nameval[0] != "gosessionid" {
+ t.Fatal("error")
+ }
+ }
+ }
+}
diff --git a/pkg/session/sess_test.go b/pkg/session/sess_test.go
new file mode 100644
index 00000000..906abec2
--- /dev/null
+++ b/pkg/session/sess_test.go
@@ -0,0 +1,131 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package session
+
+import (
+ "crypto/aes"
+ "encoding/json"
+ "testing"
+)
+
+func Test_gob(t *testing.T) {
+ a := make(map[interface{}]interface{})
+ a["username"] = "astaxie"
+ a[12] = 234
+ a["user"] = User{"asta", "xie"}
+ b, err := EncodeGob(a)
+ if err != nil {
+ t.Error(err)
+ }
+ c, err := DecodeGob(b)
+ if err != nil {
+ t.Error(err)
+ }
+ if len(c) == 0 {
+ t.Error("decodeGob empty")
+ }
+ if c["username"] != "astaxie" {
+ t.Error("decode string error")
+ }
+ if c[12] != 234 {
+ t.Error("decode int error")
+ }
+ if c["user"].(User).Username != "asta" {
+ t.Error("decode struct error")
+ }
+}
+
+type User struct {
+ Username string
+ NickName string
+}
+
+func TestGenerate(t *testing.T) {
+ str := generateRandomKey(20)
+ if len(str) != 20 {
+ t.Fatal("generate length is not equal to 20")
+ }
+}
+
+func TestCookieEncodeDecode(t *testing.T) {
+ hashKey := "testhashKey"
+ blockkey := generateRandomKey(16)
+ block, err := aes.NewCipher(blockkey)
+ if err != nil {
+ t.Fatal("NewCipher:", err)
+ }
+ securityName := string(generateRandomKey(20))
+ val := make(map[interface{}]interface{})
+ val["name"] = "astaxie"
+ val["gender"] = "male"
+ str, err := encodeCookie(block, hashKey, securityName, val)
+ if err != nil {
+ t.Fatal("encodeCookie:", err)
+ }
+ dst, err := decodeCookie(block, hashKey, securityName, str, 3600)
+ if err != nil {
+ t.Fatal("decodeCookie", err)
+ }
+ if dst["name"] != "astaxie" {
+ t.Fatal("dst get map error")
+ }
+ if dst["gender"] != "male" {
+ t.Fatal("dst get map error")
+ }
+}
+
+func TestParseConfig(t *testing.T) {
+ s := `{"cookieName":"gosessionid","gclifetime":3600}`
+ cf := new(ManagerConfig)
+ cf.EnableSetCookie = true
+ err := json.Unmarshal([]byte(s), cf)
+ if err != nil {
+ t.Fatal("parse json error,", err)
+ }
+ if cf.CookieName != "gosessionid" {
+ t.Fatal("parseconfig get cookiename error")
+ }
+ if cf.Gclifetime != 3600 {
+ t.Fatal("parseconfig get gclifetime error")
+ }
+
+ cc := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
+ cf2 := new(ManagerConfig)
+ cf2.EnableSetCookie = true
+ err = json.Unmarshal([]byte(cc), cf2)
+ if err != nil {
+ t.Fatal("parse json error,", err)
+ }
+ if cf2.CookieName != "gosessionid" {
+ t.Fatal("parseconfig get cookiename error")
+ }
+ if cf2.Gclifetime != 3600 {
+ t.Fatal("parseconfig get gclifetime error")
+ }
+ if cf2.EnableSetCookie {
+ t.Fatal("parseconfig get enableSetCookie error")
+ }
+ cconfig := new(cookieConfig)
+ err = json.Unmarshal([]byte(cf2.ProviderConfig), cconfig)
+ if err != nil {
+ t.Fatal("parse ProviderConfig err,", err)
+ }
+ if cconfig.CookieName != "gosessionid" {
+ t.Fatal("ProviderConfig get cookieName error")
+ }
+ if cconfig.SecurityKey != "beegocookiehashkey" {
+ t.Fatal("ProviderConfig get securityKey error")
+ }
+}
diff --git a/pkg/session/sess_utils.go b/pkg/session/sess_utils.go
new file mode 100644
index 00000000..20915bb6
--- /dev/null
+++ b/pkg/session/sess_utils.go
@@ -0,0 +1,207 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package session
+
+import (
+ "bytes"
+ "crypto/cipher"
+ "crypto/hmac"
+ "crypto/rand"
+ "crypto/sha256"
+ "crypto/subtle"
+ "encoding/base64"
+ "encoding/gob"
+ "errors"
+ "fmt"
+ "io"
+ "strconv"
+ "time"
+
+ "github.com/astaxie/beego/utils"
+)
+
+func init() {
+ gob.Register([]interface{}{})
+ gob.Register(map[int]interface{}{})
+ gob.Register(map[string]interface{}{})
+ gob.Register(map[interface{}]interface{}{})
+ gob.Register(map[string]string{})
+ gob.Register(map[int]string{})
+ gob.Register(map[int]int{})
+ gob.Register(map[int]int64{})
+}
+
+// EncodeGob encode the obj to gob
+func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) {
+ for _, v := range obj {
+ gob.Register(v)
+ }
+ buf := bytes.NewBuffer(nil)
+ enc := gob.NewEncoder(buf)
+ err := enc.Encode(obj)
+ if err != nil {
+ return []byte(""), err
+ }
+ return buf.Bytes(), nil
+}
+
+// DecodeGob decode data to map
+func DecodeGob(encoded []byte) (map[interface{}]interface{}, error) {
+ buf := bytes.NewBuffer(encoded)
+ dec := gob.NewDecoder(buf)
+ var out map[interface{}]interface{}
+ err := dec.Decode(&out)
+ if err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+// generateRandomKey creates a random key with the given strength.
+func generateRandomKey(strength int) []byte {
+ k := make([]byte, strength)
+ if n, err := io.ReadFull(rand.Reader, k); n != strength || err != nil {
+ return utils.RandomCreateBytes(strength)
+ }
+ return k
+}
+
+// Encryption -----------------------------------------------------------------
+
+// encrypt encrypts a value using the given block in counter mode.
+//
+// A random initialization vector (http://goo.gl/zF67k) with the length of the
+// block size is prepended to the resulting ciphertext.
+func encrypt(block cipher.Block, value []byte) ([]byte, error) {
+ iv := generateRandomKey(block.BlockSize())
+ if iv == nil {
+ return nil, errors.New("encrypt: failed to generate random iv")
+ }
+ // Encrypt it.
+ stream := cipher.NewCTR(block, iv)
+ stream.XORKeyStream(value, value)
+ // Return iv + ciphertext.
+ return append(iv, value...), nil
+}
+
+// decrypt decrypts a value using the given block in counter mode.
+//
+// The value to be decrypted must be prepended by a initialization vector
+// (http://goo.gl/zF67k) with the length of the block size.
+func decrypt(block cipher.Block, value []byte) ([]byte, error) {
+ size := block.BlockSize()
+ if len(value) > size {
+ // Extract iv.
+ iv := value[:size]
+ // Extract ciphertext.
+ value = value[size:]
+ // Decrypt it.
+ stream := cipher.NewCTR(block, iv)
+ stream.XORKeyStream(value, value)
+ return value, nil
+ }
+ return nil, errors.New("decrypt: the value could not be decrypted")
+}
+
+func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) {
+ var err error
+ var b []byte
+ // 1. EncodeGob.
+ if b, err = EncodeGob(value); err != nil {
+ return "", err
+ }
+ // 2. Encrypt (optional).
+ if b, err = encrypt(block, b); err != nil {
+ return "", err
+ }
+ b = encode(b)
+ // 3. Create MAC for "name|date|value". Extra pipe to be used later.
+ b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b))
+ h := hmac.New(sha256.New, []byte(hashKey))
+ h.Write(b)
+ sig := h.Sum(nil)
+ // Append mac, remove name.
+ b = append(b, sig...)[len(name)+1:]
+ // 4. Encode to base64.
+ b = encode(b)
+ // Done.
+ return string(b), nil
+}
+
+func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) {
+ // 1. Decode from base64.
+ b, err := decode([]byte(value))
+ if err != nil {
+ return nil, err
+ }
+ // 2. Verify MAC. Value is "date|value|mac".
+ parts := bytes.SplitN(b, []byte("|"), 3)
+ if len(parts) != 3 {
+ return nil, errors.New("Decode: invalid value format")
+ }
+
+ b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...)
+ h := hmac.New(sha256.New, []byte(hashKey))
+ h.Write(b)
+ sig := h.Sum(nil)
+ if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 {
+ return nil, errors.New("Decode: the value is not valid")
+ }
+ // 3. Verify date ranges.
+ var t1 int64
+ if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil {
+ return nil, errors.New("Decode: invalid timestamp")
+ }
+ t2 := time.Now().UTC().Unix()
+ if t1 > t2 {
+ return nil, errors.New("Decode: timestamp is too new")
+ }
+ if t1 < t2-gcmaxlifetime {
+ return nil, errors.New("Decode: expired timestamp")
+ }
+ // 4. Decrypt (optional).
+ b, err = decode(parts[1])
+ if err != nil {
+ return nil, err
+ }
+ if b, err = decrypt(block, b); err != nil {
+ return nil, err
+ }
+ // 5. DecodeGob.
+ dst, err := DecodeGob(b)
+ if err != nil {
+ return nil, err
+ }
+ return dst, nil
+}
+
+// Encoding -------------------------------------------------------------------
+
+// encode encodes a value using base64.
+func encode(value []byte) []byte {
+ encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value)))
+ base64.URLEncoding.Encode(encoded, value)
+ return encoded
+}
+
+// decode decodes a cookie using base64.
+func decode(value []byte) ([]byte, error) {
+ decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value)))
+ b, err := base64.URLEncoding.Decode(decoded, value)
+ if err != nil {
+ return nil, err
+ }
+ return decoded[:b], nil
+}
diff --git a/pkg/session/session.go b/pkg/session/session.go
new file mode 100644
index 00000000..eb85360a
--- /dev/null
+++ b/pkg/session/session.go
@@ -0,0 +1,377 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package session provider
+//
+// Usage:
+// import(
+// "github.com/astaxie/beego/session"
+// )
+//
+// func init() {
+// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "cookieLifeTime": 3600, "providerConfig": ""}`)
+// go globalSessions.GC()
+// }
+//
+// more docs: http://beego.me/docs/module/session.md
+package session
+
+import (
+ "crypto/rand"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "net/textproto"
+ "net/url"
+ "os"
+ "time"
+)
+
+// Store contains all data for one session process with specific id.
+type Store interface {
+ Set(key, value interface{}) error //set session value
+ Get(key interface{}) interface{} //get session value
+ Delete(key interface{}) error //delete session value
+ SessionID() string //back current sessionID
+ SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
+ Flush() error //delete all data
+}
+
+// Provider contains global session methods and saved SessionStores.
+// it can operate a SessionStore by its id.
+type Provider interface {
+ SessionInit(gclifetime int64, config string) error
+ SessionRead(sid string) (Store, error)
+ SessionExist(sid string) bool
+ SessionRegenerate(oldsid, sid string) (Store, error)
+ SessionDestroy(sid string) error
+ SessionAll() int //get all active session
+ SessionGC()
+}
+
+var provides = make(map[string]Provider)
+
+// SLogger a helpful variable to log information about session
+var SLogger = NewSessionLog(os.Stderr)
+
+// Register makes a session provide available by the provided name.
+// If Register is called twice with the same name or if driver is nil,
+// it panics.
+func Register(name string, provide Provider) {
+ if provide == nil {
+ panic("session: Register provide is nil")
+ }
+ if _, dup := provides[name]; dup {
+ panic("session: Register called twice for provider " + name)
+ }
+ provides[name] = provide
+}
+
+//GetProvider
+func GetProvider(name string) (Provider, error) {
+ provider, ok := provides[name]
+ if !ok {
+ return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", name)
+ }
+ return provider, nil
+}
+
+// ManagerConfig define the session config
+type ManagerConfig struct {
+ CookieName string `json:"cookieName"`
+ EnableSetCookie bool `json:"enableSetCookie,omitempty"`
+ Gclifetime int64 `json:"gclifetime"`
+ Maxlifetime int64 `json:"maxLifetime"`
+ DisableHTTPOnly bool `json:"disableHTTPOnly"`
+ Secure bool `json:"secure"`
+ CookieLifeTime int `json:"cookieLifeTime"`
+ ProviderConfig string `json:"providerConfig"`
+ Domain string `json:"domain"`
+ SessionIDLength int64 `json:"sessionIDLength"`
+ EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"`
+ SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"`
+ EnableSidInURLQuery bool `json:"EnableSidInURLQuery"`
+ SessionIDPrefix string `json:"sessionIDPrefix"`
+}
+
+// Manager contains Provider and its configuration.
+type Manager struct {
+ provider Provider
+ config *ManagerConfig
+}
+
+// NewManager Create new Manager with provider name and json config string.
+// provider name:
+// 1. cookie
+// 2. file
+// 3. memory
+// 4. redis
+// 5. mysql
+// json config:
+// 1. is https default false
+// 2. hashfunc default sha1
+// 3. hashkey default beegosessionkey
+// 4. maxage default is none
+func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) {
+ provider, ok := provides[provideName]
+ if !ok {
+ return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName)
+ }
+
+ if cf.Maxlifetime == 0 {
+ cf.Maxlifetime = cf.Gclifetime
+ }
+
+ if cf.EnableSidInHTTPHeader {
+ if cf.SessionNameInHTTPHeader == "" {
+ panic(errors.New("SessionNameInHTTPHeader is empty"))
+ }
+
+ strMimeHeader := textproto.CanonicalMIMEHeaderKey(cf.SessionNameInHTTPHeader)
+ if cf.SessionNameInHTTPHeader != strMimeHeader {
+ strErrMsg := "SessionNameInHTTPHeader (" + cf.SessionNameInHTTPHeader + ") has the wrong format, it should be like this : " + strMimeHeader
+ panic(errors.New(strErrMsg))
+ }
+ }
+
+ err := provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig)
+ if err != nil {
+ return nil, err
+ }
+
+ if cf.SessionIDLength == 0 {
+ cf.SessionIDLength = 16
+ }
+
+ return &Manager{
+ provider,
+ cf,
+ }, nil
+}
+
+// GetProvider return current manager's provider
+func (manager *Manager) GetProvider() Provider {
+ return manager.provider
+}
+
+// getSid retrieves session identifier from HTTP Request.
+// First try to retrieve id by reading from cookie, session cookie name is configurable,
+// if not exist, then retrieve id from querying parameters.
+//
+// error is not nil when there is anything wrong.
+// sid is empty when need to generate a new session id
+// otherwise return an valid session id.
+func (manager *Manager) getSid(r *http.Request) (string, error) {
+ cookie, errs := r.Cookie(manager.config.CookieName)
+ if errs != nil || cookie.Value == "" {
+ var sid string
+ if manager.config.EnableSidInURLQuery {
+ errs := r.ParseForm()
+ if errs != nil {
+ return "", errs
+ }
+
+ sid = r.FormValue(manager.config.CookieName)
+ }
+
+ // if not found in Cookie / param, then read it from request headers
+ if manager.config.EnableSidInHTTPHeader && sid == "" {
+ sids, isFound := r.Header[manager.config.SessionNameInHTTPHeader]
+ if isFound && len(sids) != 0 {
+ return sids[0], nil
+ }
+ }
+
+ return sid, nil
+ }
+
+ // HTTP Request contains cookie for sessionid info.
+ return url.QueryUnescape(cookie.Value)
+}
+
+// SessionStart generate or read the session id from http request.
+// if session id exists, return SessionStore with this id.
+func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session Store, err error) {
+ sid, errs := manager.getSid(r)
+ if errs != nil {
+ return nil, errs
+ }
+
+ if sid != "" && manager.provider.SessionExist(sid) {
+ return manager.provider.SessionRead(sid)
+ }
+
+ // Generate a new session
+ sid, errs = manager.sessionID()
+ if errs != nil {
+ return nil, errs
+ }
+
+ session, err = manager.provider.SessionRead(sid)
+ if err != nil {
+ return nil, err
+ }
+ cookie := &http.Cookie{
+ Name: manager.config.CookieName,
+ Value: url.QueryEscape(sid),
+ Path: "/",
+ HttpOnly: !manager.config.DisableHTTPOnly,
+ Secure: manager.isSecure(r),
+ Domain: manager.config.Domain,
+ }
+ if manager.config.CookieLifeTime > 0 {
+ cookie.MaxAge = manager.config.CookieLifeTime
+ cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second)
+ }
+ if manager.config.EnableSetCookie {
+ http.SetCookie(w, cookie)
+ }
+ r.AddCookie(cookie)
+
+ if manager.config.EnableSidInHTTPHeader {
+ r.Header.Set(manager.config.SessionNameInHTTPHeader, sid)
+ w.Header().Set(manager.config.SessionNameInHTTPHeader, sid)
+ }
+
+ return
+}
+
+// SessionDestroy Destroy session by its id in http request cookie.
+func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
+ if manager.config.EnableSidInHTTPHeader {
+ r.Header.Del(manager.config.SessionNameInHTTPHeader)
+ w.Header().Del(manager.config.SessionNameInHTTPHeader)
+ }
+
+ cookie, err := r.Cookie(manager.config.CookieName)
+ if err != nil || cookie.Value == "" {
+ return
+ }
+
+ sid, _ := url.QueryUnescape(cookie.Value)
+ manager.provider.SessionDestroy(sid)
+ if manager.config.EnableSetCookie {
+ expiration := time.Now()
+ cookie = &http.Cookie{Name: manager.config.CookieName,
+ Path: "/",
+ HttpOnly: !manager.config.DisableHTTPOnly,
+ Expires: expiration,
+ MaxAge: -1,
+ Domain: manager.config.Domain}
+
+ http.SetCookie(w, cookie)
+ }
+}
+
+// GetSessionStore Get SessionStore by its id.
+func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) {
+ sessions, err = manager.provider.SessionRead(sid)
+ return
+}
+
+// GC Start session gc process.
+// it can do gc in times after gc lifetime.
+func (manager *Manager) GC() {
+ manager.provider.SessionGC()
+ time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() })
+}
+
+// SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request.
+func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) (session Store) {
+ sid, err := manager.sessionID()
+ if err != nil {
+ return
+ }
+ cookie, err := r.Cookie(manager.config.CookieName)
+ if err != nil || cookie.Value == "" {
+ //delete old cookie
+ session, _ = manager.provider.SessionRead(sid)
+ cookie = &http.Cookie{Name: manager.config.CookieName,
+ Value: url.QueryEscape(sid),
+ Path: "/",
+ HttpOnly: !manager.config.DisableHTTPOnly,
+ Secure: manager.isSecure(r),
+ Domain: manager.config.Domain,
+ }
+ } else {
+ oldsid, _ := url.QueryUnescape(cookie.Value)
+ session, _ = manager.provider.SessionRegenerate(oldsid, sid)
+ cookie.Value = url.QueryEscape(sid)
+ cookie.HttpOnly = true
+ cookie.Path = "/"
+ }
+ if manager.config.CookieLifeTime > 0 {
+ cookie.MaxAge = manager.config.CookieLifeTime
+ cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second)
+ }
+ if manager.config.EnableSetCookie {
+ http.SetCookie(w, cookie)
+ }
+ r.AddCookie(cookie)
+
+ if manager.config.EnableSidInHTTPHeader {
+ r.Header.Set(manager.config.SessionNameInHTTPHeader, sid)
+ w.Header().Set(manager.config.SessionNameInHTTPHeader, sid)
+ }
+
+ return
+}
+
+// GetActiveSession Get all active sessions count number.
+func (manager *Manager) GetActiveSession() int {
+ return manager.provider.SessionAll()
+}
+
+// SetSecure Set cookie with https.
+func (manager *Manager) SetSecure(secure bool) {
+ manager.config.Secure = secure
+}
+
+func (manager *Manager) sessionID() (string, error) {
+ b := make([]byte, manager.config.SessionIDLength)
+ n, err := rand.Read(b)
+ if n != len(b) || err != nil {
+ return "", fmt.Errorf("Could not successfully read from the system CSPRNG")
+ }
+ return manager.config.SessionIDPrefix + hex.EncodeToString(b), nil
+}
+
+// Set cookie with https.
+func (manager *Manager) isSecure(req *http.Request) bool {
+ if !manager.config.Secure {
+ return false
+ }
+ if req.URL.Scheme != "" {
+ return req.URL.Scheme == "https"
+ }
+ if req.TLS == nil {
+ return false
+ }
+ return true
+}
+
+// Log implement the log.Logger
+type Log struct {
+ *log.Logger
+}
+
+// NewSessionLog set io.Writer to create a Logger for session.
+func NewSessionLog(out io.Writer) *Log {
+ sl := new(Log)
+ sl.Logger = log.New(out, "[SESSION]", 1e9)
+ return sl
+}
diff --git a/pkg/session/ssdb/sess_ssdb.go b/pkg/session/ssdb/sess_ssdb.go
new file mode 100644
index 00000000..de0c6360
--- /dev/null
+++ b/pkg/session/ssdb/sess_ssdb.go
@@ -0,0 +1,199 @@
+package ssdb
+
+import (
+ "errors"
+ "net/http"
+ "strconv"
+ "strings"
+ "sync"
+
+ "github.com/astaxie/beego/session"
+ "github.com/ssdb/gossdb/ssdb"
+)
+
+var ssdbProvider = &Provider{}
+
+// Provider holds ssdb client and configs
+type Provider struct {
+ client *ssdb.Client
+ host string
+ port int
+ maxLifetime int64
+}
+
+func (p *Provider) connectInit() error {
+ var err error
+ if p.host == "" || p.port == 0 {
+ return errors.New("SessionInit First")
+ }
+ p.client, err = ssdb.Connect(p.host, p.port)
+ return err
+}
+
+// SessionInit init the ssdb with the config
+func (p *Provider) SessionInit(maxLifetime int64, savePath string) error {
+ p.maxLifetime = maxLifetime
+ address := strings.Split(savePath, ":")
+ p.host = address[0]
+
+ var err error
+ if p.port, err = strconv.Atoi(address[1]); err != nil {
+ return err
+ }
+ return p.connectInit()
+}
+
+// SessionRead return a ssdb client session Store
+func (p *Provider) SessionRead(sid string) (session.Store, error) {
+ if p.client == nil {
+ if err := p.connectInit(); err != nil {
+ return nil, err
+ }
+ }
+ var kv map[interface{}]interface{}
+ value, err := p.client.Get(sid)
+ if err != nil {
+ return nil, err
+ }
+ if value == nil || len(value.(string)) == 0 {
+ kv = make(map[interface{}]interface{})
+ } else {
+ kv, err = session.DecodeGob([]byte(value.(string)))
+ if err != nil {
+ return nil, err
+ }
+ }
+ rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client}
+ return rs, nil
+}
+
+// SessionExist judged whether sid is exist in session
+func (p *Provider) SessionExist(sid string) bool {
+ if p.client == nil {
+ if err := p.connectInit(); err != nil {
+ panic(err)
+ }
+ }
+ value, err := p.client.Get(sid)
+ if err != nil {
+ panic(err)
+ }
+ if value == nil || len(value.(string)) == 0 {
+ return false
+ }
+ return true
+}
+
+// SessionRegenerate regenerate session with new sid and delete oldsid
+func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
+ //conn.Do("setx", key, v, ttl)
+ if p.client == nil {
+ if err := p.connectInit(); err != nil {
+ return nil, err
+ }
+ }
+ value, err := p.client.Get(oldsid)
+ if err != nil {
+ return nil, err
+ }
+ var kv map[interface{}]interface{}
+ if value == nil || len(value.(string)) == 0 {
+ kv = make(map[interface{}]interface{})
+ } else {
+ kv, err = session.DecodeGob([]byte(value.(string)))
+ if err != nil {
+ return nil, err
+ }
+ _, err = p.client.Del(oldsid)
+ if err != nil {
+ return nil, err
+ }
+ }
+ _, e := p.client.Do("setx", sid, value, p.maxLifetime)
+ if e != nil {
+ return nil, e
+ }
+ rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client}
+ return rs, nil
+}
+
+// SessionDestroy destroy the sid
+func (p *Provider) SessionDestroy(sid string) error {
+ if p.client == nil {
+ if err := p.connectInit(); err != nil {
+ return err
+ }
+ }
+ _, err := p.client.Del(sid)
+ return err
+}
+
+// SessionGC not implemented
+func (p *Provider) SessionGC() {
+}
+
+// SessionAll not implemented
+func (p *Provider) SessionAll() int {
+ return 0
+}
+
+// SessionStore holds the session information which stored in ssdb
+type SessionStore struct {
+ sid string
+ lock sync.RWMutex
+ values map[interface{}]interface{}
+ maxLifetime int64
+ client *ssdb.Client
+}
+
+// Set the key and value
+func (s *SessionStore) Set(key, value interface{}) error {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ s.values[key] = value
+ return nil
+}
+
+// Get return the value by the key
+func (s *SessionStore) Get(key interface{}) interface{} {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ if value, ok := s.values[key]; ok {
+ return value
+ }
+ return nil
+}
+
+// Delete the key in session store
+func (s *SessionStore) Delete(key interface{}) error {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ delete(s.values, key)
+ return nil
+}
+
+// Flush delete all keys and values
+func (s *SessionStore) Flush() error {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ s.values = make(map[interface{}]interface{})
+ return nil
+}
+
+// SessionID return the sessionID
+func (s *SessionStore) SessionID() string {
+ return s.sid
+}
+
+// SessionRelease Store the keyvalues into ssdb
+func (s *SessionStore) SessionRelease(w http.ResponseWriter) {
+ b, err := session.EncodeGob(s.values)
+ if err != nil {
+ return
+ }
+ s.client.Do("setx", s.sid, string(b), s.maxLifetime)
+}
+
+func init() {
+ session.Register("ssdb", ssdbProvider)
+}
diff --git a/pkg/staticfile.go b/pkg/staticfile.go
new file mode 100644
index 00000000..84e9aa7b
--- /dev/null
+++ b/pkg/staticfile.go
@@ -0,0 +1,234 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package beego
+
+import (
+ "bytes"
+ "errors"
+ "net/http"
+ "os"
+ "path"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/astaxie/beego/context"
+ "github.com/astaxie/beego/logs"
+ "github.com/hashicorp/golang-lru"
+)
+
+var errNotStaticRequest = errors.New("request not a static file request")
+
+func serverStaticRouter(ctx *context.Context) {
+ if ctx.Input.Method() != "GET" && ctx.Input.Method() != "HEAD" {
+ return
+ }
+
+ forbidden, filePath, fileInfo, err := lookupFile(ctx)
+ if err == errNotStaticRequest {
+ return
+ }
+
+ if forbidden {
+ exception("403", ctx)
+ return
+ }
+
+ if filePath == "" || fileInfo == nil {
+ if BConfig.RunMode == DEV {
+ logs.Warn("Can't find/open the file:", filePath, err)
+ }
+ http.NotFound(ctx.ResponseWriter, ctx.Request)
+ return
+ }
+ if fileInfo.IsDir() {
+ requestURL := ctx.Input.URL()
+ if requestURL[len(requestURL)-1] != '/' {
+ redirectURL := requestURL + "/"
+ if ctx.Request.URL.RawQuery != "" {
+ redirectURL = redirectURL + "?" + ctx.Request.URL.RawQuery
+ }
+ ctx.Redirect(302, redirectURL)
+ } else {
+ //serveFile will list dir
+ http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath)
+ }
+ return
+ } else if fileInfo.Size() > int64(BConfig.WebConfig.StaticCacheFileSize) {
+ //over size file serve with http module
+ http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath)
+ return
+ }
+
+ var enableCompress = BConfig.EnableGzip && isStaticCompress(filePath)
+ var acceptEncoding string
+ if enableCompress {
+ acceptEncoding = context.ParseEncoding(ctx.Request)
+ }
+ b, n, sch, reader, err := openFile(filePath, fileInfo, acceptEncoding)
+ if err != nil {
+ if BConfig.RunMode == DEV {
+ logs.Warn("Can't compress the file:", filePath, err)
+ }
+ http.NotFound(ctx.ResponseWriter, ctx.Request)
+ return
+ }
+
+ if b {
+ ctx.Output.Header("Content-Encoding", n)
+ } else {
+ ctx.Output.Header("Content-Length", strconv.FormatInt(sch.size, 10))
+ }
+
+ http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, sch.modTime, reader)
+}
+
+type serveContentHolder struct {
+ data []byte
+ modTime time.Time
+ size int64
+ originSize int64 //original file size:to judge file changed
+ encoding string
+}
+
+type serveContentReader struct {
+ *bytes.Reader
+}
+
+var (
+ staticFileLruCache *lru.Cache
+ lruLock sync.RWMutex
+)
+
+func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, *serveContentReader, error) {
+ if staticFileLruCache == nil {
+ //avoid lru cache error
+ if BConfig.WebConfig.StaticCacheFileNum >= 1 {
+ staticFileLruCache, _ = lru.New(BConfig.WebConfig.StaticCacheFileNum)
+ } else {
+ staticFileLruCache, _ = lru.New(1)
+ }
+ }
+ mapKey := acceptEncoding + ":" + filePath
+ lruLock.RLock()
+ var mapFile *serveContentHolder
+ if cacheItem, ok := staticFileLruCache.Get(mapKey); ok {
+ mapFile = cacheItem.(*serveContentHolder)
+ }
+ lruLock.RUnlock()
+ if isOk(mapFile, fi) {
+ reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)}
+ return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil
+ }
+ lruLock.Lock()
+ defer lruLock.Unlock()
+ if cacheItem, ok := staticFileLruCache.Get(mapKey); ok {
+ mapFile = cacheItem.(*serveContentHolder)
+ }
+ if !isOk(mapFile, fi) {
+ file, err := os.Open(filePath)
+ if err != nil {
+ return false, "", nil, nil, err
+ }
+ defer file.Close()
+ var bufferWriter bytes.Buffer
+ _, n, err := context.WriteFile(acceptEncoding, &bufferWriter, file)
+ if err != nil {
+ return false, "", nil, nil, err
+ }
+ mapFile = &serveContentHolder{data: bufferWriter.Bytes(), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), originSize: fi.Size(), encoding: n}
+ if isOk(mapFile, fi) {
+ staticFileLruCache.Add(mapKey, mapFile)
+ }
+ }
+
+ reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)}
+ return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil
+}
+
+func isOk(s *serveContentHolder, fi os.FileInfo) bool {
+ if s == nil {
+ return false
+ } else if s.size > int64(BConfig.WebConfig.StaticCacheFileSize) {
+ return false
+ }
+ return s.modTime == fi.ModTime() && s.originSize == fi.Size()
+}
+
+// isStaticCompress detect static files
+func isStaticCompress(filePath string) bool {
+ for _, statExtension := range BConfig.WebConfig.StaticExtensionsToGzip {
+ if strings.HasSuffix(strings.ToLower(filePath), strings.ToLower(statExtension)) {
+ return true
+ }
+ }
+ return false
+}
+
+// searchFile search the file by url path
+// if none the static file prefix matches ,return notStaticRequestErr
+func searchFile(ctx *context.Context) (string, os.FileInfo, error) {
+ requestPath := filepath.ToSlash(filepath.Clean(ctx.Request.URL.Path))
+ // special processing : favicon.ico/robots.txt can be in any static dir
+ if requestPath == "/favicon.ico" || requestPath == "/robots.txt" {
+ file := path.Join(".", requestPath)
+ if fi, _ := os.Stat(file); fi != nil {
+ return file, fi, nil
+ }
+ for _, staticDir := range BConfig.WebConfig.StaticDir {
+ filePath := path.Join(staticDir, requestPath)
+ if fi, _ := os.Stat(filePath); fi != nil {
+ return filePath, fi, nil
+ }
+ }
+ return "", nil, errNotStaticRequest
+ }
+
+ for prefix, staticDir := range BConfig.WebConfig.StaticDir {
+ if !strings.Contains(requestPath, prefix) {
+ continue
+ }
+ if prefix != "/" && len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' {
+ continue
+ }
+ filePath := path.Join(staticDir, requestPath[len(prefix):])
+ if fi, err := os.Stat(filePath); fi != nil {
+ return filePath, fi, err
+ }
+ }
+ return "", nil, errNotStaticRequest
+}
+
+// lookupFile find the file to serve
+// if the file is dir ,search the index.html as default file( MUST NOT A DIR also)
+// if the index.html not exist or is a dir, give a forbidden response depending on DirectoryIndex
+func lookupFile(ctx *context.Context) (bool, string, os.FileInfo, error) {
+ fp, fi, err := searchFile(ctx)
+ if fp == "" || fi == nil {
+ return false, "", nil, err
+ }
+ if !fi.IsDir() {
+ return false, fp, fi, err
+ }
+ if requestURL := ctx.Input.URL(); requestURL[len(requestURL)-1] == '/' {
+ ifp := filepath.Join(fp, "index.html")
+ if ifi, _ := os.Stat(ifp); ifi != nil && ifi.Mode().IsRegular() {
+ return false, ifp, ifi, err
+ }
+ }
+ return !BConfig.WebConfig.DirectoryIndex, fp, fi, err
+}
diff --git a/pkg/staticfile_test.go b/pkg/staticfile_test.go
new file mode 100644
index 00000000..e46c13ec
--- /dev/null
+++ b/pkg/staticfile_test.go
@@ -0,0 +1,99 @@
+package beego
+
+import (
+ "bytes"
+ "compress/gzip"
+ "compress/zlib"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+var currentWorkDir, _ = os.Getwd()
+var licenseFile = filepath.Join(currentWorkDir, "LICENSE")
+
+func testOpenFile(encoding string, content []byte, t *testing.T) {
+ fi, _ := os.Stat(licenseFile)
+ b, n, sch, reader, err := openFile(licenseFile, fi, encoding)
+ if err != nil {
+ t.Log(err)
+ t.Fail()
+ }
+
+ t.Log("open static file encoding "+n, b)
+
+ assetOpenFileAndContent(sch, reader, content, t)
+}
+func TestOpenStaticFile_1(t *testing.T) {
+ file, _ := os.Open(licenseFile)
+ content, _ := ioutil.ReadAll(file)
+ testOpenFile("", content, t)
+}
+
+func TestOpenStaticFileGzip_1(t *testing.T) {
+ file, _ := os.Open(licenseFile)
+ var zipBuf bytes.Buffer
+ fileWriter, _ := gzip.NewWriterLevel(&zipBuf, gzip.BestCompression)
+ io.Copy(fileWriter, file)
+ fileWriter.Close()
+ content, _ := ioutil.ReadAll(&zipBuf)
+
+ testOpenFile("gzip", content, t)
+}
+func TestOpenStaticFileDeflate_1(t *testing.T) {
+ file, _ := os.Open(licenseFile)
+ var zipBuf bytes.Buffer
+ fileWriter, _ := zlib.NewWriterLevel(&zipBuf, zlib.BestCompression)
+ io.Copy(fileWriter, file)
+ fileWriter.Close()
+ content, _ := ioutil.ReadAll(&zipBuf)
+
+ testOpenFile("deflate", content, t)
+}
+
+func TestStaticCacheWork(t *testing.T) {
+ encodings := []string{"", "gzip", "deflate"}
+
+ fi, _ := os.Stat(licenseFile)
+ for _, encoding := range encodings {
+ _, _, first, _, err := openFile(licenseFile, fi, encoding)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+
+ _, _, second, _, err := openFile(licenseFile, fi, encoding)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+
+ address1 := fmt.Sprintf("%p", first)
+ address2 := fmt.Sprintf("%p", second)
+ if address1 != address2 {
+ t.Errorf("encoding '%v' can not hit cache", encoding)
+ }
+ }
+}
+
+func assetOpenFileAndContent(sch *serveContentHolder, reader *serveContentReader, content []byte, t *testing.T) {
+ t.Log(sch.size, len(content))
+ if sch.size != int64(len(content)) {
+ t.Log("static content file size not same")
+ t.Fail()
+ }
+ bs, _ := ioutil.ReadAll(reader)
+ for i, v := range content {
+ if v != bs[i] {
+ t.Log("content not same")
+ t.Fail()
+ }
+ }
+ if staticFileLruCache.Len() == 0 {
+ t.Log("men map is empty")
+ t.Fail()
+ }
+}
diff --git a/pkg/swagger/swagger.go b/pkg/swagger/swagger.go
new file mode 100644
index 00000000..a55676cd
--- /dev/null
+++ b/pkg/swagger/swagger.go
@@ -0,0 +1,174 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Swagger™ is a project used to describe and document RESTful APIs.
+//
+// The Swagger specification defines a set of files required to describe such an API. These files can then be used by the Swagger-UI project to display the API and Swagger-Codegen to generate clients in various languages. Additional utilities can also take advantage of the resulting files, such as testing tools.
+// Now in version 2.0, Swagger is more enabling than ever. And it's 100% open source software.
+
+// Package swagger struct definition
+package swagger
+
+// Swagger list the resource
+type Swagger struct {
+ SwaggerVersion string `json:"swagger,omitempty" yaml:"swagger,omitempty"`
+ Infos Information `json:"info" yaml:"info"`
+ Host string `json:"host,omitempty" yaml:"host,omitempty"`
+ BasePath string `json:"basePath,omitempty" yaml:"basePath,omitempty"`
+ Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"`
+ Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"`
+ Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"`
+ Paths map[string]*Item `json:"paths" yaml:"paths"`
+ Definitions map[string]Schema `json:"definitions,omitempty" yaml:"definitions,omitempty"`
+ SecurityDefinitions map[string]Security `json:"securityDefinitions,omitempty" yaml:"securityDefinitions,omitempty"`
+ Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"`
+ Tags []Tag `json:"tags,omitempty" yaml:"tags,omitempty"`
+ ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"`
+}
+
+// Information Provides metadata about the API. The metadata can be used by the clients if needed.
+type Information struct {
+ Title string `json:"title,omitempty" yaml:"title,omitempty"`
+ Description string `json:"description,omitempty" yaml:"description,omitempty"`
+ Version string `json:"version,omitempty" yaml:"version,omitempty"`
+ TermsOfService string `json:"termsOfService,omitempty" yaml:"termsOfService,omitempty"`
+
+ Contact Contact `json:"contact,omitempty" yaml:"contact,omitempty"`
+ License *License `json:"license,omitempty" yaml:"license,omitempty"`
+}
+
+// Contact information for the exposed API.
+type Contact struct {
+ Name string `json:"name,omitempty" yaml:"name,omitempty"`
+ URL string `json:"url,omitempty" yaml:"url,omitempty"`
+ EMail string `json:"email,omitempty" yaml:"email,omitempty"`
+}
+
+// License information for the exposed API.
+type License struct {
+ Name string `json:"name,omitempty" yaml:"name,omitempty"`
+ URL string `json:"url,omitempty" yaml:"url,omitempty"`
+}
+
+// Item Describes the operations available on a single path.
+type Item struct {
+ Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"`
+ Get *Operation `json:"get,omitempty" yaml:"get,omitempty"`
+ Put *Operation `json:"put,omitempty" yaml:"put,omitempty"`
+ Post *Operation `json:"post,omitempty" yaml:"post,omitempty"`
+ Delete *Operation `json:"delete,omitempty" yaml:"delete,omitempty"`
+ Options *Operation `json:"options,omitempty" yaml:"options,omitempty"`
+ Head *Operation `json:"head,omitempty" yaml:"head,omitempty"`
+ Patch *Operation `json:"patch,omitempty" yaml:"patch,omitempty"`
+}
+
+// Operation Describes a single API operation on a path.
+type Operation struct {
+ Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"`
+ Summary string `json:"summary,omitempty" yaml:"summary,omitempty"`
+ Description string `json:"description,omitempty" yaml:"description,omitempty"`
+ OperationID string `json:"operationId,omitempty" yaml:"operationId,omitempty"`
+ Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"`
+ Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"`
+ Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"`
+ Parameters []Parameter `json:"parameters,omitempty" yaml:"parameters,omitempty"`
+ Responses map[string]Response `json:"responses,omitempty" yaml:"responses,omitempty"`
+ Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"`
+ Deprecated bool `json:"deprecated,omitempty" yaml:"deprecated,omitempty"`
+}
+
+// Parameter Describes a single operation parameter.
+type Parameter struct {
+ In string `json:"in,omitempty" yaml:"in,omitempty"`
+ Name string `json:"name,omitempty" yaml:"name,omitempty"`
+ Description string `json:"description,omitempty" yaml:"description,omitempty"`
+ Required bool `json:"required,omitempty" yaml:"required,omitempty"`
+ Schema *Schema `json:"schema,omitempty" yaml:"schema,omitempty"`
+ Type string `json:"type,omitempty" yaml:"type,omitempty"`
+ Format string `json:"format,omitempty" yaml:"format,omitempty"`
+ Items *ParameterItems `json:"items,omitempty" yaml:"items,omitempty"`
+ Default interface{} `json:"default,omitempty" yaml:"default,omitempty"`
+}
+
+// ParameterItems A limited subset of JSON-Schema's items object. It is used by parameter definitions that are not located in "body".
+// http://swagger.io/specification/#itemsObject
+type ParameterItems struct {
+ Type string `json:"type,omitempty" yaml:"type,omitempty"`
+ Format string `json:"format,omitempty" yaml:"format,omitempty"`
+ Items []*ParameterItems `json:"items,omitempty" yaml:"items,omitempty"` //Required if type is "array". Describes the type of items in the array.
+ CollectionFormat string `json:"collectionFormat,omitempty" yaml:"collectionFormat,omitempty"`
+ Default string `json:"default,omitempty" yaml:"default,omitempty"`
+}
+
+// Schema Object allows the definition of input and output data types.
+type Schema struct {
+ Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"`
+ Title string `json:"title,omitempty" yaml:"title,omitempty"`
+ Format string `json:"format,omitempty" yaml:"format,omitempty"`
+ Description string `json:"description,omitempty" yaml:"description,omitempty"`
+ Required []string `json:"required,omitempty" yaml:"required,omitempty"`
+ Type string `json:"type,omitempty" yaml:"type,omitempty"`
+ Items *Schema `json:"items,omitempty" yaml:"items,omitempty"`
+ Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"`
+ Enum []interface{} `json:"enum,omitempty" yaml:"enum,omitempty"`
+ Example interface{} `json:"example,omitempty" yaml:"example,omitempty"`
+}
+
+// Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification
+type Propertie struct {
+ Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"`
+ Title string `json:"title,omitempty" yaml:"title,omitempty"`
+ Description string `json:"description,omitempty" yaml:"description,omitempty"`
+ Default interface{} `json:"default,omitempty" yaml:"default,omitempty"`
+ Type string `json:"type,omitempty" yaml:"type,omitempty"`
+ Example interface{} `json:"example,omitempty" yaml:"example,omitempty"`
+ Required []string `json:"required,omitempty" yaml:"required,omitempty"`
+ Format string `json:"format,omitempty" yaml:"format,omitempty"`
+ ReadOnly bool `json:"readOnly,omitempty" yaml:"readOnly,omitempty"`
+ Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"`
+ Items *Propertie `json:"items,omitempty" yaml:"items,omitempty"`
+ AdditionalProperties *Propertie `json:"additionalProperties,omitempty" yaml:"additionalProperties,omitempty"`
+}
+
+// Response as they are returned from executing this operation.
+type Response struct {
+ Description string `json:"description" yaml:"description"`
+ Schema *Schema `json:"schema,omitempty" yaml:"schema,omitempty"`
+ Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"`
+}
+
+// Security Allows the definition of a security scheme that can be used by the operations
+type Security struct {
+ Type string `json:"type,omitempty" yaml:"type,omitempty"` // Valid values are "basic", "apiKey" or "oauth2".
+ Description string `json:"description,omitempty" yaml:"description,omitempty"`
+ Name string `json:"name,omitempty" yaml:"name,omitempty"`
+ In string `json:"in,omitempty" yaml:"in,omitempty"` // Valid values are "query" or "header".
+ Flow string `json:"flow,omitempty" yaml:"flow,omitempty"` // Valid values are "implicit", "password", "application" or "accessCode".
+ AuthorizationURL string `json:"authorizationUrl,omitempty" yaml:"authorizationUrl,omitempty"`
+ TokenURL string `json:"tokenUrl,omitempty" yaml:"tokenUrl,omitempty"`
+ Scopes map[string]string `json:"scopes,omitempty" yaml:"scopes,omitempty"` // The available scopes for the OAuth2 security scheme.
+}
+
+// Tag Allows adding meta data to a single tag that is used by the Operation Object
+type Tag struct {
+ Name string `json:"name,omitempty" yaml:"name,omitempty"`
+ Description string `json:"description,omitempty" yaml:"description,omitempty"`
+ ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"`
+}
+
+// ExternalDocs include Additional external documentation
+type ExternalDocs struct {
+ Description string `json:"description,omitempty" yaml:"description,omitempty"`
+ URL string `json:"url,omitempty" yaml:"url,omitempty"`
+}
diff --git a/pkg/template.go b/pkg/template.go
new file mode 100644
index 00000000..59875be7
--- /dev/null
+++ b/pkg/template.go
@@ -0,0 +1,406 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package beego
+
+import (
+ "errors"
+ "fmt"
+ "html/template"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "os"
+ "path/filepath"
+ "regexp"
+ "strings"
+ "sync"
+
+ "github.com/astaxie/beego/logs"
+ "github.com/astaxie/beego/utils"
+)
+
+var (
+ beegoTplFuncMap = make(template.FuncMap)
+ beeViewPathTemplateLocked = false
+ // beeViewPathTemplates caching map and supported template file extensions per view
+ beeViewPathTemplates = make(map[string]map[string]*template.Template)
+ templatesLock sync.RWMutex
+ // beeTemplateExt stores the template extension which will build
+ beeTemplateExt = []string{"tpl", "html", "gohtml"}
+ // beeTemplatePreprocessors stores associations of extension -> preprocessor handler
+ beeTemplateEngines = map[string]templatePreProcessor{}
+ beeTemplateFS = defaultFSFunc
+)
+
+// ExecuteTemplate applies the template with name to the specified data object,
+// writing the output to wr.
+// A template will be executed safely in parallel.
+func ExecuteTemplate(wr io.Writer, name string, data interface{}) error {
+ return ExecuteViewPathTemplate(wr, name, BConfig.WebConfig.ViewsPath, data)
+}
+
+// ExecuteViewPathTemplate applies the template with name and from specific viewPath to the specified data object,
+// writing the output to wr.
+// A template will be executed safely in parallel.
+func ExecuteViewPathTemplate(wr io.Writer, name string, viewPath string, data interface{}) error {
+ if BConfig.RunMode == DEV {
+ templatesLock.RLock()
+ defer templatesLock.RUnlock()
+ }
+ if beeTemplates, ok := beeViewPathTemplates[viewPath]; ok {
+ if t, ok := beeTemplates[name]; ok {
+ var err error
+ if t.Lookup(name) != nil {
+ err = t.ExecuteTemplate(wr, name, data)
+ } else {
+ err = t.Execute(wr, data)
+ }
+ if err != nil {
+ logs.Trace("template Execute err:", err)
+ }
+ return err
+ }
+ panic("can't find templatefile in the path:" + viewPath + "/" + name)
+ }
+ panic("Unknown view path:" + viewPath)
+}
+
+func init() {
+ beegoTplFuncMap["dateformat"] = DateFormat
+ beegoTplFuncMap["date"] = Date
+ beegoTplFuncMap["compare"] = Compare
+ beegoTplFuncMap["compare_not"] = CompareNot
+ beegoTplFuncMap["not_nil"] = NotNil
+ beegoTplFuncMap["not_null"] = NotNil
+ beegoTplFuncMap["substr"] = Substr
+ beegoTplFuncMap["html2str"] = HTML2str
+ beegoTplFuncMap["str2html"] = Str2html
+ beegoTplFuncMap["htmlquote"] = Htmlquote
+ beegoTplFuncMap["htmlunquote"] = Htmlunquote
+ beegoTplFuncMap["renderform"] = RenderForm
+ beegoTplFuncMap["assets_js"] = AssetsJs
+ beegoTplFuncMap["assets_css"] = AssetsCSS
+ beegoTplFuncMap["config"] = GetConfig
+ beegoTplFuncMap["map_get"] = MapGet
+
+ // Comparisons
+ beegoTplFuncMap["eq"] = eq // ==
+ beegoTplFuncMap["ge"] = ge // >=
+ beegoTplFuncMap["gt"] = gt // >
+ beegoTplFuncMap["le"] = le // <=
+ beegoTplFuncMap["lt"] = lt // <
+ beegoTplFuncMap["ne"] = ne // !=
+
+ beegoTplFuncMap["urlfor"] = URLFor // build a URL to match a Controller and it's method
+}
+
+// AddFuncMap let user to register a func in the template.
+func AddFuncMap(key string, fn interface{}) error {
+ beegoTplFuncMap[key] = fn
+ return nil
+}
+
+type templatePreProcessor func(root, path string, funcs template.FuncMap) (*template.Template, error)
+
+type templateFile struct {
+ root string
+ files map[string][]string
+}
+
+// visit will make the paths into two part,the first is subDir (without tf.root),the second is full path(without tf.root).
+// if tf.root="views" and
+// paths is "views/errors/404.html",the subDir will be "errors",the file will be "errors/404.html"
+// paths is "views/admin/errors/404.html",the subDir will be "admin/errors",the file will be "admin/errors/404.html"
+func (tf *templateFile) visit(paths string, f os.FileInfo, err error) error {
+ if f == nil {
+ return err
+ }
+ if f.IsDir() || (f.Mode()&os.ModeSymlink) > 0 {
+ return nil
+ }
+ if !HasTemplateExt(paths) {
+ return nil
+ }
+
+ replace := strings.NewReplacer("\\", "/")
+ file := strings.TrimLeft(replace.Replace(paths[len(tf.root):]), "/")
+ subDir := filepath.Dir(file)
+
+ tf.files[subDir] = append(tf.files[subDir], file)
+ return nil
+}
+
+// HasTemplateExt return this path contains supported template extension of beego or not.
+func HasTemplateExt(paths string) bool {
+ for _, v := range beeTemplateExt {
+ if strings.HasSuffix(paths, "."+v) {
+ return true
+ }
+ }
+ return false
+}
+
+// AddTemplateExt add new extension for template.
+func AddTemplateExt(ext string) {
+ for _, v := range beeTemplateExt {
+ if v == ext {
+ return
+ }
+ }
+ beeTemplateExt = append(beeTemplateExt, ext)
+}
+
+// AddViewPath adds a new path to the supported view paths.
+//Can later be used by setting a controller ViewPath to this folder
+//will panic if called after beego.Run()
+func AddViewPath(viewPath string) error {
+ if beeViewPathTemplateLocked {
+ if _, exist := beeViewPathTemplates[viewPath]; exist {
+ return nil //Ignore if viewpath already exists
+ }
+ panic("Can not add new view paths after beego.Run()")
+ }
+ beeViewPathTemplates[viewPath] = make(map[string]*template.Template)
+ return BuildTemplate(viewPath)
+}
+
+func lockViewPaths() {
+ beeViewPathTemplateLocked = true
+}
+
+// BuildTemplate will build all template files in a directory.
+// it makes beego can render any template file in view directory.
+func BuildTemplate(dir string, files ...string) error {
+ var err error
+ fs := beeTemplateFS()
+ f, err := fs.Open(dir)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil
+ }
+ return errors.New("dir open err")
+ }
+ defer f.Close()
+
+ beeTemplates, ok := beeViewPathTemplates[dir]
+ if !ok {
+ panic("Unknown view path: " + dir)
+ }
+ self := &templateFile{
+ root: dir,
+ files: make(map[string][]string),
+ }
+ err = Walk(fs, dir, func(path string, f os.FileInfo, err error) error {
+ return self.visit(path, f, err)
+ })
+ if err != nil {
+ fmt.Printf("Walk() returned %v\n", err)
+ return err
+ }
+ buildAllFiles := len(files) == 0
+ for _, v := range self.files {
+ for _, file := range v {
+ if buildAllFiles || utils.InSlice(file, files) {
+ templatesLock.Lock()
+ ext := filepath.Ext(file)
+ var t *template.Template
+ if len(ext) == 0 {
+ t, err = getTemplate(self.root, fs, file, v...)
+ } else if fn, ok := beeTemplateEngines[ext[1:]]; ok {
+ t, err = fn(self.root, file, beegoTplFuncMap)
+ } else {
+ t, err = getTemplate(self.root, fs, file, v...)
+ }
+ if err != nil {
+ logs.Error("parse template err:", file, err)
+ templatesLock.Unlock()
+ return err
+ }
+ beeTemplates[file] = t
+ templatesLock.Unlock()
+ }
+ }
+ }
+ return nil
+}
+
+func getTplDeep(root string, fs http.FileSystem, file string, parent string, t *template.Template) (*template.Template, [][]string, error) {
+ var fileAbsPath string
+ var rParent string
+ var err error
+ if strings.HasPrefix(file, "../") {
+ rParent = filepath.Join(filepath.Dir(parent), file)
+ fileAbsPath = filepath.Join(root, filepath.Dir(parent), file)
+ } else {
+ rParent = file
+ fileAbsPath = filepath.Join(root, file)
+ }
+ f, err := fs.Open(fileAbsPath)
+ if err != nil {
+ panic("can't find template file:" + file)
+ }
+ defer f.Close()
+ data, err := ioutil.ReadAll(f)
+ if err != nil {
+ return nil, [][]string{}, err
+ }
+ t, err = t.New(file).Parse(string(data))
+ if err != nil {
+ return nil, [][]string{}, err
+ }
+ reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*template[ ]+\"([^\"]+)\"")
+ allSub := reg.FindAllStringSubmatch(string(data), -1)
+ for _, m := range allSub {
+ if len(m) == 2 {
+ tl := t.Lookup(m[1])
+ if tl != nil {
+ continue
+ }
+ if !HasTemplateExt(m[1]) {
+ continue
+ }
+ _, _, err = getTplDeep(root, fs, m[1], rParent, t)
+ if err != nil {
+ return nil, [][]string{}, err
+ }
+ }
+ }
+ return t, allSub, nil
+}
+
+func getTemplate(root string, fs http.FileSystem, file string, others ...string) (t *template.Template, err error) {
+ t = template.New(file).Delims(BConfig.WebConfig.TemplateLeft, BConfig.WebConfig.TemplateRight).Funcs(beegoTplFuncMap)
+ var subMods [][]string
+ t, subMods, err = getTplDeep(root, fs, file, "", t)
+ if err != nil {
+ return nil, err
+ }
+ t, err = _getTemplate(t, root, fs, subMods, others...)
+
+ if err != nil {
+ return nil, err
+ }
+ return
+}
+
+func _getTemplate(t0 *template.Template, root string, fs http.FileSystem, subMods [][]string, others ...string) (t *template.Template, err error) {
+ t = t0
+ for _, m := range subMods {
+ if len(m) == 2 {
+ tpl := t.Lookup(m[1])
+ if tpl != nil {
+ continue
+ }
+ //first check filename
+ for _, otherFile := range others {
+ if otherFile == m[1] {
+ var subMods1 [][]string
+ t, subMods1, err = getTplDeep(root, fs, otherFile, "", t)
+ if err != nil {
+ logs.Trace("template parse file err:", err)
+ } else if len(subMods1) > 0 {
+ t, err = _getTemplate(t, root, fs, subMods1, others...)
+ }
+ break
+ }
+ }
+ //second check define
+ for _, otherFile := range others {
+ var data []byte
+ fileAbsPath := filepath.Join(root, otherFile)
+ f, err := fs.Open(fileAbsPath)
+ if err != nil {
+ f.Close()
+ logs.Trace("template file parse error, not success open file:", err)
+ continue
+ }
+ data, err = ioutil.ReadAll(f)
+ f.Close()
+ if err != nil {
+ logs.Trace("template file parse error, not success read file:", err)
+ continue
+ }
+ reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*define[ ]+\"([^\"]+)\"")
+ allSub := reg.FindAllStringSubmatch(string(data), -1)
+ for _, sub := range allSub {
+ if len(sub) == 2 && sub[1] == m[1] {
+ var subMods1 [][]string
+ t, subMods1, err = getTplDeep(root, fs, otherFile, "", t)
+ if err != nil {
+ logs.Trace("template parse file err:", err)
+ } else if len(subMods1) > 0 {
+ t, err = _getTemplate(t, root, fs, subMods1, others...)
+ if err != nil {
+ logs.Trace("template parse file err:", err)
+ }
+ }
+ break
+ }
+ }
+ }
+ }
+
+ }
+ return
+}
+
+type templateFSFunc func() http.FileSystem
+
+func defaultFSFunc() http.FileSystem {
+ return FileSystem{}
+}
+
+// SetTemplateFSFunc set default filesystem function
+func SetTemplateFSFunc(fnt templateFSFunc) {
+ beeTemplateFS = fnt
+}
+
+// SetViewsPath sets view directory path in beego application.
+func SetViewsPath(path string) *App {
+ BConfig.WebConfig.ViewsPath = path
+ return BeeApp
+}
+
+// SetStaticPath sets static directory path and proper url pattern in beego application.
+// if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public".
+func SetStaticPath(url string, path string) *App {
+ if !strings.HasPrefix(url, "/") {
+ url = "/" + url
+ }
+ if url != "/" {
+ url = strings.TrimRight(url, "/")
+ }
+ BConfig.WebConfig.StaticDir[url] = path
+ return BeeApp
+}
+
+// DelStaticPath removes the static folder setting in this url pattern in beego application.
+func DelStaticPath(url string) *App {
+ if !strings.HasPrefix(url, "/") {
+ url = "/" + url
+ }
+ if url != "/" {
+ url = strings.TrimRight(url, "/")
+ }
+ delete(BConfig.WebConfig.StaticDir, url)
+ return BeeApp
+}
+
+// AddTemplateEngine add a new templatePreProcessor which support extension
+func AddTemplateEngine(extension string, fn templatePreProcessor) *App {
+ AddTemplateExt(extension)
+ beeTemplateEngines[extension] = fn
+ return BeeApp
+}
diff --git a/pkg/template_test.go b/pkg/template_test.go
new file mode 100644
index 00000000..287faadc
--- /dev/null
+++ b/pkg/template_test.go
@@ -0,0 +1,316 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package beego
+
+import (
+ "bytes"
+ "github.com/astaxie/beego/testdata"
+ "github.com/elazarl/go-bindata-assetfs"
+ "net/http"
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+var header = `{{define "header"}}
+