1
0
mirror of https://github.com/astaxie/beego.git synced 2025-06-12 07:50:39 +00:00

remove pkg directory;

remove build directory;
remove githook directory;
This commit is contained in:
Ming Deng
2020-10-08 17:17:15 +08:00
parent 034cb3222e
commit 14c1b76569
431 changed files with 372 additions and 545 deletions

13
server/web/LICENSE Normal file
View File

@ -0,0 +1,13 @@
Copyright 2014 astaxie
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

126
server/web/admin.go Normal file
View File

@ -0,0 +1,126 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"fmt"
"net/http"
"reflect"
"time"
"github.com/astaxie/beego/core/logs"
)
// BeeAdminApp is the default adminApp used by admin module.
var beeAdminApp *adminApp
// FilterMonitorFunc is default monitor filter when admin module is enable.
// if this func returns, admin module records qps for this request by condition of this function logic.
// usage:
// func MyFilterMonitor(method, requestPath string, t time.Duration, pattern string, statusCode int) bool {
// if method == "POST" {
// return false
// }
// if t.Nanoseconds() < 100 {
// return false
// }
// if strings.HasPrefix(requestPath, "/astaxie") {
// return false
// }
// return true
// }
// beego.FilterMonitorFunc = MyFilterMonitor.
var FilterMonitorFunc func(string, string, time.Duration, string, int) bool
func init() {
FilterMonitorFunc = func(string, string, time.Duration, string, int) bool { return true }
}
func list(root string, p interface{}, m M) {
pt := reflect.TypeOf(p)
pv := reflect.ValueOf(p)
if pt.Kind() == reflect.Ptr {
pt = pt.Elem()
pv = pv.Elem()
}
for i := 0; i < pv.NumField(); i++ {
var key string
if root == "" {
key = pt.Field(i).Name
} else {
key = root + "." + pt.Field(i).Name
}
if pv.Field(i).Kind() == reflect.Struct {
list(key, pv.Field(i).Interface(), m)
} else {
m[key] = pv.Field(i).Interface()
}
}
}
func writeJSON(rw http.ResponseWriter, jsonData []byte) {
rw.Header().Set("Content-Type", "application/json")
rw.Write(jsonData)
}
// adminApp is an http.HandlerFunc map used as beeAdminApp.
type adminApp struct {
*HttpServer
}
// Route adds http.HandlerFunc to adminApp with url pattern.
func (admin *adminApp) Run() {
// if len(task.AdminTaskList) > 0 {
// task.StartTask()
// }
logs.Warning("now we don't start tasks here, if you use task module," +
" please invoke task.StartTask, or task will not be executed")
addr := BConfig.Listen.AdminAddr
if BConfig.Listen.AdminPort != 0 {
addr = fmt.Sprintf("%s:%d", BConfig.Listen.AdminAddr, BConfig.Listen.AdminPort)
}
logs.Info("Admin server Running on %s", addr)
admin.HttpServer.Run(addr)
}
func registerAdmin() error {
if BConfig.Listen.EnableAdmin {
c := &adminController{
servers: make([]*HttpServer, 0, 2),
}
beeAdminApp = &adminApp{
HttpServer: NewHttpServerWithCfg(*BConfig),
}
// keep in mind that all data should be html escaped to avoid XSS attack
beeAdminApp.Router("/", c, "get:AdminIndex")
beeAdminApp.Router("/qps", c, "get:QpsIndex")
beeAdminApp.Router("/prof", c, "get:ProfIndex")
beeAdminApp.Router("/healthcheck", c, "get:Healthcheck")
beeAdminApp.Router("/task", c, "get:TaskStatus")
beeAdminApp.Router("/listconf", c, "get:ListConf")
beeAdminApp.Router("/metrics", c, "get:PrometheusMetrics")
go beeAdminApp.Run()
}
return nil
}

View File

@ -0,0 +1,297 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"strconv"
"text/template"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/astaxie/beego/core/governor"
)
type adminController struct {
Controller
servers []*HttpServer
}
func (a *adminController) registerHttpServer(svr *HttpServer) {
a.servers = append(a.servers, svr)
}
// ProfIndex is a http.Handler for showing profile command.
// it's in url pattern "/prof" in admin module.
func (a *adminController) ProfIndex() {
rw, r := a.Ctx.ResponseWriter, a.Ctx.Request
r.ParseForm()
command := r.Form.Get("command")
if command == "" {
return
}
var (
format = r.Form.Get("format")
data = make(map[interface{}]interface{})
result bytes.Buffer
)
governor.ProcessInput(command, &result)
data["Content"] = template.HTMLEscapeString(result.String())
if format == "json" && command == "gc summary" {
dataJSON, err := json.Marshal(data)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(rw, dataJSON)
return
}
data["Title"] = template.HTMLEscapeString(command)
defaultTpl := defaultScriptsTpl
if command == "gc summary" {
defaultTpl = gcAjaxTpl
}
writeTemplate(rw, data, profillingTpl, defaultTpl)
}
func (a *adminController) PrometheusMetrics() {
promhttp.Handler().ServeHTTP(a.Ctx.ResponseWriter, a.Ctx.Request)
}
// TaskStatus is a http.Handler with running task status (task name, status and the last execution).
// it's in "/task" pattern in admin module.
func (a *adminController) TaskStatus() {
rw, req := a.Ctx.ResponseWriter, a.Ctx.Request
data := make(map[interface{}]interface{})
// Run Task
req.ParseForm()
taskname := req.Form.Get("taskname")
if taskname != "" {
cmd := governor.GetCommand("task", "run")
res := cmd.Execute(taskname)
if res.IsSuccess() {
data["Message"] = []string{"success",
template.HTMLEscapeString(fmt.Sprintf("%s run success,Now the Status is <br>%s",
taskname, res.Content.(string)))}
} else {
data["Message"] = []string{"error", template.HTMLEscapeString(fmt.Sprintf("%s", res.Error))}
}
}
// List Tasks
content := make(M)
resultList := governor.GetCommand("task", "list").Execute().Content.([][]string)
var fields = []string{
"Task Name",
"Task Spec",
"Task Status",
"Last Time",
"",
}
content["Fields"] = fields
content["Data"] = resultList
data["Content"] = content
data["Title"] = "Tasks"
writeTemplate(rw, data, tasksTpl, defaultScriptsTpl)
}
func (a *adminController) AdminIndex() {
// AdminIndex is the default http.Handler for admin module.
// it matches url pattern "/".
writeTemplate(a.Ctx.ResponseWriter, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl)
}
// Healthcheck is a http.Handler calling health checking and showing the result.
// it's in "/healthcheck" pattern in admin module.
func (a *adminController) Healthcheck() {
heathCheck(a.Ctx.ResponseWriter, a.Ctx.Request)
}
func heathCheck(rw http.ResponseWriter, r *http.Request) {
var (
result []string
data = make(map[interface{}]interface{})
resultList = new([][]string)
content = M{
"Fields": []string{"Name", "Message", "Status"},
}
)
for name, h := range governor.AdminCheckList {
if err := h.Check(); err != nil {
result = []string{
"error",
template.HTMLEscapeString(name),
template.HTMLEscapeString(err.Error()),
}
} else {
result = []string{
"success",
template.HTMLEscapeString(name),
"OK",
}
}
*resultList = append(*resultList, result)
}
queryParams := r.URL.Query()
jsonFlag := queryParams.Get("json")
shouldReturnJSON, _ := strconv.ParseBool(jsonFlag)
if shouldReturnJSON {
response := buildHealthCheckResponseList(resultList)
jsonResponse, err := json.Marshal(response)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
} else {
writeJSON(rw, jsonResponse)
}
return
}
content["Data"] = resultList
data["Content"] = content
data["Title"] = "Health Check"
writeTemplate(rw, data, healthCheckTpl, defaultScriptsTpl)
}
// QpsIndex is the http.Handler for writing qps statistics map result info in http.ResponseWriter.
// it's registered with url pattern "/qps" in admin module.
func (a *adminController) QpsIndex() {
data := make(map[interface{}]interface{})
data["Content"] = StatisticsMap.GetMap()
// do html escape before display path, avoid xss
if content, ok := (data["Content"]).(M); ok {
if resultLists, ok := (content["Data"]).([][]string); ok {
for i := range resultLists {
if len(resultLists[i]) > 0 {
resultLists[i][0] = template.HTMLEscapeString(resultLists[i][0])
}
}
}
}
writeTemplate(a.Ctx.ResponseWriter, data, qpsTpl, defaultScriptsTpl)
}
// ListConf is the http.Handler of displaying all beego configuration values as key/value pair.
// it's registered with url pattern "/listconf" in admin module.
func (a *adminController) ListConf() {
rw := a.Ctx.ResponseWriter
r := a.Ctx.Request
r.ParseForm()
command := r.Form.Get("command")
if command == "" {
rw.Write([]byte("command not support"))
return
}
data := make(map[interface{}]interface{})
switch command {
case "conf":
m := make(M)
list("BConfig", BConfig, m)
m["appConfigPath"] = template.HTMLEscapeString(appConfigPath)
m["appConfigProvider"] = template.HTMLEscapeString(appConfigProvider)
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
tmpl = template.Must(tmpl.Parse(configTpl))
tmpl = template.Must(tmpl.Parse(defaultScriptsTpl))
data["Content"] = m
tmpl.Execute(rw, data)
case "router":
content := BeeApp.PrintTree()
content["Fields"] = []string{
"Router Pattern",
"Methods",
"Controller",
}
data["Content"] = content
data["Title"] = "Routers"
writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl)
case "filter":
var (
content = M{
"Fields": []string{
"Router Pattern",
"Filter Function",
},
}
)
filterTypeData := BeeApp.reportFilter()
filterTypes := make([]string, 0, len(filterTypeData))
for k, _ := range filterTypeData {
filterTypes = append(filterTypes, k)
}
content["Data"] = filterTypeData
content["Methods"] = filterTypes
data["Content"] = content
data["Title"] = "Filters"
writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl)
default:
rw.Write([]byte("command not support"))
}
}
func writeTemplate(rw http.ResponseWriter, data map[interface{}]interface{}, tpls ...string) {
tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl))
for _, tpl := range tpls {
tmpl = template.Must(tmpl.Parse(tpl))
}
tmpl.Execute(rw, data)
}
func buildHealthCheckResponseList(healthCheckResults *[][]string) []map[string]interface{} {
response := make([]map[string]interface{}, len(*healthCheckResults))
for i, healthCheckResult := range *healthCheckResults {
currentResultMap := make(map[string]interface{})
currentResultMap["name"] = healthCheckResult[0]
currentResultMap["message"] = healthCheckResult[1]
currentResultMap["status"] = healthCheckResult[2]
response[i] = currentResultMap
}
return response
}
// PrintTree print all routers
// Deprecated using BeeApp directly
func PrintTree() M {
return BeeApp.PrintTree()
}

249
server/web/admin_test.go Normal file
View File

@ -0,0 +1,249 @@
package web
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/astaxie/beego/core/governor"
)
type SampleDatabaseCheck struct {
}
type SampleCacheCheck struct {
}
func (dc *SampleDatabaseCheck) Check() error {
return nil
}
func (cc *SampleCacheCheck) Check() error {
return errors.New("no cache detected")
}
func TestList_01(t *testing.T) {
m := make(M)
list("BConfig", BConfig, m)
t.Log(m)
om := oldMap()
for k, v := range om {
if fmt.Sprint(m[k]) != fmt.Sprint(v) {
t.Log(k, "old-key", v, "new-key", m[k])
t.FailNow()
}
}
}
func oldMap() M {
m := make(M)
m["BConfig.AppName"] = BConfig.AppName
m["BConfig.RunMode"] = BConfig.RunMode
m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive
m["BConfig.ServerName"] = BConfig.ServerName
m["BConfig.RecoverPanic"] = BConfig.RecoverPanic
m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody
m["BConfig.EnableGzip"] = BConfig.EnableGzip
m["BConfig.MaxMemory"] = BConfig.MaxMemory
m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow
m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful
m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut
m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4
m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP
m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr
m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort
m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS
m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr
m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort
m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile
m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile
m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin
m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr
m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort
m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi
m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo
m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender
m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs
m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName
m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator
m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex
m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir
m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip
m["BConfig.WebConfig.StaticCacheFileSize"] = BConfig.WebConfig.StaticCacheFileSize
m["BConfig.WebConfig.StaticCacheFileNum"] = BConfig.WebConfig.StaticCacheFileNum
m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft
m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight
m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath
m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF
m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire
m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn
m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider
m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName
m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime
m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig
m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime
m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie
m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain
m["BConfig.WebConfig.Session.SessionDisableHTTPOnly"] = BConfig.WebConfig.Session.SessionDisableHTTPOnly
m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs
m["BConfig.Log.EnableStaticLogs"] = BConfig.Log.EnableStaticLogs
m["BConfig.Log.AccessLogsFormat"] = BConfig.Log.AccessLogsFormat
m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum
m["BConfig.Log.Outputs"] = BConfig.Log.Outputs
return m
}
func TestWriteJSON(t *testing.T) {
t.Log("Testing the adding of JSON to the response")
w := httptest.NewRecorder()
originalBody := []int{1, 2, 3}
res, _ := json.Marshal(originalBody)
writeJSON(w, res)
decodedBody := []int{}
err := json.NewDecoder(w.Body).Decode(&decodedBody)
if err != nil {
t.Fatal("Could not decode response body into slice.")
}
for i := range decodedBody {
if decodedBody[i] != originalBody[i] {
t.Fatalf("Expected %d but got %d in decoded body slice", originalBody[i], decodedBody[i])
}
}
}
func TestHealthCheckHandlerDefault(t *testing.T) {
endpointPath := "/healthcheck"
governor.AddHealthCheck("database", &SampleDatabaseCheck{})
governor.AddHealthCheck("cache", &SampleCacheCheck{})
req, err := http.NewRequest("GET", endpointPath, nil)
if err != nil {
t.Fatal(err)
}
w := httptest.NewRecorder()
handler := http.HandlerFunc(heathCheck)
handler.ServeHTTP(w, req)
if status := w.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
if !strings.Contains(w.Body.String(), "database") {
t.Errorf("Expected 'database' in generated template.")
}
}
func TestBuildHealthCheckResponseList(t *testing.T) {
healthCheckResults := [][]string{
[]string{
"error",
"Database",
"Error occured whie starting the db",
},
[]string{
"success",
"Cache",
"Cache started successfully",
},
}
responseList := buildHealthCheckResponseList(&healthCheckResults)
if len(responseList) != len(healthCheckResults) {
t.Errorf("invalid response map length: got %d want %d",
len(responseList), len(healthCheckResults))
}
responseFields := []string{"name", "message", "status"}
for _, response := range responseList {
for _, field := range responseFields {
_, ok := response[field]
if !ok {
t.Errorf("expected %s to be in the response %v", field, response)
}
}
}
}
func TestHealthCheckHandlerReturnsJSON(t *testing.T) {
governor.AddHealthCheck("database", &SampleDatabaseCheck{})
governor.AddHealthCheck("cache", &SampleCacheCheck{})
req, err := http.NewRequest("GET", "/healthcheck?json=true", nil)
if err != nil {
t.Fatal(err)
}
w := httptest.NewRecorder()
handler := http.HandlerFunc(heathCheck)
handler.ServeHTTP(w, req)
if status := w.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
decodedResponseBody := []map[string]interface{}{}
expectedResponseBody := []map[string]interface{}{}
expectedJSONString := []byte(`
[
{
"message":"database",
"name":"success",
"status":"OK"
},
{
"message":"cache",
"name":"error",
"status":"no cache detected"
}
]
`)
json.Unmarshal(expectedJSONString, &expectedResponseBody)
json.Unmarshal(w.Body.Bytes(), &decodedResponseBody)
if len(expectedResponseBody) != len(decodedResponseBody) {
t.Errorf("invalid response map length: got %d want %d",
len(decodedResponseBody), len(expectedResponseBody))
}
assert.Equal(t, len(expectedResponseBody), len(decodedResponseBody))
assert.Equal(t, 2, len(decodedResponseBody))
var database, cache map[string]interface{}
if decodedResponseBody[0]["message"] == "database" {
database = decodedResponseBody[0]
cache = decodedResponseBody[1]
} else {
database = decodedResponseBody[1]
cache = decodedResponseBody[0]
}
assert.Equal(t, expectedResponseBody[0], database)
assert.Equal(t, expectedResponseBody[1], cache)
}

356
server/web/adminui.go Normal file

File diff suppressed because one or more lines are too long

103
server/web/beego.go Normal file
View File

@ -0,0 +1,103 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"os"
"path/filepath"
"sync"
)
const (
// DEV is for develop
DEV = "dev"
// PROD is for production
PROD = "prod"
)
// M is Map shortcut
type M map[string]interface{}
// Hook function to run
type hookfunc func() error
var (
hooks = make([]hookfunc, 0) // hook function slice to store the hookfunc
)
// AddAPPStartHook is used to register the hookfunc
// The hookfuncs will run in beego.Run()
// such as initiating session , starting middleware , building template, starting admin control and so on.
func AddAPPStartHook(hf ...hookfunc) {
hooks = append(hooks, hf...)
}
// Run beego application.
// beego.Run() default run on HttpPort
// beego.Run("localhost")
// beego.Run(":8089")
// beego.Run("127.0.0.1:8089")
func Run(params ...string) {
if len(params) > 0 && params[0] != "" {
BeeApp.Run(params[0])
}
BeeApp.Run("")
}
// RunWithMiddleWares Run beego application with middlewares.
func RunWithMiddleWares(addr string, mws ...MiddleWare) {
BeeApp.Run(addr, mws...)
}
var initHttpOnce sync.Once
// TODO move to module init function
func initBeforeHTTPRun() {
initHttpOnce.Do(func() {
// init hooks
AddAPPStartHook(
registerMime,
registerDefaultErrorHandler,
registerSession,
registerTemplate,
registerAdmin,
registerGzip,
registerCommentRouter,
)
for _, hk := range hooks {
if err := hk(); err != nil {
panic(err)
}
}
})
}
// TestBeegoInit is for test package init
func TestBeegoInit(ap string) {
path := filepath.Join(ap, "conf", "app.conf")
os.Chdir(ap)
InitBeegoBeforeTest(path)
}
// InitBeegoBeforeTest is for test package init
func InitBeegoBeforeTest(appConfigPath string) {
if err := LoadAppConfig(appConfigProvider, appConfigPath); err != nil {
panic(err)
}
BConfig.RunMode = "test"
initBeforeHTTPRun()
}

View File

@ -0,0 +1,19 @@
Copyright (c) 2011-2014 Dmitry Chestnykh <dmitry@codingrobots.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View File

@ -0,0 +1,45 @@
# Captcha
an example for use captcha
```
package controllers
import (
"github.com/astaxie/beego"
"github.com/astaxie/beego/cache"
"github.com/astaxie/beego/utils/captcha"
)
var cpt *captcha.Captcha
func init() {
// use beego cache system store the captcha data
store := cache.NewMemoryCache()
cpt = captcha.NewWithFilter("/captcha/", store)
}
type MainController struct {
beego.Controller
}
func (this *MainController) Get() {
this.TplName = "index.tpl"
}
func (this *MainController) Post() {
this.TplName = "index.tpl"
this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request)
}
```
template usage
```
{{.Success}}
<form action="/" method="post">
{{create_captcha}}
<input name="captcha" type="text">
</form>
```

View File

@ -0,0 +1,281 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package captcha implements generation and verification of image CAPTCHAs.
// an example for use captcha
//
// ```
// package controllers
//
// import (
// "github.com/astaxie/beego"
// "github.com/astaxie/beego/cache"
// "github.com/astaxie/beego/utils/captcha"
// )
//
// var cpt *captcha.Captcha
//
// func init() {
// // use beego cache system store the captcha data
// store := cache.NewMemoryCache()
// cpt = captcha.NewWithFilter("/captcha/", store)
// }
//
// type MainController struct {
// beego.Controller
// }
//
// func (this *MainController) Get() {
// this.TplName = "index.tpl"
// }
//
// func (this *MainController) Post() {
// this.TplName = "index.tpl"
//
// this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request)
// }
// ```
//
// template usage
//
// ```
// {{.Success}}
// <form action="/" method="post">
// {{create_captcha}}
// <input name="captcha" type="text">
// </form>
// ```
package captcha
import (
context2 "context"
"fmt"
"html/template"
"net/http"
"path"
"strings"
"time"
"github.com/astaxie/beego/core/logs"
"github.com/astaxie/beego/core/utils"
"github.com/astaxie/beego/server/web"
"github.com/astaxie/beego/server/web/context"
)
var (
defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
)
const (
// default captcha attributes
challengeNums = 6
expiration = 600 * time.Second
fieldIDName = "captcha_id"
fieldCaptchaName = "captcha"
cachePrefix = "captcha_"
defaultURLPrefix = "/captcha/"
)
// Captcha struct
type Captcha struct {
// beego cache store
store Storage
// url prefix for captcha image
URLPrefix string
// specify captcha id input field name
FieldIDName string
// specify captcha result input field name
FieldCaptchaName string
// captcha image width and height
StdWidth int
StdHeight int
// captcha chars nums
ChallengeNums int
// captcha expiration seconds
Expiration time.Duration
// cache key prefix
CachePrefix string
}
// generate key string
func (c *Captcha) key(id string) string {
return c.CachePrefix + id
}
// generate rand chars with default chars
func (c *Captcha) genRandChars() []byte {
return utils.RandomCreateBytes(c.ChallengeNums, defaultChars...)
}
// Handler beego filter handler for serve captcha image
func (c *Captcha) Handler(ctx *context.Context) {
var chars []byte
id := path.Base(ctx.Request.RequestURI)
if i := strings.Index(id, "."); i != -1 {
id = id[:i]
}
key := c.key(id)
if len(ctx.Input.Query("reload")) > 0 {
chars = c.genRandChars()
if err := c.store.Put(context2.Background(), key, chars, c.Expiration); err != nil {
ctx.Output.SetStatus(500)
ctx.WriteString("captcha reload error")
logs.Error("Reload Create Captcha Error:", err)
return
}
} else {
val, _ := c.store.Get(context2.Background(), key)
if v, ok := val.([]byte); ok {
chars = v
} else {
ctx.Output.SetStatus(404)
ctx.WriteString("captcha not found")
return
}
}
img := NewImage(chars, c.StdWidth, c.StdHeight)
if _, err := img.WriteTo(ctx.ResponseWriter); err != nil {
logs.Error("Write Captcha Image Error:", err)
}
}
// CreateCaptchaHTML template func for output html
func (c *Captcha) CreateCaptchaHTML() template.HTML {
value, err := c.CreateCaptcha()
if err != nil {
logs.Error("Create Captcha Error:", err)
return ""
}
// create html
return template.HTML(fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`+
`<a class="captcha" href="javascript:">`+
`<img onclick="this.src=('%s%s.png?reload='+(new Date()).getTime())" class="captcha-img" src="%s%s.png">`+
`</a>`, c.FieldIDName, value, c.URLPrefix, value, c.URLPrefix, value))
}
// CreateCaptcha create a new captcha id
func (c *Captcha) CreateCaptcha() (string, error) {
// generate captcha id
id := string(utils.RandomCreateBytes(15))
// get the captcha chars
chars := c.genRandChars()
// save to store
if err := c.store.Put(context2.Background(), c.key(id), chars, c.Expiration); err != nil {
return "", err
}
return id, nil
}
// VerifyReq verify from a request
func (c *Captcha) VerifyReq(req *http.Request) bool {
req.ParseForm()
return c.Verify(req.Form.Get(c.FieldIDName), req.Form.Get(c.FieldCaptchaName))
}
// Verify direct verify id and challenge string
func (c *Captcha) Verify(id string, challenge string) (success bool) {
if len(challenge) == 0 || len(id) == 0 {
return
}
var chars []byte
key := c.key(id)
val, _ := c.store.Get(context2.Background(), key)
if v, ok := val.([]byte); ok {
chars = v
} else {
return
}
defer func() {
// finally remove it
c.store.Delete(context2.Background(), key)
}()
if len(chars) != len(challenge) {
return
}
// verify challenge
for i, c := range chars {
if c != challenge[i]-48 {
return
}
}
return true
}
// NewCaptcha create a new captcha.Captcha
func NewCaptcha(urlPrefix string, store Storage) *Captcha {
cpt := &Captcha{}
cpt.store = store
cpt.FieldIDName = fieldIDName
cpt.FieldCaptchaName = fieldCaptchaName
cpt.ChallengeNums = challengeNums
cpt.Expiration = expiration
cpt.CachePrefix = cachePrefix
cpt.StdWidth = stdWidth
cpt.StdHeight = stdHeight
if len(urlPrefix) == 0 {
urlPrefix = defaultURLPrefix
}
if urlPrefix[len(urlPrefix)-1] != '/' {
urlPrefix += "/"
}
cpt.URLPrefix = urlPrefix
return cpt
}
// NewWithFilter create a new captcha.Captcha and auto AddFilter for serve captacha image
// and add a template func for output html
func NewWithFilter(urlPrefix string, store Storage) *Captcha {
cpt := NewCaptcha(urlPrefix, store)
// create filter for serve captcha image
web.InsertFilter(cpt.URLPrefix+"*", web.BeforeRouter, cpt.Handler)
// add to template func map
web.AddFuncMap("create_captcha", cpt.CreateCaptchaHTML)
return cpt
}
type Storage interface {
// Get a cached value by key.
Get(ctx context2.Context, key string) (interface{}, error)
// Set a cached value with key and expire time.
Put(ctx context2.Context, key string, val interface{}, timeout time.Duration) error
// Delete cached value by key.
Delete(ctx context2.Context, key string) error
}

501
server/web/captcha/image.go Normal file
View File

@ -0,0 +1,501 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package captcha
import (
"bytes"
"image"
"image/color"
"image/png"
"io"
"math"
)
const (
fontWidth = 11
fontHeight = 18
blackChar = 1
// Standard width and height of a captcha image.
stdWidth = 240
stdHeight = 80
// Maximum absolute skew factor of a single digit.
maxSkew = 0.7
// Number of background circles.
circleCount = 20
)
var font = [][]byte{
{ // 0
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1,
0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
},
{ // 1
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0,
0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
},
{ // 2
0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0,
0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
},
{ // 3
0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
},
{ // 4
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0,
0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0,
0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0,
0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0,
0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0,
0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0,
1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0,
1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
},
{ // 5
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0,
},
{ // 6
0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0,
0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0,
0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0,
1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0,
1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0,
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1,
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
},
{ // 7
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0,
0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
},
{ // 8
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1,
0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0,
0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0,
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
},
{ // 9
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0,
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1,
0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1,
0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,
0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
},
}
// Image struct
type Image struct {
*image.Paletted
numWidth int
numHeight int
dotSize int
}
var prng = &siprng{}
// randIntn returns a pseudorandom non-negative int in range [0, n).
func randIntn(n int) int {
return prng.Intn(n)
}
// randInt returns a pseudorandom int in range [from, to].
func randInt(from, to int) int {
return prng.Intn(to+1-from) + from
}
// randFloat returns a pseudorandom float64 in range [from, to].
func randFloat(from, to float64) float64 {
return (to-from)*prng.Float64() + from
}
func randomPalette() color.Palette {
p := make([]color.Color, circleCount+1)
// Transparent color.
p[0] = color.RGBA{0xFF, 0xFF, 0xFF, 0x00}
// Primary color.
prim := color.RGBA{
uint8(randIntn(129)),
uint8(randIntn(129)),
uint8(randIntn(129)),
0xFF,
}
p[1] = prim
// Circle colors.
for i := 2; i <= circleCount; i++ {
p[i] = randomBrightness(prim, 255)
}
return p
}
// NewImage returns a new captcha image of the given width and height with the
// given digits, where each digit must be in range 0-9.
func NewImage(digits []byte, width, height int) *Image {
m := new(Image)
m.Paletted = image.NewPaletted(image.Rect(0, 0, width, height), randomPalette())
m.calculateSizes(width, height, len(digits))
// Randomly position captcha inside the image.
maxx := width - (m.numWidth+m.dotSize)*len(digits) - m.dotSize
maxy := height - m.numHeight - m.dotSize*2
var border int
if width > height {
border = height / 5
} else {
border = width / 5
}
x := randInt(border, maxx-border)
y := randInt(border, maxy-border)
// Draw digits.
for _, n := range digits {
m.drawDigit(font[n], x, y)
x += m.numWidth + m.dotSize
}
// Draw strike-through line.
m.strikeThrough()
// Apply wave distortion.
m.distort(randFloat(5, 10), randFloat(100, 200))
// Fill image with random circles.
m.fillWithCircles(circleCount, m.dotSize)
return m
}
// encodedPNG encodes an image to PNG and returns
// the result as a byte slice.
func (m *Image) encodedPNG() []byte {
var buf bytes.Buffer
if err := png.Encode(&buf, m.Paletted); err != nil {
panic(err.Error())
}
return buf.Bytes()
}
// WriteTo writes captcha image in PNG format into the given writer.
func (m *Image) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(m.encodedPNG())
return int64(n), err
}
func (m *Image) calculateSizes(width, height, ncount int) {
// Goal: fit all digits inside the image.
var border int
if width > height {
border = height / 4
} else {
border = width / 4
}
// Convert everything to floats for calculations.
w := float64(width - border*2)
h := float64(height - border*2)
// fw takes into account 1-dot spacing between digits.
fw := float64(fontWidth + 1)
fh := float64(fontHeight)
nc := float64(ncount)
// Calculate the width of a single digit taking into account only the
// width of the image.
nw := w / nc
// Calculate the height of a digit from this width.
nh := nw * fh / fw
// Digit too high?
if nh > h {
// Fit digits based on height.
nh = h
nw = fw / fh * nh
}
// Calculate dot size.
m.dotSize = int(nh / fh)
if m.dotSize < 1 {
m.dotSize = 1
}
// Save everything, making the actual width smaller by 1 dot to account
// for spacing between digits.
m.numWidth = int(nw) - m.dotSize
m.numHeight = int(nh)
}
func (m *Image) drawHorizLine(fromX, toX, y int, colorIdx uint8) {
for x := fromX; x <= toX; x++ {
m.SetColorIndex(x, y, colorIdx)
}
}
func (m *Image) drawCircle(x, y, radius int, colorIdx uint8) {
f := 1 - radius
dfx := 1
dfy := -2 * radius
xo := 0
yo := radius
m.SetColorIndex(x, y+radius, colorIdx)
m.SetColorIndex(x, y-radius, colorIdx)
m.drawHorizLine(x-radius, x+radius, y, colorIdx)
for xo < yo {
if f >= 0 {
yo--
dfy += 2
f += dfy
}
xo++
dfx += 2
f += dfx
m.drawHorizLine(x-xo, x+xo, y+yo, colorIdx)
m.drawHorizLine(x-xo, x+xo, y-yo, colorIdx)
m.drawHorizLine(x-yo, x+yo, y+xo, colorIdx)
m.drawHorizLine(x-yo, x+yo, y-xo, colorIdx)
}
}
func (m *Image) fillWithCircles(n, maxradius int) {
maxx := m.Bounds().Max.X
maxy := m.Bounds().Max.Y
for i := 0; i < n; i++ {
colorIdx := uint8(randInt(1, circleCount-1))
r := randInt(1, maxradius)
m.drawCircle(randInt(r, maxx-r), randInt(r, maxy-r), r, colorIdx)
}
}
func (m *Image) strikeThrough() {
maxx := m.Bounds().Max.X
maxy := m.Bounds().Max.Y
y := randInt(maxy/3, maxy-maxy/3)
amplitude := randFloat(5, 20)
period := randFloat(80, 180)
dx := 2.0 * math.Pi / period
for x := 0; x < maxx; x++ {
xo := amplitude * math.Cos(float64(y)*dx)
yo := amplitude * math.Sin(float64(x)*dx)
for yn := 0; yn < m.dotSize; yn++ {
r := randInt(0, m.dotSize)
m.drawCircle(x+int(xo), y+int(yo)+(yn*m.dotSize), r/2, 1)
}
}
}
func (m *Image) drawDigit(digit []byte, x, y int) {
skf := randFloat(-maxSkew, maxSkew)
xs := float64(x)
r := m.dotSize / 2
y += randInt(-r, r)
for yo := 0; yo < fontHeight; yo++ {
for xo := 0; xo < fontWidth; xo++ {
if digit[yo*fontWidth+xo] != blackChar {
continue
}
m.drawCircle(x+xo*m.dotSize, y+yo*m.dotSize, r, 1)
}
xs += skf
x = int(xs)
}
}
func (m *Image) distort(amplude float64, period float64) {
w := m.Bounds().Max.X
h := m.Bounds().Max.Y
oldm := m.Paletted
newm := image.NewPaletted(image.Rect(0, 0, w, h), oldm.Palette)
dx := 2.0 * math.Pi / period
for x := 0; x < w; x++ {
for y := 0; y < h; y++ {
xo := amplude * math.Sin(float64(y)*dx)
yo := amplude * math.Cos(float64(x)*dx)
newm.SetColorIndex(x, y, oldm.ColorIndexAt(x+int(xo), y+int(yo)))
}
}
m.Paletted = newm
}
func randomBrightness(c color.RGBA, max uint8) color.RGBA {
minc := min3(c.R, c.G, c.B)
maxc := max3(c.R, c.G, c.B)
if maxc > max {
return c
}
n := randIntn(int(max-maxc)) - int(minc)
return color.RGBA{
uint8(int(c.R) + n),
uint8(int(c.G) + n),
uint8(int(c.B) + n),
c.A,
}
}
func min3(x, y, z uint8) (m uint8) {
m = x
if y < m {
m = y
}
if z < m {
m = z
}
return
}
func max3(x, y, z uint8) (m uint8) {
m = x
if y > m {
m = y
}
if z > m {
m = z
}
return
}

View File

@ -0,0 +1,52 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package captcha
import (
"testing"
"github.com/astaxie/beego/core/utils"
)
type byteCounter struct {
n int64
}
func (bc *byteCounter) Write(b []byte) (int, error) {
bc.n += int64(len(b))
return len(b), nil
}
func BenchmarkNewImage(b *testing.B) {
b.StopTimer()
d := utils.RandomCreateBytes(challengeNums, defaultChars...)
b.StartTimer()
for i := 0; i < b.N; i++ {
NewImage(d, stdWidth, stdHeight)
}
}
func BenchmarkImageWriteTo(b *testing.B) {
b.StopTimer()
d := utils.RandomCreateBytes(challengeNums, defaultChars...)
b.StartTimer()
counter := &byteCounter{}
for i := 0; i < b.N; i++ {
img := NewImage(d, stdWidth, stdHeight)
img.WriteTo(counter)
b.SetBytes(counter.n)
counter.n = 0
}
}

View File

@ -0,0 +1,277 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package captcha
import (
"crypto/rand"
"encoding/binary"
"io"
"sync"
)
// siprng is PRNG based on SipHash-2-4.
type siprng struct {
mu sync.Mutex
k0, k1, ctr uint64
}
// siphash implements SipHash-2-4, accepting a uint64 as a message.
func siphash(k0, k1, m uint64) uint64 {
// Initialization.
v0 := k0 ^ 0x736f6d6570736575
v1 := k1 ^ 0x646f72616e646f6d
v2 := k0 ^ 0x6c7967656e657261
v3 := k1 ^ 0x7465646279746573
t := uint64(8) << 56
// Compression.
v3 ^= m
// Round 1.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
// Round 2.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
v0 ^= m
// Compress last block.
v3 ^= t
// Round 1.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
// Round 2.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
v0 ^= t
// Finalization.
v2 ^= 0xff
// Round 1.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
// Round 2.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
// Round 3.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
// Round 4.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
return v0 ^ v1 ^ v2 ^ v3
}
// rekey sets a new PRNG key, which is read from crypto/rand.
func (p *siprng) rekey() {
var k [16]byte
if _, err := io.ReadFull(rand.Reader, k[:]); err != nil {
panic(err.Error())
}
p.k0 = binary.LittleEndian.Uint64(k[0:8])
p.k1 = binary.LittleEndian.Uint64(k[8:16])
p.ctr = 1
}
// Uint64 returns a new pseudorandom uint64.
// It rekeys PRNG on the first call and every 64 MB of generated data.
func (p *siprng) Uint64() uint64 {
p.mu.Lock()
if p.ctr == 0 || p.ctr > 8*1024*1024 {
p.rekey()
}
v := siphash(p.k0, p.k1, p.ctr)
p.ctr++
p.mu.Unlock()
return v
}
func (p *siprng) Int63() int64 {
return int64(p.Uint64() & 0x7fffffffffffffff)
}
func (p *siprng) Uint32() uint32 {
return uint32(p.Uint64())
}
func (p *siprng) Int31() int32 {
return int32(p.Uint32() & 0x7fffffff)
}
func (p *siprng) Intn(n int) int {
if n <= 0 {
panic("invalid argument to Intn")
}
if n <= 1<<31-1 {
return int(p.Int31n(int32(n)))
}
return int(p.Int63n(int64(n)))
}
func (p *siprng) Int63n(n int64) int64 {
if n <= 0 {
panic("invalid argument to Int63n")
}
max := int64((1 << 63) - 1 - (1<<63)%uint64(n))
v := p.Int63()
for v > max {
v = p.Int63()
}
return v % n
}
func (p *siprng) Int31n(n int32) int32 {
if n <= 0 {
panic("invalid argument to Int31n")
}
max := int32((1 << 31) - 1 - (1<<31)%uint32(n))
v := p.Int31()
for v > max {
v = p.Int31()
}
return v % n
}
func (p *siprng) Float64() float64 { return float64(p.Int63()) / (1 << 63) }

View File

@ -0,0 +1,33 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package captcha
import "testing"
func TestSiphash(t *testing.T) {
good := uint64(0xe849e8bb6ffe2567)
cur := siphash(0, 0, 0)
if cur != good {
t.Fatalf("siphash: expected %x, got %x", good, cur)
}
}
func BenchmarkSiprng(b *testing.B) {
b.SetBytes(8)
p := &siprng{}
for i := 0; i < b.N; i++ {
p.Uint64()
}
}

537
server/web/config.go Normal file
View File

@ -0,0 +1,537 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
context2 "context"
"crypto/tls"
"fmt"
"os"
"path/filepath"
"reflect"
"runtime"
"strings"
"github.com/astaxie/beego"
"github.com/astaxie/beego/core/config"
"github.com/astaxie/beego/core/logs"
"github.com/astaxie/beego/server/web/session"
"github.com/astaxie/beego/core/utils"
"github.com/astaxie/beego/server/web/context"
)
// Config is the main struct for BConfig
// TODO after supporting multiple servers, remove common config to somewhere else
type Config struct {
AppName string // Application name
RunMode string // Running Mode: dev | prod
RouterCaseSensitive bool
ServerName string
RecoverPanic bool
RecoverFunc func(*context.Context, *Config)
CopyRequestBody bool
EnableGzip bool
MaxMemory int64
EnableErrorsShow bool
EnableErrorsRender bool
Listen Listen
WebConfig WebConfig
Log LogConfig
}
// Listen holds for http and https related config
type Listen struct {
Graceful bool // Graceful means use graceful module to start the server
ServerTimeOut int64
ListenTCP4 bool
EnableHTTP bool
HTTPAddr string
HTTPPort int
AutoTLS bool
Domains []string
TLSCacheDir string
EnableHTTPS bool
EnableMutualHTTPS bool
HTTPSAddr string
HTTPSPort int
HTTPSCertFile string
HTTPSKeyFile string
TrustCaFile string
EnableAdmin bool
AdminAddr string
AdminPort int
EnableFcgi bool
EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O
ClientAuth int
}
// WebConfig holds web related config
type WebConfig struct {
AutoRender bool
EnableDocs bool
FlashName string
FlashSeparator string
DirectoryIndex bool
StaticDir map[string]string
StaticExtensionsToGzip []string
StaticCacheFileSize int
StaticCacheFileNum int
TemplateLeft string
TemplateRight string
ViewsPath string
CommentRouterPath string
EnableXSRF bool
XSRFKey string
XSRFExpire int
Session SessionConfig
}
// SessionConfig holds session related config
type SessionConfig struct {
SessionOn bool
SessionProvider string
SessionName string
SessionGCMaxLifetime int64
SessionProviderConfig string
SessionCookieLifeTime int
SessionAutoSetCookie bool
SessionDomain string
SessionDisableHTTPOnly bool // used to allow for cross domain cookies/javascript cookies.
SessionEnableSidInHTTPHeader bool // enable store/get the sessionId into/from http headers
SessionNameInHTTPHeader string
SessionEnableSidInURLQuery bool // enable get the sessionId from Url Query params
}
// LogConfig holds Log related config
type LogConfig struct {
AccessLogs bool
EnableStaticLogs bool // log static files requests default: false
AccessLogsFormat string // access log format: JSON_FORMAT, APACHE_FORMAT or empty string
FileLineNum bool
Outputs map[string]string // Store Adaptor : config
}
var (
// BConfig is the default config for Application
BConfig *Config
// AppConfig is the instance of Config, store the config information from file
AppConfig *beegoAppConfig
// AppPath is the absolute path to the app
AppPath string
// GlobalSessions is the instance for the session manager
GlobalSessions *session.Manager
// appConfigPath is the path to the config files
appConfigPath string
// appConfigProvider is the provider for the config, default is ini
appConfigProvider = "ini"
// WorkPath is the absolute path to project root directory
WorkPath string
)
func init() {
BConfig = newBConfig()
var err error
if AppPath, err = filepath.Abs(filepath.Dir(os.Args[0])); err != nil {
panic(err)
}
WorkPath, err = os.Getwd()
if err != nil {
panic(err)
}
var filename = "app.conf"
if os.Getenv("BEEGO_RUNMODE") != "" {
filename = os.Getenv("BEEGO_RUNMODE") + ".app.conf"
}
appConfigPath = filepath.Join(WorkPath, "conf", filename)
if !utils.FileExists(appConfigPath) {
appConfigPath = filepath.Join(AppPath, "conf", filename)
if !utils.FileExists(appConfigPath) {
AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()}
return
}
}
if err = parseConfig(appConfigPath); err != nil {
panic(err)
}
}
func defaultRecoverPanic(ctx *context.Context, cfg *Config) {
if err := recover(); err != nil {
if err == ErrAbort {
return
}
if !cfg.RecoverPanic {
panic(err)
}
if cfg.EnableErrorsShow {
if _, ok := ErrorMaps[fmt.Sprint(err)]; ok {
exception(fmt.Sprint(err), ctx)
return
}
}
var stack string
logs.Critical("the request url is ", ctx.Input.URL())
logs.Critical("Handler crashed with error", err)
for i := 1; ; i++ {
_, file, line, ok := runtime.Caller(i)
if !ok {
break
}
logs.Critical(fmt.Sprintf("%s:%d", file, line))
stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line))
}
if cfg.RunMode == DEV && cfg.EnableErrorsRender {
showErr(err, ctx, stack)
}
if ctx.Output.Status != 0 {
ctx.ResponseWriter.WriteHeader(ctx.Output.Status)
} else {
ctx.ResponseWriter.WriteHeader(500)
}
}
}
func newBConfig() *Config {
res := &Config{
AppName: "beego",
RunMode: PROD,
RouterCaseSensitive: true,
ServerName: "beegoServer:" + beego.VERSION,
RecoverPanic: true,
CopyRequestBody: false,
EnableGzip: false,
MaxMemory: 1 << 26, // 64MB
EnableErrorsShow: true,
EnableErrorsRender: true,
Listen: Listen{
Graceful: false,
ServerTimeOut: 0,
ListenTCP4: false,
EnableHTTP: true,
AutoTLS: false,
Domains: []string{},
TLSCacheDir: ".",
HTTPAddr: "",
HTTPPort: 8080,
EnableHTTPS: false,
HTTPSAddr: "",
HTTPSPort: 10443,
HTTPSCertFile: "",
HTTPSKeyFile: "",
EnableAdmin: false,
AdminAddr: "",
AdminPort: 8088,
EnableFcgi: false,
EnableStdIo: false,
ClientAuth: int(tls.RequireAndVerifyClientCert),
},
WebConfig: WebConfig{
AutoRender: true,
EnableDocs: false,
FlashName: "BEEGO_FLASH",
FlashSeparator: "BEEGOFLASH",
DirectoryIndex: false,
StaticDir: map[string]string{"/static": "static"},
StaticExtensionsToGzip: []string{".css", ".js"},
StaticCacheFileSize: 1024 * 100,
StaticCacheFileNum: 1000,
TemplateLeft: "{{",
TemplateRight: "}}",
ViewsPath: "views",
CommentRouterPath: "controllers",
EnableXSRF: false,
XSRFKey: "beegoxsrf",
XSRFExpire: 0,
Session: SessionConfig{
SessionOn: false,
SessionProvider: "memory",
SessionName: "beegosessionID",
SessionGCMaxLifetime: 3600,
SessionProviderConfig: "",
SessionDisableHTTPOnly: false,
SessionCookieLifeTime: 0, // set cookie default is the browser life
SessionAutoSetCookie: true,
SessionDomain: "",
SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers
SessionNameInHTTPHeader: "Beegosessionid",
SessionEnableSidInURLQuery: false, // enable get the sessionId from Url Query params
},
},
Log: LogConfig{
AccessLogs: false,
EnableStaticLogs: false,
AccessLogsFormat: "APACHE_FORMAT",
FileLineNum: true,
Outputs: map[string]string{"console": ""},
},
}
res.RecoverFunc = defaultRecoverPanic
return res
}
// now only support ini, next will support json.
func parseConfig(appConfigPath string) (err error) {
AppConfig, err = newAppConfig(appConfigProvider, appConfigPath)
if err != nil {
return err
}
return assignConfig(AppConfig)
}
func assignConfig(ac config.Configer) error {
for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} {
assignSingleConfig(i, ac)
}
// set the run mode first
if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" {
BConfig.RunMode = envRunMode
} else if runMode, err := ac.String(nil, "RunMode"); runMode != "" && err == nil {
BConfig.RunMode = runMode
}
if sd, err := ac.String(nil, "StaticDir"); sd != "" && err == nil {
BConfig.WebConfig.StaticDir = map[string]string{}
sds := strings.Fields(sd)
for _, v := range sds {
if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 {
BConfig.WebConfig.StaticDir["/"+strings.Trim(url2fsmap[0], "/")] = url2fsmap[1]
} else {
BConfig.WebConfig.StaticDir["/"+strings.Trim(url2fsmap[0], "/")] = url2fsmap[0]
}
}
}
if sgz, err := ac.String(nil, "StaticExtensionsToGzip"); sgz != "" && err == nil {
extensions := strings.Split(sgz, ",")
fileExts := []string{}
for _, ext := range extensions {
ext = strings.TrimSpace(ext)
if ext == "" {
continue
}
if !strings.HasPrefix(ext, ".") {
ext = "." + ext
}
fileExts = append(fileExts, ext)
}
if len(fileExts) > 0 {
BConfig.WebConfig.StaticExtensionsToGzip = fileExts
}
}
if sfs, err := ac.Int(nil, "StaticCacheFileSize"); err == nil {
BConfig.WebConfig.StaticCacheFileSize = sfs
}
if sfn, err := ac.Int(nil, "StaticCacheFileNum"); err == nil {
BConfig.WebConfig.StaticCacheFileNum = sfn
}
if lo, err := ac.String(nil, "LogOutputs"); lo != "" && err == nil {
// if lo is not nil or empty
// means user has set his own LogOutputs
// clear the default setting to BConfig.Log.Outputs
BConfig.Log.Outputs = make(map[string]string)
los := strings.Split(lo, ";")
for _, v := range los {
if logType2Config := strings.SplitN(v, ",", 2); len(logType2Config) == 2 {
BConfig.Log.Outputs[logType2Config[0]] = logType2Config[1]
} else {
continue
}
}
}
// init log
logs.Reset()
for adaptor, config := range BConfig.Log.Outputs {
err := logs.SetLogger(adaptor, config)
if err != nil {
fmt.Fprintln(os.Stderr, fmt.Sprintf("%s with the config %q got err:%s", adaptor, config, err.Error()))
}
}
logs.SetLogFuncCall(BConfig.Log.FileLineNum)
return nil
}
func assignSingleConfig(p interface{}, ac config.Configer) {
pt := reflect.TypeOf(p)
if pt.Kind() != reflect.Ptr {
return
}
pt = pt.Elem()
if pt.Kind() != reflect.Struct {
return
}
pv := reflect.ValueOf(p).Elem()
for i := 0; i < pt.NumField(); i++ {
pf := pv.Field(i)
if !pf.CanSet() {
continue
}
name := pt.Field(i).Name
switch pf.Kind() {
case reflect.String:
pf.SetString(ac.DefaultString(nil, name, pf.String()))
case reflect.Int, reflect.Int64:
pf.SetInt(ac.DefaultInt64(nil, name, pf.Int()))
case reflect.Bool:
pf.SetBool(ac.DefaultBool(nil, name, pf.Bool()))
case reflect.Struct:
default:
// do nothing here
}
}
}
// LoadAppConfig allow developer to apply a config file
func LoadAppConfig(adapterName, configPath string) error {
absConfigPath, err := filepath.Abs(configPath)
if err != nil {
return err
}
if !utils.FileExists(absConfigPath) {
return fmt.Errorf("the target config file: %s don't exist", configPath)
}
appConfigPath = absConfigPath
appConfigProvider = adapterName
return parseConfig(appConfigPath)
}
type beegoAppConfig struct {
config.BaseConfiger
innerConfig config.Configer
}
func newAppConfig(appConfigProvider, appConfigPath string) (*beegoAppConfig, error) {
ac, err := config.NewConfig(appConfigProvider, appConfigPath)
if err != nil {
return nil, err
}
return &beegoAppConfig{innerConfig: ac}, nil
}
func (b *beegoAppConfig) Set(ctx context2.Context, key, val string) error {
if err := b.innerConfig.Set(nil, BConfig.RunMode+"::"+key, val); err != nil {
return b.innerConfig.Set(nil, key, val)
}
return nil
}
func (b *beegoAppConfig) String(ctx context2.Context, key string) (string, error) {
if v, err := b.innerConfig.String(nil, BConfig.RunMode+"::"+key); v != "" && err == nil {
return v, nil
}
return b.innerConfig.String(nil, key)
}
func (b *beegoAppConfig) Strings(ctx context2.Context, key string) ([]string, error) {
if v, err := b.innerConfig.Strings(nil, BConfig.RunMode+"::"+key); len(v) > 0 && err == nil {
return v, nil
}
return b.innerConfig.Strings(nil, key)
}
func (b *beegoAppConfig) Int(ctx context2.Context, key string) (int, error) {
if v, err := b.innerConfig.Int(nil, BConfig.RunMode+"::"+key); err == nil {
return v, nil
}
return b.innerConfig.Int(nil, key)
}
func (b *beegoAppConfig) Int64(ctx context2.Context, key string) (int64, error) {
if v, err := b.innerConfig.Int64(nil, BConfig.RunMode+"::"+key); err == nil {
return v, nil
}
return b.innerConfig.Int64(nil, key)
}
func (b *beegoAppConfig) Bool(ctx context2.Context, key string) (bool, error) {
if v, err := b.innerConfig.Bool(nil, BConfig.RunMode+"::"+key); err == nil {
return v, nil
}
return b.innerConfig.Bool(nil, key)
}
func (b *beegoAppConfig) Float(ctx context2.Context, key string) (float64, error) {
if v, err := b.innerConfig.Float(nil, BConfig.RunMode+"::"+key); err == nil {
return v, nil
}
return b.innerConfig.Float(nil, key)
}
func (b *beegoAppConfig) DefaultString(ctx context2.Context, key string, defaultVal string) string {
if v, err := b.String(nil, key); v != "" && err == nil {
return v
}
return defaultVal
}
func (b *beegoAppConfig) DefaultStrings(ctx context2.Context, key string, defaultVal []string) []string {
if v, err := b.Strings(ctx, key); len(v) != 0 && err == nil {
return v
}
return defaultVal
}
func (b *beegoAppConfig) DefaultInt(ctx context2.Context, key string, defaultVal int) int {
if v, err := b.Int(ctx, key); err == nil {
return v
}
return defaultVal
}
func (b *beegoAppConfig) DefaultInt64(ctx context2.Context, key string, defaultVal int64) int64 {
if v, err := b.Int64(ctx, key); err == nil {
return v
}
return defaultVal
}
func (b *beegoAppConfig) DefaultBool(ctx context2.Context, key string, defaultVal bool) bool {
if v, err := b.Bool(ctx, key); err == nil {
return v
}
return defaultVal
}
func (b *beegoAppConfig) DefaultFloat(ctx context2.Context, key string, defaultVal float64) float64 {
if v, err := b.Float(ctx, key); err == nil {
return v
}
return defaultVal
}
func (b *beegoAppConfig) DIY(ctx context2.Context, key string) (interface{}, error) {
return b.innerConfig.DIY(nil, key)
}
func (b *beegoAppConfig) GetSection(ctx context2.Context, section string) (map[string]string, error) {
return b.innerConfig.GetSection(nil, section)
}
func (b *beegoAppConfig) SaveConfigFile(ctx context2.Context, filename string) error {
return b.innerConfig.SaveConfigFile(nil, filename)
}

146
server/web/config_test.go Normal file
View File

@ -0,0 +1,146 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"encoding/json"
"reflect"
"testing"
beeJson "github.com/astaxie/beego/core/config/json"
)
func TestDefaults(t *testing.T) {
if BConfig.WebConfig.FlashName != "BEEGO_FLASH" {
t.Errorf("FlashName was not set to default.")
}
if BConfig.WebConfig.FlashSeparator != "BEEGOFLASH" {
t.Errorf("FlashName was not set to default.")
}
}
func TestAssignConfig_01(t *testing.T) {
_BConfig := &Config{}
_BConfig.AppName = "beego_test"
jcf := &beeJson.JSONConfig{}
ac, _ := jcf.ParseData([]byte(`{"AppName":"beego_json"}`))
assignSingleConfig(_BConfig, ac)
if _BConfig.AppName != "beego_json" {
t.Log(_BConfig)
t.FailNow()
}
}
func TestAssignConfig_02(t *testing.T) {
_BConfig := &Config{}
bs, _ := json.Marshal(newBConfig())
jsonMap := M{}
json.Unmarshal(bs, &jsonMap)
configMap := M{}
for k, v := range jsonMap {
if reflect.TypeOf(v).Kind() == reflect.Map {
for k1, v1 := range v.(M) {
if reflect.TypeOf(v1).Kind() == reflect.Map {
for k2, v2 := range v1.(M) {
configMap[k2] = v2
}
} else {
configMap[k1] = v1
}
}
} else {
configMap[k] = v
}
}
configMap["MaxMemory"] = 1024
configMap["Graceful"] = true
configMap["XSRFExpire"] = 32
configMap["SessionProviderConfig"] = "file"
configMap["FileLineNum"] = true
jcf := &beeJson.JSONConfig{}
bs, _ = json.Marshal(configMap)
ac, _ := jcf.ParseData(bs)
for _, i := range []interface{}{_BConfig, &_BConfig.Listen, &_BConfig.WebConfig, &_BConfig.Log, &_BConfig.WebConfig.Session} {
assignSingleConfig(i, ac)
}
if _BConfig.MaxMemory != 1024 {
t.Log(_BConfig.MaxMemory)
t.FailNow()
}
if !_BConfig.Listen.Graceful {
t.Log(_BConfig.Listen.Graceful)
t.FailNow()
}
if _BConfig.WebConfig.XSRFExpire != 32 {
t.Log(_BConfig.WebConfig.XSRFExpire)
t.FailNow()
}
if _BConfig.WebConfig.Session.SessionProviderConfig != "file" {
t.Log(_BConfig.WebConfig.Session.SessionProviderConfig)
t.FailNow()
}
if !_BConfig.Log.FileLineNum {
t.Log(_BConfig.Log.FileLineNum)
t.FailNow()
}
}
func TestAssignConfig_03(t *testing.T) {
jcf := &beeJson.JSONConfig{}
ac, _ := jcf.ParseData([]byte(`{"AppName":"beego"}`))
ac.Set(nil, "AppName", "test_app")
ac.Set(nil, "RunMode", "online")
ac.Set(nil, "StaticDir", "download:down download2:down2")
ac.Set(nil, "StaticExtensionsToGzip", ".css,.js,.html,.jpg,.png")
ac.Set(nil, "StaticCacheFileSize", "87456")
ac.Set(nil, "StaticCacheFileNum", "1254")
assignConfig(ac)
t.Logf("%#v", BConfig)
if BConfig.AppName != "test_app" {
t.FailNow()
}
if BConfig.RunMode != "online" {
t.FailNow()
}
if BConfig.WebConfig.StaticDir["/download"] != "down" {
t.FailNow()
}
if BConfig.WebConfig.StaticDir["/download2"] != "down2" {
t.FailNow()
}
if BConfig.WebConfig.StaticCacheFileSize != 87456 {
t.FailNow()
}
if BConfig.WebConfig.StaticCacheFileNum != 1254 {
t.FailNow()
}
if len(BConfig.WebConfig.StaticExtensionsToGzip) != 5 {
t.FailNow()
}
}

View File

@ -0,0 +1,232 @@
// Copyright 2015 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package context
import (
"bytes"
"compress/flate"
"compress/gzip"
"compress/zlib"
"io"
"net/http"
"os"
"strconv"
"strings"
"sync"
)
var (
// Default size==20B same as nginx
defaultGzipMinLength = 20
// Content will only be compressed if content length is either unknown or greater than gzipMinLength.
gzipMinLength = defaultGzipMinLength
// Compression level used for deflate compression. (0-9).
gzipCompressLevel int
// List of HTTP methods to compress. If not set, only GET requests are compressed.
includedMethods map[string]bool
getMethodOnly bool
)
// InitGzip initializes the gzipcompress
func InitGzip(minLength, compressLevel int, methods []string) {
if minLength >= 0 {
gzipMinLength = minLength
}
gzipCompressLevel = compressLevel
if gzipCompressLevel < flate.NoCompression || gzipCompressLevel > flate.BestCompression {
gzipCompressLevel = flate.BestSpeed
}
getMethodOnly = (len(methods) == 0) || (len(methods) == 1 && strings.ToUpper(methods[0]) == "GET")
includedMethods = make(map[string]bool, len(methods))
for _, v := range methods {
includedMethods[strings.ToUpper(v)] = true
}
}
type resetWriter interface {
io.Writer
Reset(w io.Writer)
}
type nopResetWriter struct {
io.Writer
}
func (n nopResetWriter) Reset(w io.Writer) {
//do nothing
}
type acceptEncoder struct {
name string
levelEncode func(int) resetWriter
customCompressLevelPool *sync.Pool
bestCompressionPool *sync.Pool
}
func (ac acceptEncoder) encode(wr io.Writer, level int) resetWriter {
if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil {
return nopResetWriter{wr}
}
var rwr resetWriter
switch level {
case flate.BestSpeed:
rwr = ac.customCompressLevelPool.Get().(resetWriter)
case flate.BestCompression:
rwr = ac.bestCompressionPool.Get().(resetWriter)
default:
rwr = ac.levelEncode(level)
}
rwr.Reset(wr)
return rwr
}
func (ac acceptEncoder) put(wr resetWriter, level int) {
if ac.customCompressLevelPool == nil || ac.bestCompressionPool == nil {
return
}
wr.Reset(nil)
// notice
// compressionLevel==BestCompression DOES NOT MATTER
// sync.Pool will not memory leak
switch level {
case gzipCompressLevel:
ac.customCompressLevelPool.Put(wr)
case flate.BestCompression:
ac.bestCompressionPool.Put(wr)
}
}
var (
noneCompressEncoder = acceptEncoder{"", nil, nil, nil}
gzipCompressEncoder = acceptEncoder{
name: "gzip",
levelEncode: func(level int) resetWriter { wr, _ := gzip.NewWriterLevel(nil, level); return wr },
customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, gzipCompressLevel); return wr }},
bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr }},
}
// According to: http://tools.ietf.org/html/rfc2616#section-3.5 the deflate compress in http is zlib indeed
// deflate
// The "zlib" format defined in RFC 1950 [31] in combination with
// the "deflate" compression mechanism described in RFC 1951 [29].
deflateCompressEncoder = acceptEncoder{
name: "deflate",
levelEncode: func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr },
customCompressLevelPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, gzipCompressLevel); return wr }},
bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := zlib.NewWriterLevel(nil, flate.BestCompression); return wr }},
}
)
var (
encoderMap = map[string]acceptEncoder{ // all the other compress methods will ignore
"gzip": gzipCompressEncoder,
"deflate": deflateCompressEncoder,
"*": gzipCompressEncoder, // * means any compress will accept,we prefer gzip
"identity": noneCompressEncoder, // identity means none-compress
}
)
// WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate)
func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) {
return writeLevel(encoding, writer, file, flate.BestCompression)
}
// WriteBody reads writes content to writer by the specific encoding(gzip/deflate)
func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) {
if encoding == "" || len(content) < gzipMinLength {
_, err := writer.Write(content)
return false, "", err
}
return writeLevel(encoding, writer, bytes.NewReader(content), gzipCompressLevel)
}
// writeLevel reads from reader and writes to writer by specific encoding and compress level.
// The compress level is defined by deflate package
func writeLevel(encoding string, writer io.Writer, reader io.Reader, level int) (bool, string, error) {
var outputWriter resetWriter
var err error
var ce = noneCompressEncoder
if cf, ok := encoderMap[encoding]; ok {
ce = cf
}
encoding = ce.name
outputWriter = ce.encode(writer, level)
defer ce.put(outputWriter, level)
_, err = io.Copy(outputWriter, reader)
if err != nil {
return false, "", err
}
switch outputWriter.(type) {
case io.WriteCloser:
outputWriter.(io.WriteCloser).Close()
}
return encoding != "", encoding, nil
}
// ParseEncoding will extract the right encoding for response
// the Accept-Encoding's sec is here:
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.3
func ParseEncoding(r *http.Request) string {
if r == nil {
return ""
}
if (getMethodOnly && r.Method == "GET") || includedMethods[r.Method] {
return parseEncoding(r)
}
return ""
}
type q struct {
name string
value float64
}
func parseEncoding(r *http.Request) string {
acceptEncoding := r.Header.Get("Accept-Encoding")
if acceptEncoding == "" {
return ""
}
var lastQ q
for _, v := range strings.Split(acceptEncoding, ",") {
v = strings.TrimSpace(v)
if v == "" {
continue
}
vs := strings.Split(v, ";")
var cf acceptEncoder
var ok bool
if cf, ok = encoderMap[vs[0]]; !ok {
continue
}
if len(vs) == 1 {
return cf.name
}
if len(vs) == 2 {
f, _ := strconv.ParseFloat(strings.Replace(vs[1], "q=", "", -1), 64)
if f == 0 {
continue
}
if f > lastQ.value {
lastQ = q{cf.name, f}
}
}
}
return lastQ.name
}

View File

@ -0,0 +1,59 @@
// Copyright 2015 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package context
import (
"net/http"
"testing"
)
func Test_ExtractEncoding(t *testing.T) {
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip,deflate"}}}) != "gzip" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"deflate,gzip"}}}) != "deflate" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=.5,deflate"}}}) != "deflate" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=.5,deflate;q=0.3"}}}) != "gzip" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0,deflate"}}}) != "deflate" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"deflate;q=0.5,gzip;q=0.5,identity"}}}) != "" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"*"}}}) != "gzip" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"x,gzip,deflate"}}}) != "gzip" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip,x,deflate"}}}) != "gzip" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0.5,x,deflate"}}}) != "deflate" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"x"}}}) != "" {
t.Fail()
}
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0.5,x;q=0.8"}}}) != "gzip" {
t.Fail()
}
}

View File

@ -0,0 +1,262 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package context provide the context utils
// Usage:
//
// import "github.com/astaxie/beego/context"
//
// ctx := context.Context{Request:req,ResponseWriter:rw}
//
// more docs http://beego.me/docs/module/context.md
package context
import (
"bufio"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"net"
"net/http"
"strconv"
"strings"
"time"
"github.com/astaxie/beego/core/utils"
)
// Commonly used mime-types
const (
ApplicationJSON = "application/json"
ApplicationXML = "application/xml"
ApplicationYAML = "application/x-yaml"
TextXML = "text/xml"
)
// NewContext return the Context with Input and Output
func NewContext() *Context {
return &Context{
Input: NewInput(),
Output: NewOutput(),
}
}
// Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter.
// BeegoInput and BeegoOutput provides an api to operate request and response more easily.
type Context struct {
Input *BeegoInput
Output *BeegoOutput
Request *http.Request
ResponseWriter *Response
_xsrfToken string
}
// Reset initializes Context, BeegoInput and BeegoOutput
func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) {
ctx.Request = r
if ctx.ResponseWriter == nil {
ctx.ResponseWriter = &Response{}
}
ctx.ResponseWriter.reset(rw)
ctx.Input.Reset(ctx)
ctx.Output.Reset(ctx)
ctx._xsrfToken = ""
}
// Redirect redirects to localurl with http header status code.
func (ctx *Context) Redirect(status int, localurl string) {
http.Redirect(ctx.ResponseWriter, ctx.Request, localurl, status)
}
// Abort stops the request.
// If beego.ErrorMaps exists, panic body.
func (ctx *Context) Abort(status int, body string) {
ctx.Output.SetStatus(status)
panic(body)
}
// WriteString writes a string to response body.
func (ctx *Context) WriteString(content string) {
ctx.ResponseWriter.Write([]byte(content))
}
// GetCookie gets a cookie from a request for a given key.
// (Alias of BeegoInput.Cookie)
func (ctx *Context) GetCookie(key string) string {
return ctx.Input.Cookie(key)
}
// SetCookie sets a cookie for a response.
// (Alias of BeegoOutput.Cookie)
func (ctx *Context) SetCookie(name string, value string, others ...interface{}) {
ctx.Output.Cookie(name, value, others...)
}
// GetSecureCookie gets a secure cookie from a request for a given key.
func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) {
val := ctx.Input.Cookie(key)
if val == "" {
return "", false
}
parts := strings.SplitN(val, "|", 3)
if len(parts) != 3 {
return "", false
}
vs := parts[0]
timestamp := parts[1]
sig := parts[2]
h := hmac.New(sha256.New, []byte(Secret))
fmt.Fprintf(h, "%s%s", vs, timestamp)
if fmt.Sprintf("%02x", h.Sum(nil)) != sig {
return "", false
}
res, _ := base64.URLEncoding.DecodeString(vs)
return string(res), true
}
// SetSecureCookie sets a secure cookie for a response.
func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) {
vs := base64.URLEncoding.EncodeToString([]byte(value))
timestamp := strconv.FormatInt(time.Now().UnixNano(), 10)
h := hmac.New(sha256.New, []byte(Secret))
fmt.Fprintf(h, "%s%s", vs, timestamp)
sig := fmt.Sprintf("%02x", h.Sum(nil))
cookie := strings.Join([]string{vs, timestamp, sig}, "|")
ctx.Output.Cookie(name, cookie, others...)
}
// XSRFToken creates and returns an xsrf token string
func (ctx *Context) XSRFToken(key string, expire int64) string {
if ctx._xsrfToken == "" {
token, ok := ctx.GetSecureCookie(key, "_xsrf")
if !ok {
token = string(utils.RandomCreateBytes(32))
ctx.SetSecureCookie(key, "_xsrf", token, expire, "", "", true, true)
}
ctx._xsrfToken = token
}
return ctx._xsrfToken
}
// CheckXSRFCookie checks if the XSRF token in this request is valid or not.
// The token can be provided in the request header in the form "X-Xsrftoken" or "X-CsrfToken"
// or in form field value named as "_xsrf".
func (ctx *Context) CheckXSRFCookie() bool {
token := ctx.Input.Query("_xsrf")
if token == "" {
token = ctx.Request.Header.Get("X-Xsrftoken")
}
if token == "" {
token = ctx.Request.Header.Get("X-Csrftoken")
}
if token == "" {
ctx.Abort(422, "422")
return false
}
if ctx._xsrfToken != token {
ctx.Abort(417, "417")
return false
}
return true
}
// RenderMethodResult renders the return value of a controller method to the output
func (ctx *Context) RenderMethodResult(result interface{}) {
if result != nil {
renderer, ok := result.(Renderer)
if !ok {
err, ok := result.(error)
if ok {
renderer = errorRenderer(err)
} else {
renderer = jsonRenderer(result)
}
}
renderer.Render(ctx)
}
}
// Response is a wrapper for the http.ResponseWriter
// Started: if true, response was already written to so the other handler will not be executed
type Response struct {
http.ResponseWriter
Started bool
Status int
Elapsed time.Duration
}
func (r *Response) reset(rw http.ResponseWriter) {
r.ResponseWriter = rw
r.Status = 0
r.Started = false
}
// Write writes the data to the connection as part of a HTTP reply,
// and sets `Started` to true.
// Started: if true, the response was already sent
func (r *Response) Write(p []byte) (int, error) {
r.Started = true
return r.ResponseWriter.Write(p)
}
// WriteHeader sends a HTTP response header with status code,
// and sets `Started` to true.
func (r *Response) WriteHeader(code int) {
if r.Status > 0 {
//prevent multiple response.WriteHeader calls
return
}
r.Status = code
r.Started = true
r.ResponseWriter.WriteHeader(code)
}
// Hijack hijacker for http
func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hj, ok := r.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, errors.New("webserver doesn't support hijacking")
}
return hj.Hijack()
}
// Flush http.Flusher
func (r *Response) Flush() {
if f, ok := r.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
// CloseNotify http.CloseNotifier
func (r *Response) CloseNotify() <-chan bool {
if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok {
return cn.CloseNotify()
}
return nil
}
// Pusher http.Pusher
func (r *Response) Pusher() (pusher http.Pusher) {
if pusher, ok := r.ResponseWriter.(http.Pusher); ok {
return pusher
}
return nil
}

View File

@ -0,0 +1,47 @@
// Copyright 2016 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package context
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestXsrfReset_01(t *testing.T) {
r := &http.Request{}
c := NewContext()
c.Request = r
c.ResponseWriter = &Response{}
c.ResponseWriter.reset(httptest.NewRecorder())
c.Output.Reset(c)
c.Input.Reset(c)
c.XSRFToken("key", 16)
if c._xsrfToken == "" {
t.FailNow()
}
token := c._xsrfToken
c.Reset(&Response{ResponseWriter: httptest.NewRecorder()}, r)
if c._xsrfToken != "" {
t.FailNow()
}
c.XSRFToken("key", 16)
if c._xsrfToken == "" {
t.FailNow()
}
if token == c._xsrfToken {
t.FailNow()
}
}

687
server/web/context/input.go Normal file
View File

@ -0,0 +1,687 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package context
import (
"bytes"
"compress/gzip"
"errors"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"reflect"
"regexp"
"strconv"
"strings"
"sync"
"github.com/astaxie/beego/server/web/session"
)
// Regexes for checking the accept headers
// TODO make sure these are correct
var (
acceptsHTMLRegex = regexp.MustCompile(`(text/html|application/xhtml\+xml)(?:,|$)`)
acceptsXMLRegex = regexp.MustCompile(`(application/xml|text/xml)(?:,|$)`)
acceptsJSONRegex = regexp.MustCompile(`(application/json)(?:,|$)`)
acceptsYAMLRegex = regexp.MustCompile(`(application/x-yaml)(?:,|$)`)
maxParam = 50
)
// BeegoInput operates the http request header, data, cookie and body.
// Contains router params and current session.
type BeegoInput struct {
Context *Context
CruSession session.Store
pnames []string
pvalues []string
data map[interface{}]interface{} // store some values in this context when calling context in filter or controller.
dataLock sync.RWMutex
RequestBody []byte
RunMethod string
RunController reflect.Type
}
// NewInput returns the BeegoInput generated by context.
func NewInput() *BeegoInput {
return &BeegoInput{
pnames: make([]string, 0, maxParam),
pvalues: make([]string, 0, maxParam),
data: make(map[interface{}]interface{}),
}
}
// Reset initializes the BeegoInput
func (input *BeegoInput) Reset(ctx *Context) {
input.Context = ctx
input.CruSession = nil
input.pnames = input.pnames[:0]
input.pvalues = input.pvalues[:0]
input.dataLock.Lock()
input.data = nil
input.dataLock.Unlock()
input.RequestBody = []byte{}
}
// Protocol returns the request protocol name, such as HTTP/1.1 .
func (input *BeegoInput) Protocol() string {
return input.Context.Request.Proto
}
// URI returns the full request url with query, string and fragment.
func (input *BeegoInput) URI() string {
return input.Context.Request.RequestURI
}
// URL returns the request url path (without query, string and fragment).
func (input *BeegoInput) URL() string {
return input.Context.Request.URL.Path
}
// Site returns the base site url as scheme://domain type.
func (input *BeegoInput) Site() string {
return input.Scheme() + "://" + input.Domain()
}
// Scheme returns the request scheme as "http" or "https".
func (input *BeegoInput) Scheme() string {
if scheme := input.Header("X-Forwarded-Proto"); scheme != "" {
return scheme
}
if input.Context.Request.URL.Scheme != "" {
return input.Context.Request.URL.Scheme
}
if input.Context.Request.TLS == nil {
return "http"
}
return "https"
}
// Domain returns the host name (alias of host method)
func (input *BeegoInput) Domain() string {
return input.Host()
}
// Host returns the host name.
// If no host info in request, return localhost.
func (input *BeegoInput) Host() string {
if input.Context.Request.Host != "" {
if hostPart, _, err := net.SplitHostPort(input.Context.Request.Host); err == nil {
return hostPart
}
return input.Context.Request.Host
}
return "localhost"
}
// Method returns http request method.
func (input *BeegoInput) Method() string {
return input.Context.Request.Method
}
// Is returns the boolean value of this request is on given method, such as Is("POST").
func (input *BeegoInput) Is(method string) bool {
return input.Method() == method
}
// IsGet Is this a GET method request?
func (input *BeegoInput) IsGet() bool {
return input.Is("GET")
}
// IsPost Is this a POST method request?
func (input *BeegoInput) IsPost() bool {
return input.Is("POST")
}
// IsHead Is this a Head method request?
func (input *BeegoInput) IsHead() bool {
return input.Is("HEAD")
}
// IsOptions Is this a OPTIONS method request?
func (input *BeegoInput) IsOptions() bool {
return input.Is("OPTIONS")
}
// IsPut Is this a PUT method request?
func (input *BeegoInput) IsPut() bool {
return input.Is("PUT")
}
// IsDelete Is this a DELETE method request?
func (input *BeegoInput) IsDelete() bool {
return input.Is("DELETE")
}
// IsPatch Is this a PATCH method request?
func (input *BeegoInput) IsPatch() bool {
return input.Is("PATCH")
}
// IsAjax returns boolean of is this request generated by ajax.
func (input *BeegoInput) IsAjax() bool {
return input.Header("X-Requested-With") == "XMLHttpRequest"
}
// IsSecure returns boolean of this request is in https.
func (input *BeegoInput) IsSecure() bool {
return input.Scheme() == "https"
}
// IsWebsocket returns boolean of this request is in webSocket.
func (input *BeegoInput) IsWebsocket() bool {
return input.Header("Upgrade") == "websocket"
}
// IsUpload returns boolean of whether file uploads in this request or not..
func (input *BeegoInput) IsUpload() bool {
return strings.Contains(input.Header("Content-Type"), "multipart/form-data")
}
// AcceptsHTML Checks if request accepts html response
func (input *BeegoInput) AcceptsHTML() bool {
return acceptsHTMLRegex.MatchString(input.Header("Accept"))
}
// AcceptsXML Checks if request accepts xml response
func (input *BeegoInput) AcceptsXML() bool {
return acceptsXMLRegex.MatchString(input.Header("Accept"))
}
// AcceptsJSON Checks if request accepts json response
func (input *BeegoInput) AcceptsJSON() bool {
return acceptsJSONRegex.MatchString(input.Header("Accept"))
}
// AcceptsYAML Checks if request accepts json response
func (input *BeegoInput) AcceptsYAML() bool {
return acceptsYAMLRegex.MatchString(input.Header("Accept"))
}
// IP returns request client ip.
// if in proxy, return first proxy id.
// if error, return RemoteAddr.
func (input *BeegoInput) IP() string {
ips := input.Proxy()
if len(ips) > 0 && ips[0] != "" {
rip, _, err := net.SplitHostPort(ips[0])
if err != nil {
rip = ips[0]
}
return rip
}
if ip, _, err := net.SplitHostPort(input.Context.Request.RemoteAddr); err == nil {
return ip
}
return input.Context.Request.RemoteAddr
}
// Proxy returns proxy client ips slice.
func (input *BeegoInput) Proxy() []string {
if ips := input.Header("X-Forwarded-For"); ips != "" {
return strings.Split(ips, ",")
}
return []string{}
}
// Referer returns http referer header.
func (input *BeegoInput) Referer() string {
return input.Header("Referer")
}
// Refer returns http referer header.
func (input *BeegoInput) Refer() string {
return input.Referer()
}
// SubDomains returns sub domain string.
// if aa.bb.domain.com, returns aa.bb
func (input *BeegoInput) SubDomains() string {
parts := strings.Split(input.Host(), ".")
if len(parts) >= 3 {
return strings.Join(parts[:len(parts)-2], ".")
}
return ""
}
// Port returns request client port.
// when error or empty, return 80.
func (input *BeegoInput) Port() int {
if _, portPart, err := net.SplitHostPort(input.Context.Request.Host); err == nil {
port, _ := strconv.Atoi(portPart)
return port
}
return 80
}
// UserAgent returns request client user agent string.
func (input *BeegoInput) UserAgent() string {
return input.Header("User-Agent")
}
// ParamsLen return the length of the params
func (input *BeegoInput) ParamsLen() int {
return len(input.pnames)
}
// Param returns router param by a given key.
func (input *BeegoInput) Param(key string) string {
for i, v := range input.pnames {
if v == key && i <= len(input.pvalues) {
// we cannot use url.PathEscape(input.pvalues[i])
// for example, if the value is /a/b
// after url.PathEscape(input.pvalues[i]), the value is %2Fa%2Fb
// However, the value is used in ControllerRegister.ServeHTTP
// and split by "/", so function crash...
return input.pvalues[i]
}
}
return ""
}
// Params returns the map[key]value.
func (input *BeegoInput) Params() map[string]string {
m := make(map[string]string)
for i, v := range input.pnames {
if i <= len(input.pvalues) {
m[v] = input.pvalues[i]
}
}
return m
}
// SetParam sets the param with key and value
func (input *BeegoInput) SetParam(key, val string) {
// check if already exists
for i, v := range input.pnames {
if v == key && i <= len(input.pvalues) {
input.pvalues[i] = val
return
}
}
input.pvalues = append(input.pvalues, val)
input.pnames = append(input.pnames, key)
}
// ResetParams clears any of the input's params
// Used to clear parameters so they may be reset between filter passes.
func (input *BeegoInput) ResetParams() {
input.pnames = input.pnames[:0]
input.pvalues = input.pvalues[:0]
}
// Query returns input data item string by a given string.
func (input *BeegoInput) Query(key string) string {
if val := input.Param(key); val != "" {
return val
}
if input.Context.Request.Form == nil {
input.dataLock.Lock()
if input.Context.Request.Form == nil {
input.Context.Request.ParseForm()
}
input.dataLock.Unlock()
}
input.dataLock.RLock()
defer input.dataLock.RUnlock()
return input.Context.Request.Form.Get(key)
}
// Header returns request header item string by a given string.
// if non-existed, return empty string.
func (input *BeegoInput) Header(key string) string {
return input.Context.Request.Header.Get(key)
}
// Cookie returns request cookie item string by a given key.
// if non-existed, return empty string.
func (input *BeegoInput) Cookie(key string) string {
ck, err := input.Context.Request.Cookie(key)
if err != nil {
return ""
}
return ck.Value
}
// Session returns current session item value by a given key.
// if non-existed, return nil.
func (input *BeegoInput) Session(key interface{}) interface{} {
return input.CruSession.Get(nil, key)
}
// CopyBody returns the raw request body data as bytes.
func (input *BeegoInput) CopyBody(MaxMemory int64) []byte {
if input.Context.Request.Body == nil {
return []byte{}
}
var requestbody []byte
safe := &io.LimitedReader{R: input.Context.Request.Body, N: MaxMemory}
if input.Header("Content-Encoding") == "gzip" {
reader, err := gzip.NewReader(safe)
if err != nil {
return nil
}
requestbody, _ = ioutil.ReadAll(reader)
} else {
requestbody, _ = ioutil.ReadAll(safe)
}
input.Context.Request.Body.Close()
bf := bytes.NewBuffer(requestbody)
input.Context.Request.Body = http.MaxBytesReader(input.Context.ResponseWriter, ioutil.NopCloser(bf), MaxMemory)
input.RequestBody = requestbody
return requestbody
}
// Data returns the implicit data in the input
func (input *BeegoInput) Data() map[interface{}]interface{} {
input.dataLock.Lock()
defer input.dataLock.Unlock()
if input.data == nil {
input.data = make(map[interface{}]interface{})
}
return input.data
}
// GetData returns the stored data in this context.
func (input *BeegoInput) GetData(key interface{}) interface{} {
input.dataLock.Lock()
defer input.dataLock.Unlock()
if v, ok := input.data[key]; ok {
return v
}
return nil
}
// SetData stores data with given key in this context.
// This data is only available in this context.
func (input *BeegoInput) SetData(key, val interface{}) {
input.dataLock.Lock()
defer input.dataLock.Unlock()
if input.data == nil {
input.data = make(map[interface{}]interface{})
}
input.data[key] = val
}
// ParseFormOrMulitForm parseForm or parseMultiForm based on Content-type
func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error {
// Parse the body depending on the content type.
if strings.Contains(input.Header("Content-Type"), "multipart/form-data") {
if err := input.Context.Request.ParseMultipartForm(maxMemory); err != nil {
return errors.New("Error parsing request body:" + err.Error())
}
} else if err := input.Context.Request.ParseForm(); err != nil {
return errors.New("Error parsing request body:" + err.Error())
}
return nil
}
// Bind data from request.Form[key] to dest
// like /?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie
// var id int beegoInput.Bind(&id, "id") id ==123
// var isok bool beegoInput.Bind(&isok, "isok") isok ==true
// var ft float64 beegoInput.Bind(&ft, "ft") ft ==1.2
// ol := make([]int, 0, 2) beegoInput.Bind(&ol, "ol") ol ==[1 2]
// ul := make([]string, 0, 2) beegoInput.Bind(&ul, "ul") ul ==[str array]
// user struct{Name} beegoInput.Bind(&user, "user") user == {Name:"astaxie"}
func (input *BeegoInput) Bind(dest interface{}, key string) error {
value := reflect.ValueOf(dest)
if value.Kind() != reflect.Ptr {
return errors.New("beego: non-pointer passed to Bind: " + key)
}
value = value.Elem()
if !value.CanSet() {
return errors.New("beego: non-settable variable passed to Bind: " + key)
}
typ := value.Type()
// Get real type if dest define with interface{}.
// e.g var dest interface{} dest=1.0
if value.Kind() == reflect.Interface {
typ = value.Elem().Type()
}
rv := input.bind(key, typ)
if !rv.IsValid() {
return errors.New("beego: reflect value is empty")
}
value.Set(rv)
return nil
}
func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value {
if input.Context.Request.Form == nil {
input.Context.Request.ParseForm()
}
rv := reflect.Zero(typ)
switch typ.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
val := input.Query(key)
if len(val) == 0 {
return rv
}
rv = input.bindInt(val, typ)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
val := input.Query(key)
if len(val) == 0 {
return rv
}
rv = input.bindUint(val, typ)
case reflect.Float32, reflect.Float64:
val := input.Query(key)
if len(val) == 0 {
return rv
}
rv = input.bindFloat(val, typ)
case reflect.String:
val := input.Query(key)
if len(val) == 0 {
return rv
}
rv = input.bindString(val, typ)
case reflect.Bool:
val := input.Query(key)
if len(val) == 0 {
return rv
}
rv = input.bindBool(val, typ)
case reflect.Slice:
rv = input.bindSlice(&input.Context.Request.Form, key, typ)
case reflect.Struct:
rv = input.bindStruct(&input.Context.Request.Form, key, typ)
case reflect.Ptr:
rv = input.bindPoint(key, typ)
case reflect.Map:
rv = input.bindMap(&input.Context.Request.Form, key, typ)
}
return rv
}
func (input *BeegoInput) bindValue(val string, typ reflect.Type) reflect.Value {
rv := reflect.Zero(typ)
switch typ.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
rv = input.bindInt(val, typ)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
rv = input.bindUint(val, typ)
case reflect.Float32, reflect.Float64:
rv = input.bindFloat(val, typ)
case reflect.String:
rv = input.bindString(val, typ)
case reflect.Bool:
rv = input.bindBool(val, typ)
case reflect.Slice:
rv = input.bindSlice(&url.Values{"": {val}}, "", typ)
case reflect.Struct:
rv = input.bindStruct(&url.Values{"": {val}}, "", typ)
case reflect.Ptr:
rv = input.bindPoint(val, typ)
case reflect.Map:
rv = input.bindMap(&url.Values{"": {val}}, "", typ)
}
return rv
}
func (input *BeegoInput) bindInt(val string, typ reflect.Type) reflect.Value {
intValue, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return reflect.Zero(typ)
}
pValue := reflect.New(typ)
pValue.Elem().SetInt(intValue)
return pValue.Elem()
}
func (input *BeegoInput) bindUint(val string, typ reflect.Type) reflect.Value {
uintValue, err := strconv.ParseUint(val, 10, 64)
if err != nil {
return reflect.Zero(typ)
}
pValue := reflect.New(typ)
pValue.Elem().SetUint(uintValue)
return pValue.Elem()
}
func (input *BeegoInput) bindFloat(val string, typ reflect.Type) reflect.Value {
floatValue, err := strconv.ParseFloat(val, 64)
if err != nil {
return reflect.Zero(typ)
}
pValue := reflect.New(typ)
pValue.Elem().SetFloat(floatValue)
return pValue.Elem()
}
func (input *BeegoInput) bindString(val string, typ reflect.Type) reflect.Value {
return reflect.ValueOf(val)
}
func (input *BeegoInput) bindBool(val string, typ reflect.Type) reflect.Value {
val = strings.TrimSpace(strings.ToLower(val))
switch val {
case "true", "on", "1":
return reflect.ValueOf(true)
}
return reflect.ValueOf(false)
}
type sliceValue struct {
index int // Index extracted from brackets. If -1, no index was provided.
value reflect.Value // the bound value for this slice element.
}
func (input *BeegoInput) bindSlice(params *url.Values, key string, typ reflect.Type) reflect.Value {
maxIndex := -1
numNoIndex := 0
sliceValues := []sliceValue{}
for reqKey, vals := range *params {
if !strings.HasPrefix(reqKey, key+"[") {
continue
}
// Extract the index, and the index where a sub-key starts. (e.g. field[0].subkey)
index := -1
leftBracket, rightBracket := len(key), strings.Index(reqKey[len(key):], "]")+len(key)
if rightBracket > leftBracket+1 {
index, _ = strconv.Atoi(reqKey[leftBracket+1 : rightBracket])
}
subKeyIndex := rightBracket + 1
// Handle the indexed case.
if index > -1 {
if index > maxIndex {
maxIndex = index
}
sliceValues = append(sliceValues, sliceValue{
index: index,
value: input.bind(reqKey[:subKeyIndex], typ.Elem()),
})
continue
}
// It's an un-indexed element. (e.g. element[])
numNoIndex += len(vals)
for _, val := range vals {
// Unindexed values can only be direct-bound.
sliceValues = append(sliceValues, sliceValue{
index: -1,
value: input.bindValue(val, typ.Elem()),
})
}
}
resultArray := reflect.MakeSlice(typ, maxIndex+1, maxIndex+1+numNoIndex)
for _, sv := range sliceValues {
if sv.index != -1 {
resultArray.Index(sv.index).Set(sv.value)
} else {
resultArray = reflect.Append(resultArray, sv.value)
}
}
return resultArray
}
func (input *BeegoInput) bindStruct(params *url.Values, key string, typ reflect.Type) reflect.Value {
result := reflect.New(typ).Elem()
fieldValues := make(map[string]reflect.Value)
for reqKey, val := range *params {
var fieldName string
if strings.HasPrefix(reqKey, key+".") {
fieldName = reqKey[len(key)+1:]
} else if strings.HasPrefix(reqKey, key+"[") && reqKey[len(reqKey)-1] == ']' {
fieldName = reqKey[len(key)+1 : len(reqKey)-1]
} else {
continue
}
if _, ok := fieldValues[fieldName]; !ok {
// Time to bind this field. Get it and make sure we can set it.
fieldValue := result.FieldByName(fieldName)
if !fieldValue.IsValid() {
continue
}
if !fieldValue.CanSet() {
continue
}
boundVal := input.bindValue(val[0], fieldValue.Type())
fieldValue.Set(boundVal)
fieldValues[fieldName] = boundVal
}
}
return result
}
func (input *BeegoInput) bindPoint(key string, typ reflect.Type) reflect.Value {
return input.bind(key, typ.Elem()).Addr()
}
func (input *BeegoInput) bindMap(params *url.Values, key string, typ reflect.Type) reflect.Value {
var (
result = reflect.MakeMap(typ)
keyType = typ.Key()
valueType = typ.Elem()
)
for paramName, values := range *params {
if !strings.HasPrefix(paramName, key+"[") || paramName[len(paramName)-1] != ']' {
continue
}
key := paramName[len(key)+1 : len(paramName)-1]
result.SetMapIndex(input.bindValue(key, keyType), input.bindValue(values[0], valueType))
}
return result
}

View File

@ -0,0 +1,217 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package context
import (
"net/http"
"net/http/httptest"
"reflect"
"testing"
)
func TestBind(t *testing.T) {
type testItem struct {
field string
empty interface{}
want interface{}
}
type Human struct {
ID int
Nick string
Pwd string
Ms bool
}
cases := []struct {
request string
valueGp []testItem
}{
{"/?p=str", []testItem{{"p", interface{}(""), interface{}("str")}}},
{"/?p=", []testItem{{"p", "", ""}}},
{"/?p=str", []testItem{{"p", "", "str"}}},
{"/?p=123", []testItem{{"p", 0, 123}}},
{"/?p=123", []testItem{{"p", uint(0), uint(123)}}},
{"/?p=1.0", []testItem{{"p", 0.0, 1.0}}},
{"/?p=1", []testItem{{"p", false, true}}},
{"/?p=true", []testItem{{"p", false, true}}},
{"/?p=ON", []testItem{{"p", false, true}}},
{"/?p=on", []testItem{{"p", false, true}}},
{"/?p=1", []testItem{{"p", false, true}}},
{"/?p=2", []testItem{{"p", false, false}}},
{"/?p=false", []testItem{{"p", false, false}}},
{"/?p[a]=1&p[b]=2&p[c]=3", []testItem{{"p", map[string]int{}, map[string]int{"a": 1, "b": 2, "c": 3}}}},
{"/?p[a]=v1&p[b]=v2&p[c]=v3", []testItem{{"p", map[string]string{}, map[string]string{"a": "v1", "b": "v2", "c": "v3"}}}},
{"/?p[]=8&p[]=9&p[]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}},
{"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}},
{"/?p[0]=8&p[1]=9&p[2]=10&p[5]=14", []testItem{{"p", []int{}, []int{8, 9, 10, 0, 0, 14}}}},
{"/?p[0]=8.0&p[1]=9.0&p[2]=10.0", []testItem{{"p", []float64{}, []float64{8.0, 9.0, 10.0}}}},
{"/?p[]=10&p[]=9&p[]=8", []testItem{{"p", []string{}, []string{"10", "9", "8"}}}},
{"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []string{}, []string{"8", "9", "10"}}}},
{"/?p[0]=true&p[1]=false&p[2]=true&p[5]=1&p[6]=ON&p[7]=other", []testItem{{"p", []bool{}, []bool{true, false, true, false, false, true, true, false}}}},
{"/?human.Nick=astaxie", []testItem{{"human", Human{}, Human{Nick: "astaxie"}}}},
{"/?human.ID=888&human.Nick=astaxie&human.Ms=true&human[Pwd]=pass", []testItem{{"human", Human{}, Human{ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass"}}}},
{"/?human[0].ID=888&human[0].Nick=astaxie&human[0].Ms=true&human[0][Pwd]=pass01&human[1].ID=999&human[1].Nick=ysqi&human[1].Ms=On&human[1].Pwd=pass02",
[]testItem{{"human", []Human{}, []Human{
{ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass01"},
{ID: 999, Nick: "ysqi", Ms: true, Pwd: "pass02"},
}}}},
{
"/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&human.Nick=astaxie",
[]testItem{
{"id", 0, 123},
{"isok", false, true},
{"ft", 0.0, 1.2},
{"ol", []int{}, []int{1, 2}},
{"ul", []string{}, []string{"str", "array"}},
{"human", Human{}, Human{Nick: "astaxie"}},
},
},
}
for _, c := range cases {
r, _ := http.NewRequest("GET", c.request, nil)
beegoInput := NewInput()
beegoInput.Context = NewContext()
beegoInput.Context.Reset(httptest.NewRecorder(), r)
for _, item := range c.valueGp {
got := item.empty
err := beegoInput.Bind(&got, item.field)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, item.want) {
t.Fatalf("Bind %q error,should be:\n%#v \ngot:\n%#v", item.field, item.want, got)
}
}
}
}
func TestSubDomain(t *testing.T) {
r, _ := http.NewRequest("GET", "http://www.example.com/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil)
beegoInput := NewInput()
beegoInput.Context = NewContext()
beegoInput.Context.Reset(httptest.NewRecorder(), r)
subdomain := beegoInput.SubDomains()
if subdomain != "www" {
t.Fatal("Subdomain parse error, got" + subdomain)
}
r, _ = http.NewRequest("GET", "http://localhost/", nil)
beegoInput.Context.Request = r
if beegoInput.SubDomains() != "" {
t.Fatal("Subdomain parse error, should be empty, got " + beegoInput.SubDomains())
}
r, _ = http.NewRequest("GET", "http://aa.bb.example.com/", nil)
beegoInput.Context.Request = r
if beegoInput.SubDomains() != "aa.bb" {
t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains())
}
/* TODO Fix this
r, _ = http.NewRequest("GET", "http://127.0.0.1/", nil)
beegoInput.Context.Request = r
if beegoInput.SubDomains() != "" {
t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains())
}
*/
r, _ = http.NewRequest("GET", "http://example.com/", nil)
beegoInput.Context.Request = r
if beegoInput.SubDomains() != "" {
t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains())
}
r, _ = http.NewRequest("GET", "http://aa.bb.cc.dd.example.com/", nil)
beegoInput.Context.Request = r
if beegoInput.SubDomains() != "aa.bb.cc.dd" {
t.Fatal("Subdomain parse error, got " + beegoInput.SubDomains())
}
}
func TestParams(t *testing.T) {
inp := NewInput()
inp.SetParam("p1", "val1_ver1")
inp.SetParam("p2", "val2_ver1")
inp.SetParam("p3", "val3_ver1")
if l := inp.ParamsLen(); l != 3 {
t.Fatalf("Input.ParamsLen wrong value: %d, expected %d", l, 3)
}
if val := inp.Param("p1"); val != "val1_ver1" {
t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver1")
}
if val := inp.Param("p3"); val != "val3_ver1" {
t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val3_ver1")
}
vals := inp.Params()
expected := map[string]string{
"p1": "val1_ver1",
"p2": "val2_ver1",
"p3": "val3_ver1",
}
if !reflect.DeepEqual(vals, expected) {
t.Fatalf("Input.Params wrong value: %s, expected %s", vals, expected)
}
// overwriting existing params
inp.SetParam("p1", "val1_ver2")
inp.SetParam("p2", "val2_ver2")
expected = map[string]string{
"p1": "val1_ver2",
"p2": "val2_ver2",
"p3": "val3_ver1",
}
vals = inp.Params()
if !reflect.DeepEqual(vals, expected) {
t.Fatalf("Input.Params wrong value: %s, expected %s", vals, expected)
}
if l := inp.ParamsLen(); l != 3 {
t.Fatalf("Input.ParamsLen wrong value: %d, expected %d", l, 3)
}
if val := inp.Param("p1"); val != "val1_ver2" {
t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver2")
}
if val := inp.Param("p2"); val != "val2_ver2" {
t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver2")
}
}
func BenchmarkQuery(b *testing.B) {
beegoInput := NewInput()
beegoInput.Context = NewContext()
beegoInput.Context.Request, _ = http.NewRequest("POST", "http://www.example.com/?q=foo", nil)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
beegoInput.Query("q")
}
})
}

View File

@ -0,0 +1,408 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package context
import (
"bytes"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"html/template"
"io"
"mime"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"time"
yaml "gopkg.in/yaml.v2"
)
// BeegoOutput does work for sending response header.
type BeegoOutput struct {
Context *Context
Status int
EnableGzip bool
}
// NewOutput returns new BeegoOutput.
// Empty when initialized
func NewOutput() *BeegoOutput {
return &BeegoOutput{}
}
// Reset initializes BeegoOutput
func (output *BeegoOutput) Reset(ctx *Context) {
output.Context = ctx
output.Status = 0
}
// Header sets response header item string via given key.
func (output *BeegoOutput) Header(key, val string) {
output.Context.ResponseWriter.Header().Set(key, val)
}
// Body sets the response body content.
// if EnableGzip, content is compressed.
// Sends out response body directly.
func (output *BeegoOutput) Body(content []byte) error {
var encoding string
var buf = &bytes.Buffer{}
if output.EnableGzip {
encoding = ParseEncoding(output.Context.Request)
}
if b, n, _ := WriteBody(encoding, buf, content); b {
output.Header("Content-Encoding", n)
output.Header("Content-Length", strconv.Itoa(buf.Len()))
} else {
output.Header("Content-Length", strconv.Itoa(len(content)))
}
// Write status code if it has been set manually
// Set it to 0 afterwards to prevent "multiple response.WriteHeader calls"
if output.Status != 0 {
output.Context.ResponseWriter.WriteHeader(output.Status)
output.Status = 0
} else {
output.Context.ResponseWriter.Started = true
}
io.Copy(output.Context.ResponseWriter, buf)
return nil
}
// Cookie sets a cookie value via given key.
// others: used to set a cookie's max age time, path,domain, secure and httponly.
func (output *BeegoOutput) Cookie(name string, value string, others ...interface{}) {
var b bytes.Buffer
fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value))
// fix cookie not work in IE
if len(others) > 0 {
var maxAge int64
switch v := others[0].(type) {
case int:
maxAge = int64(v)
case int32:
maxAge = int64(v)
case int64:
maxAge = v
}
switch {
case maxAge > 0:
fmt.Fprintf(&b, "; Expires=%s; Max-Age=%d", time.Now().Add(time.Duration(maxAge)*time.Second).UTC().Format(time.RFC1123), maxAge)
case maxAge < 0:
fmt.Fprintf(&b, "; Max-Age=0")
}
}
// the settings below
// Path, Domain, Secure, HttpOnly
// can use nil skip set
// default "/"
if len(others) > 1 {
if v, ok := others[1].(string); ok && len(v) > 0 {
fmt.Fprintf(&b, "; Path=%s", sanitizeValue(v))
}
} else {
fmt.Fprintf(&b, "; Path=%s", "/")
}
// default empty
if len(others) > 2 {
if v, ok := others[2].(string); ok && len(v) > 0 {
fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(v))
}
}
// default empty
if len(others) > 3 {
var secure bool
switch v := others[3].(type) {
case bool:
secure = v
default:
if others[3] != nil {
secure = true
}
}
if secure {
fmt.Fprintf(&b, "; Secure")
}
}
// default false. for session cookie default true
if len(others) > 4 {
if v, ok := others[4].(bool); ok && v {
fmt.Fprintf(&b, "; HttpOnly")
}
}
output.Context.ResponseWriter.Header().Add("Set-Cookie", b.String())
}
var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-")
func sanitizeName(n string) string {
return cookieNameSanitizer.Replace(n)
}
var cookieValueSanitizer = strings.NewReplacer("\n", " ", "\r", " ", ";", " ")
func sanitizeValue(v string) string {
return cookieValueSanitizer.Replace(v)
}
func jsonRenderer(value interface{}) Renderer {
return rendererFunc(func(ctx *Context) {
ctx.Output.JSON(value, false, false)
})
}
func errorRenderer(err error) Renderer {
return rendererFunc(func(ctx *Context) {
ctx.Output.SetStatus(500)
ctx.Output.Body([]byte(err.Error()))
})
}
// JSON writes json to the response body.
// if encoding is true, it converts utf-8 to \u0000 type.
func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) error {
output.Header("Content-Type", "application/json; charset=utf-8")
var content []byte
var err error
if hasIndent {
content, err = json.MarshalIndent(data, "", " ")
} else {
content, err = json.Marshal(data)
}
if err != nil {
http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError)
return err
}
if encoding {
content = []byte(stringsToJSON(string(content)))
}
return output.Body(content)
}
// YAML writes yaml to the response body.
func (output *BeegoOutput) YAML(data interface{}) error {
output.Header("Content-Type", "application/x-yaml; charset=utf-8")
var content []byte
var err error
content, err = yaml.Marshal(data)
if err != nil {
http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError)
return err
}
return output.Body(content)
}
// JSONP writes jsonp to the response body.
func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error {
output.Header("Content-Type", "application/javascript; charset=utf-8")
var content []byte
var err error
if hasIndent {
content, err = json.MarshalIndent(data, "", " ")
} else {
content, err = json.Marshal(data)
}
if err != nil {
http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError)
return err
}
callback := output.Context.Input.Query("callback")
if callback == "" {
return errors.New(`"callback" parameter required`)
}
callback = template.JSEscapeString(callback)
callbackContent := bytes.NewBufferString(" if(window." + callback + ")" + callback)
callbackContent.WriteString("(")
callbackContent.Write(content)
callbackContent.WriteString(");\r\n")
return output.Body(callbackContent.Bytes())
}
// XML writes xml string to the response body.
func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error {
output.Header("Content-Type", "application/xml; charset=utf-8")
var content []byte
var err error
if hasIndent {
content, err = xml.MarshalIndent(data, "", " ")
} else {
content, err = xml.Marshal(data)
}
if err != nil {
http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError)
return err
}
return output.Body(content)
}
// ServeFormatted serves YAML, XML or JSON, depending on the value of the Accept header
func (output *BeegoOutput) ServeFormatted(data interface{}, hasIndent bool, hasEncode ...bool) {
accept := output.Context.Input.Header("Accept")
switch accept {
case ApplicationYAML:
output.YAML(data)
case ApplicationXML, TextXML:
output.XML(data, hasIndent)
default:
output.JSON(data, hasIndent, len(hasEncode) > 0 && hasEncode[0])
}
}
// Download forces response for download file.
// Prepares the download response header automatically.
func (output *BeegoOutput) Download(file string, filename ...string) {
// check get file error, file not found or other error.
if _, err := os.Stat(file); err != nil {
http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file)
return
}
var fName string
if len(filename) > 0 && filename[0] != "" {
fName = filename[0]
} else {
fName = filepath.Base(file)
}
//https://tools.ietf.org/html/rfc6266#section-4.3
fn := url.PathEscape(fName)
if fName == fn {
fn = "filename=" + fn
} else {
/**
The parameters "filename" and "filename*" differ only in that
"filename*" uses the encoding defined in [RFC5987], allowing the use
of characters not present in the ISO-8859-1 character set
([ISO-8859-1]).
*/
fn = "filename=" + fName + "; filename*=utf-8''" + fn
}
output.Header("Content-Disposition", "attachment; "+fn)
output.Header("Content-Description", "File Transfer")
output.Header("Content-Type", "application/octet-stream")
output.Header("Content-Transfer-Encoding", "binary")
output.Header("Expires", "0")
output.Header("Cache-Control", "must-revalidate")
output.Header("Pragma", "public")
http.ServeFile(output.Context.ResponseWriter, output.Context.Request, file)
}
// ContentType sets the content type from ext string.
// MIME type is given in mime package.
func (output *BeegoOutput) ContentType(ext string) {
if !strings.HasPrefix(ext, ".") {
ext = "." + ext
}
ctype := mime.TypeByExtension(ext)
if ctype != "" {
output.Header("Content-Type", ctype)
}
}
// SetStatus sets the response status code.
// Writes response header directly.
func (output *BeegoOutput) SetStatus(status int) {
output.Status = status
}
// IsCachable returns boolean of if this request is cached.
// HTTP 304 means cached.
func (output *BeegoOutput) IsCachable() bool {
return output.Status >= 200 && output.Status < 300 || output.Status == 304
}
// IsEmpty returns boolean of if this request is empty.
// HTTP 201204 and 304 means empty.
func (output *BeegoOutput) IsEmpty() bool {
return output.Status == 201 || output.Status == 204 || output.Status == 304
}
// IsOk returns boolean of if this request was ok.
// HTTP 200 means ok.
func (output *BeegoOutput) IsOk() bool {
return output.Status == 200
}
// IsSuccessful returns boolean of this request was successful.
// HTTP 2xx means ok.
func (output *BeegoOutput) IsSuccessful() bool {
return output.Status >= 200 && output.Status < 300
}
// IsRedirect returns boolean of if this request is redirected.
// HTTP 301,302,307 means redirection.
func (output *BeegoOutput) IsRedirect() bool {
return output.Status == 301 || output.Status == 302 || output.Status == 303 || output.Status == 307
}
// IsForbidden returns boolean of if this request is forbidden.
// HTTP 403 means forbidden.
func (output *BeegoOutput) IsForbidden() bool {
return output.Status == 403
}
// IsNotFound returns boolean of if this request is not found.
// HTTP 404 means not found.
func (output *BeegoOutput) IsNotFound() bool {
return output.Status == 404
}
// IsClientError returns boolean of if this request client sends error data.
// HTTP 4xx means client error.
func (output *BeegoOutput) IsClientError() bool {
return output.Status >= 400 && output.Status < 500
}
// IsServerError returns boolean of if this server handler errors.
// HTTP 5xx means server internal error.
func (output *BeegoOutput) IsServerError() bool {
return output.Status >= 500 && output.Status < 600
}
func stringsToJSON(str string) string {
var jsons bytes.Buffer
for _, r := range str {
rint := int(r)
if rint < 128 {
jsons.WriteRune(r)
} else {
jsons.WriteString("\\u")
if rint < 0x100 {
jsons.WriteString("00")
} else if rint < 0x1000 {
jsons.WriteString("0")
}
jsons.WriteString(strconv.FormatInt(int64(rint), 16))
}
}
return jsons.String()
}
// Session sets session item value with given key.
func (output *BeegoOutput) Session(name interface{}, value interface{}) {
output.Context.Input.CruSession.Set(nil, name, value)
}

View File

@ -0,0 +1,78 @@
package param
import (
"fmt"
"reflect"
"github.com/astaxie/beego/core/logs"
beecontext "github.com/astaxie/beego/server/web/context"
)
// ConvertParams converts http method params to values that will be passed to the method controller as arguments
func ConvertParams(methodParams []*MethodParam, methodType reflect.Type, ctx *beecontext.Context) (result []reflect.Value) {
result = make([]reflect.Value, 0, len(methodParams))
for i := 0; i < len(methodParams); i++ {
reflectValue := convertParam(methodParams[i], methodType.In(i), ctx)
result = append(result, reflectValue)
}
return
}
func convertParam(param *MethodParam, paramType reflect.Type, ctx *beecontext.Context) (result reflect.Value) {
paramValue := getParamValue(param, ctx)
if paramValue == "" {
if param.required {
ctx.Abort(400, fmt.Sprintf("Missing parameter %s", param.name))
} else {
paramValue = param.defaultValue
}
}
reflectValue, err := parseValue(param, paramValue, paramType)
if err != nil {
logs.Debug(fmt.Sprintf("Error converting param %s to type %s. Value: %v, Error: %s", param.name, paramType, paramValue, err))
ctx.Abort(400, fmt.Sprintf("Invalid parameter %s. Can not convert %v to type %s", param.name, paramValue, paramType))
}
return reflectValue
}
func getParamValue(param *MethodParam, ctx *beecontext.Context) string {
switch param.in {
case body:
return string(ctx.Input.RequestBody)
case header:
return ctx.Input.Header(param.name)
case path:
return ctx.Input.Query(":" + param.name)
default:
return ctx.Input.Query(param.name)
}
}
func parseValue(param *MethodParam, paramValue string, paramType reflect.Type) (result reflect.Value, err error) {
if paramValue == "" {
return reflect.Zero(paramType), nil
}
parser := getParser(param, paramType)
value, err := parser.parse(paramValue, paramType)
if err != nil {
return result, err
}
return safeConvert(reflect.ValueOf(value), paramType)
}
func safeConvert(value reflect.Value, t reflect.Type) (result reflect.Value, err error) {
defer func() {
if r := recover(); r != nil {
var ok bool
err, ok = r.(error)
if !ok {
err = fmt.Errorf("%v", r)
}
}
}()
result = value.Convert(t)
return
}

View File

@ -0,0 +1,69 @@
package param
import (
"fmt"
"strings"
)
//MethodParam keeps param information to be auto passed to controller methods
type MethodParam struct {
name string
in paramType
required bool
defaultValue string
}
type paramType byte
const (
param paramType = iota
path
body
header
)
// New creates a new MethodParam with name and specific options
func New(name string, opts ...MethodParamOption) *MethodParam {
return newParam(name, nil, opts)
}
func newParam(name string, parser paramParser, opts []MethodParamOption) (param *MethodParam) {
param = &MethodParam{name: name}
for _, option := range opts {
option(param)
}
return
}
// Make creates an array of MethodParmas or an empty array
func Make(list ...*MethodParam) []*MethodParam {
if len(list) > 0 {
return list
}
return nil
}
func (mp *MethodParam) String() string {
options := []string{}
result := "param.New(\"" + mp.name + "\""
if mp.required {
options = append(options, "param.IsRequired")
}
switch mp.in {
case path:
options = append(options, "param.InPath")
case body:
options = append(options, "param.InBody")
case header:
options = append(options, "param.InHeader")
}
if mp.defaultValue != "" {
options = append(options, fmt.Sprintf(`param.Default("%s")`, mp.defaultValue))
}
if len(options) > 0 {
result += ", "
}
result += strings.Join(options, ", ")
result += ")"
return result
}

View File

@ -0,0 +1,37 @@
package param
import (
"fmt"
)
// MethodParamOption defines a func which apply options on a MethodParam
type MethodParamOption func(*MethodParam)
// IsRequired indicates that this param is required and can not be omitted from the http request
var IsRequired MethodParamOption = func(p *MethodParam) {
p.required = true
}
// InHeader indicates that this param is passed via an http header
var InHeader MethodParamOption = func(p *MethodParam) {
p.in = header
}
// InPath indicates that this param is part of the URL path
var InPath MethodParamOption = func(p *MethodParam) {
p.in = path
}
// InBody indicates that this param is passed as an http request body
var InBody MethodParamOption = func(p *MethodParam) {
p.in = body
}
// Default provides a default value for the http param
func Default(defaultValue interface{}) MethodParamOption {
return func(p *MethodParam) {
if defaultValue != nil {
p.defaultValue = fmt.Sprint(defaultValue)
}
}
}

View File

@ -0,0 +1,149 @@
package param
import (
"encoding/json"
"reflect"
"strconv"
"strings"
"time"
)
type paramParser interface {
parse(value string, toType reflect.Type) (interface{}, error)
}
func getParser(param *MethodParam, t reflect.Type) paramParser {
switch t.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return intParser{}
case reflect.Slice:
if t.Elem().Kind() == reflect.Uint8 { //treat []byte as string
return stringParser{}
}
if param.in == body {
return jsonParser{}
}
elemParser := getParser(param, t.Elem())
if elemParser == (jsonParser{}) {
return elemParser
}
return sliceParser(elemParser)
case reflect.Bool:
return boolParser{}
case reflect.String:
return stringParser{}
case reflect.Float32, reflect.Float64:
return floatParser{}
case reflect.Ptr:
elemParser := getParser(param, t.Elem())
if elemParser == (jsonParser{}) {
return elemParser
}
return ptrParser(elemParser)
default:
if t.PkgPath() == "time" && t.Name() == "Time" {
return timeParser{}
}
return jsonParser{}
}
}
type parserFunc func(value string, toType reflect.Type) (interface{}, error)
func (f parserFunc) parse(value string, toType reflect.Type) (interface{}, error) {
return f(value, toType)
}
type boolParser struct {
}
func (p boolParser) parse(value string, toType reflect.Type) (interface{}, error) {
return strconv.ParseBool(value)
}
type stringParser struct {
}
func (p stringParser) parse(value string, toType reflect.Type) (interface{}, error) {
return value, nil
}
type intParser struct {
}
func (p intParser) parse(value string, toType reflect.Type) (interface{}, error) {
return strconv.Atoi(value)
}
type floatParser struct {
}
func (p floatParser) parse(value string, toType reflect.Type) (interface{}, error) {
if toType.Kind() == reflect.Float32 {
res, err := strconv.ParseFloat(value, 32)
if err != nil {
return nil, err
}
return float32(res), nil
}
return strconv.ParseFloat(value, 64)
}
type timeParser struct {
}
func (p timeParser) parse(value string, toType reflect.Type) (result interface{}, err error) {
result, err = time.Parse(time.RFC3339, value)
if err != nil {
result, err = time.Parse("2006-01-02", value)
}
return
}
type jsonParser struct {
}
func (p jsonParser) parse(value string, toType reflect.Type) (interface{}, error) {
pResult := reflect.New(toType)
v := pResult.Interface()
err := json.Unmarshal([]byte(value), v)
if err != nil {
return nil, err
}
return pResult.Elem().Interface(), nil
}
func sliceParser(elemParser paramParser) paramParser {
return parserFunc(func(value string, toType reflect.Type) (interface{}, error) {
values := strings.Split(value, ",")
result := reflect.MakeSlice(toType, 0, len(values))
elemType := toType.Elem()
for _, v := range values {
parsedValue, err := elemParser.parse(v, elemType)
if err != nil {
return nil, err
}
result = reflect.Append(result, reflect.ValueOf(parsedValue))
}
return result.Interface(), nil
})
}
func ptrParser(elemParser paramParser) paramParser {
return parserFunc(func(value string, toType reflect.Type) (interface{}, error) {
parsedValue, err := elemParser.parse(value, toType.Elem())
if err != nil {
return nil, err
}
newValPtr := reflect.New(toType.Elem())
newVal := reflect.Indirect(newValPtr)
convertedVal, err := safeConvert(reflect.ValueOf(parsedValue), toType.Elem())
if err != nil {
return nil, err
}
newVal.Set(convertedVal)
return newValPtr.Interface(), nil
})
}

View File

@ -0,0 +1,86 @@
package param
import (
"reflect"
"testing"
"time"
)
type testDefinition struct {
strValue string
expectedValue interface{}
expectedParser paramParser
}
func Test_Parsers(t *testing.T) {
//ints
checkParser(testDefinition{"1", 1, intParser{}}, t)
checkParser(testDefinition{"-1", int64(-1), intParser{}}, t)
checkParser(testDefinition{"1", uint64(1), intParser{}}, t)
//floats
checkParser(testDefinition{"1.0", float32(1.0), floatParser{}}, t)
checkParser(testDefinition{"-1.0", float64(-1.0), floatParser{}}, t)
//strings
checkParser(testDefinition{"AB", "AB", stringParser{}}, t)
checkParser(testDefinition{"AB", []byte{65, 66}, stringParser{}}, t)
//bools
checkParser(testDefinition{"true", true, boolParser{}}, t)
checkParser(testDefinition{"0", false, boolParser{}}, t)
//timeParser
checkParser(testDefinition{"2017-05-30T13:54:53Z", time.Date(2017, 5, 30, 13, 54, 53, 0, time.UTC), timeParser{}}, t)
checkParser(testDefinition{"2017-05-30", time.Date(2017, 5, 30, 0, 0, 0, 0, time.UTC), timeParser{}}, t)
//json
checkParser(testDefinition{`{"X": 5, "Y":"Z"}`, struct {
X int
Y string
}{5, "Z"}, jsonParser{}}, t)
//slice in query is parsed as comma delimited
checkParser(testDefinition{`1,2`, []int{1, 2}, sliceParser(intParser{})}, t)
//slice in body is parsed as json
checkParser(testDefinition{`["a","b"]`, []string{"a", "b"}, jsonParser{}}, t, MethodParam{in: body})
//pointers
var someInt = 1
checkParser(testDefinition{`1`, &someInt, ptrParser(intParser{})}, t)
var someStruct = struct{ X int }{5}
checkParser(testDefinition{`{"X": 5}`, &someStruct, jsonParser{}}, t)
}
func checkParser(def testDefinition, t *testing.T, methodParam ...MethodParam) {
toType := reflect.TypeOf(def.expectedValue)
var mp MethodParam
if len(methodParam) == 0 {
mp = MethodParam{}
} else {
mp = methodParam[0]
}
parser := getParser(&mp, toType)
if reflect.TypeOf(parser) != reflect.TypeOf(def.expectedParser) {
t.Errorf("Invalid parser for value %v. Expected: %v, actual: %v", def.strValue, reflect.TypeOf(def.expectedParser).Name(), reflect.TypeOf(parser).Name())
return
}
result, err := parser.parse(def.strValue, toType)
if err != nil {
t.Errorf("Parsing error for value %v. Expected result: %v, error: %v", def.strValue, def.expectedValue, err)
return
}
convResult, err := safeConvert(reflect.ValueOf(result), toType)
if err != nil {
t.Errorf("Conversion error for %v. from value: %v, toType: %v, error: %v", def.strValue, result, toType, err)
return
}
if !reflect.DeepEqual(convResult.Interface(), def.expectedValue) {
t.Errorf("Parsing error for value %v. Expected result: %v, actual: %v", def.strValue, def.expectedValue, result)
}
}

View File

@ -0,0 +1,12 @@
package context
// Renderer defines a http response renderer
type Renderer interface {
Render(ctx *Context)
}
type rendererFunc func(ctx *Context)
func (f rendererFunc) Render(ctx *Context) {
f(ctx)
}

View File

@ -0,0 +1,26 @@
package context
import (
"net/http"
"strconv"
)
const (
//BadRequest indicates HTTP error 400
BadRequest StatusCode = http.StatusBadRequest
//NotFound indicates HTTP error 404
NotFound StatusCode = http.StatusNotFound
)
// StatusCode sets the HTTP response status code
type StatusCode int
func (s StatusCode) Error() string {
return strconv.Itoa(int(s))
}
// Render sets the HTTP status code
func (s StatusCode) Render(ctx *Context) {
ctx.Output.SetStatus(int(s))
}

707
server/web/controller.go Normal file
View File

@ -0,0 +1,707 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"bytes"
"errors"
"fmt"
"html/template"
"io"
"mime/multipart"
"net/http"
"net/url"
"os"
"reflect"
"strconv"
"strings"
"github.com/astaxie/beego/server/web/session"
"github.com/astaxie/beego/server/web/context"
"github.com/astaxie/beego/server/web/context/param"
)
var (
// ErrAbort custom error when user stop request handler manually.
ErrAbort = errors.New("user stop run")
// GlobalControllerRouter store comments with controller. pkgpath+controller:comments
GlobalControllerRouter = make(map[string][]ControllerComments)
)
// ControllerFilter store the filter for controller
type ControllerFilter struct {
Pattern string
Pos int
Filter FilterFunc
ReturnOnOutput bool
ResetParams bool
}
// ControllerFilterComments store the comment for controller level filter
type ControllerFilterComments struct {
Pattern string
Pos int
Filter string // NOQA
ReturnOnOutput bool
ResetParams bool
}
// ControllerImportComments store the import comment for controller needed
type ControllerImportComments struct {
ImportPath string
ImportAlias string
}
// ControllerComments store the comment for the controller method
type ControllerComments struct {
Method string
Router string
Filters []*ControllerFilter
ImportComments []*ControllerImportComments
FilterComments []*ControllerFilterComments
AllowHTTPMethods []string
Params []map[string]string
MethodParams []*param.MethodParam
}
// ControllerCommentsSlice implements the sort interface
type ControllerCommentsSlice []ControllerComments
func (p ControllerCommentsSlice) Len() int { return len(p) }
func (p ControllerCommentsSlice) Less(i, j int) bool { return p[i].Router < p[j].Router }
func (p ControllerCommentsSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// Controller defines some basic http request handler operations, such as
// http context, template and view, session and xsrf.
type Controller struct {
// context data
Ctx *context.Context
Data map[interface{}]interface{}
// route controller info
controllerName string
actionName string
methodMapping map[string]func() //method:routertree
AppController interface{}
// template data
TplName string
ViewPath string
Layout string
LayoutSections map[string]string // the key is the section name and the value is the template name
TplPrefix string
TplExt string
EnableRender bool
// xsrf data
_xsrfToken string
XSRFExpire int
EnableXSRF bool
// session
CruSession session.Store
}
// ControllerInterface is an interface to uniform all controller handler.
type ControllerInterface interface {
Init(ct *context.Context, controllerName, actionName string, app interface{})
Prepare()
Get()
Post()
Delete()
Put()
Head()
Patch()
Options()
Trace()
Finish()
Render() error
XSRFToken() string
CheckXSRFCookie() bool
HandlerFunc(fn string) bool
URLMapping()
}
// Init generates default values of controller operations.
func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) {
c.Layout = ""
c.TplName = ""
c.controllerName = controllerName
c.actionName = actionName
c.Ctx = ctx
c.TplExt = "tpl"
c.AppController = app
c.EnableRender = true
c.EnableXSRF = true
c.Data = ctx.Input.Data()
c.methodMapping = make(map[string]func())
}
// Prepare runs after Init before request function execution.
func (c *Controller) Prepare() {}
// Finish runs after request function execution.
func (c *Controller) Finish() {}
// Get adds a request function to handle GET request.
func (c *Controller) Get() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
}
// Post adds a request function to handle POST request.
func (c *Controller) Post() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
}
// Delete adds a request function to handle DELETE request.
func (c *Controller) Delete() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
}
// Put adds a request function to handle PUT request.
func (c *Controller) Put() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
}
// Head adds a request function to handle HEAD request.
func (c *Controller) Head() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
}
// Patch adds a request function to handle PATCH request.
func (c *Controller) Patch() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
}
// Options adds a request function to handle OPTIONS request.
func (c *Controller) Options() {
http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
}
// Trace adds a request function to handle Trace request.
// this method SHOULD NOT be overridden.
// https://tools.ietf.org/html/rfc7231#section-4.3.8
// The TRACE method requests a remote, application-level loop-back of
// the request message. The final recipient of the request SHOULD
// reflect the message received, excluding some fields described below,
// back to the client as the message body of a 200 (OK) response with a
// Content-Type of "message/http" (Section 8.3.1 of [RFC7230]).
func (c *Controller) Trace() {
ts := func(h http.Header) (hs string) {
for k, v := range h {
hs += fmt.Sprintf("\r\n%s: %s", k, v)
}
return
}
hs := fmt.Sprintf("\r\nTRACE %s %s%s\r\n", c.Ctx.Request.RequestURI, c.Ctx.Request.Proto, ts(c.Ctx.Request.Header))
c.Ctx.Output.Header("Content-Type", "message/http")
c.Ctx.Output.Header("Content-Length", fmt.Sprint(len(hs)))
c.Ctx.Output.Header("Cache-Control", "no-cache, no-store, must-revalidate")
c.Ctx.WriteString(hs)
}
// HandlerFunc call function with the name
func (c *Controller) HandlerFunc(fnname string) bool {
if v, ok := c.methodMapping[fnname]; ok {
v()
return true
}
return false
}
// URLMapping register the internal Controller router.
func (c *Controller) URLMapping() {}
// Mapping the method to function
func (c *Controller) Mapping(method string, fn func()) {
c.methodMapping[method] = fn
}
// Render sends the response with rendered template bytes as text/html type.
func (c *Controller) Render() error {
if !c.EnableRender {
return nil
}
rb, err := c.RenderBytes()
if err != nil {
return err
}
if c.Ctx.ResponseWriter.Header().Get("Content-Type") == "" {
c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8")
}
return c.Ctx.Output.Body(rb)
}
// RenderString returns the rendered template string. Do not send out response.
func (c *Controller) RenderString() (string, error) {
b, e := c.RenderBytes()
return string(b), e
}
// RenderBytes returns the bytes of rendered template string. Do not send out response.
func (c *Controller) RenderBytes() ([]byte, error) {
buf, err := c.renderTemplate()
//if the controller has set layout, then first get the tplName's content set the content to the layout
if err == nil && c.Layout != "" {
c.Data["LayoutContent"] = template.HTML(buf.String())
if c.LayoutSections != nil {
for sectionName, sectionTpl := range c.LayoutSections {
if sectionTpl == "" {
c.Data[sectionName] = ""
continue
}
buf.Reset()
err = ExecuteViewPathTemplate(&buf, sectionTpl, c.viewPath(), c.Data)
if err != nil {
return nil, err
}
c.Data[sectionName] = template.HTML(buf.String())
}
}
buf.Reset()
ExecuteViewPathTemplate(&buf, c.Layout, c.viewPath(), c.Data)
}
return buf.Bytes(), err
}
func (c *Controller) renderTemplate() (bytes.Buffer, error) {
var buf bytes.Buffer
if c.TplName == "" {
c.TplName = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt
}
if c.TplPrefix != "" {
c.TplName = c.TplPrefix + c.TplName
}
if BConfig.RunMode == DEV {
buildFiles := []string{c.TplName}
if c.Layout != "" {
buildFiles = append(buildFiles, c.Layout)
if c.LayoutSections != nil {
for _, sectionTpl := range c.LayoutSections {
if sectionTpl == "" {
continue
}
buildFiles = append(buildFiles, sectionTpl)
}
}
}
BuildTemplate(c.viewPath(), buildFiles...)
}
return buf, ExecuteViewPathTemplate(&buf, c.TplName, c.viewPath(), c.Data)
}
func (c *Controller) viewPath() string {
if c.ViewPath == "" {
return BConfig.WebConfig.ViewsPath
}
return c.ViewPath
}
// Redirect sends the redirection response to url with status code.
func (c *Controller) Redirect(url string, code int) {
LogAccess(c.Ctx, nil, code)
c.Ctx.Redirect(code, url)
}
// SetData set the data depending on the accepted
func (c *Controller) SetData(data interface{}) {
accept := c.Ctx.Input.Header("Accept")
switch accept {
case context.ApplicationYAML:
c.Data["yaml"] = data
case context.ApplicationXML, context.TextXML:
c.Data["xml"] = data
default:
c.Data["json"] = data
}
}
// Abort stops controller handler and show the error data if code is defined in ErrorMap or code string.
func (c *Controller) Abort(code string) {
status, err := strconv.Atoi(code)
if err != nil {
status = 200
}
c.CustomAbort(status, code)
}
// CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body.
func (c *Controller) CustomAbort(status int, body string) {
// first panic from ErrorMaps, it is user defined error functions.
if _, ok := ErrorMaps[body]; ok {
c.Ctx.Output.Status = status
panic(body)
}
// last panic user string
c.Ctx.ResponseWriter.WriteHeader(status)
c.Ctx.ResponseWriter.Write([]byte(body))
panic(ErrAbort)
}
// StopRun makes panic of USERSTOPRUN error and go to recover function if defined.
func (c *Controller) StopRun() {
panic(ErrAbort)
}
// URLFor does another controller handler in this request function.
// it goes to this controller method if endpoint is not clear.
func (c *Controller) URLFor(endpoint string, values ...interface{}) string {
if len(endpoint) == 0 {
return ""
}
if endpoint[0] == '.' {
return URLFor(reflect.Indirect(reflect.ValueOf(c.AppController)).Type().Name()+endpoint, values...)
}
return URLFor(endpoint, values...)
}
// ServeJSON sends a json response with encoding charset.
func (c *Controller) ServeJSON(encoding ...bool) {
var (
hasIndent = BConfig.RunMode != PROD
hasEncoding = len(encoding) > 0 && encoding[0]
)
c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding)
}
// ServeJSONP sends a jsonp response.
func (c *Controller) ServeJSONP() {
hasIndent := BConfig.RunMode != PROD
c.Ctx.Output.JSONP(c.Data["jsonp"], hasIndent)
}
// ServeXML sends xml response.
func (c *Controller) ServeXML() {
hasIndent := BConfig.RunMode != PROD
c.Ctx.Output.XML(c.Data["xml"], hasIndent)
}
// ServeYAML sends yaml response.
func (c *Controller) ServeYAML() {
c.Ctx.Output.YAML(c.Data["yaml"])
}
// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header
func (c *Controller) ServeFormatted(encoding ...bool) {
hasIndent := BConfig.RunMode != PROD
hasEncoding := len(encoding) > 0 && encoding[0]
c.Ctx.Output.ServeFormatted(c.Data, hasIndent, hasEncoding)
}
// Input returns the input data map from POST or PUT request body and query string.
func (c *Controller) Input() url.Values {
if c.Ctx.Request.Form == nil {
c.Ctx.Request.ParseForm()
}
return c.Ctx.Request.Form
}
// ParseForm maps input data map to obj struct.
func (c *Controller) ParseForm(obj interface{}) error {
return ParseForm(c.Input(), obj)
}
// GetString returns the input value by key string or the default value while it's present and input is blank
func (c *Controller) GetString(key string, def ...string) string {
if v := c.Ctx.Input.Query(key); v != "" {
return v
}
if len(def) > 0 {
return def[0]
}
return ""
}
// GetStrings returns the input string slice by key string or the default value while it's present and input is blank
// it's designed for multi-value input field such as checkbox(input[type=checkbox]), multi-selection.
func (c *Controller) GetStrings(key string, def ...[]string) []string {
var defv []string
if len(def) > 0 {
defv = def[0]
}
if f := c.Input(); f == nil {
return defv
} else if vs := f[key]; len(vs) > 0 {
return vs
}
return defv
}
// GetInt returns input as an int or the default value while it's present and input is blank
func (c *Controller) GetInt(key string, def ...int) (int, error) {
strv := c.Ctx.Input.Query(key)
if len(strv) == 0 && len(def) > 0 {
return def[0], nil
}
return strconv.Atoi(strv)
}
// GetInt8 return input as an int8 or the default value while it's present and input is blank
func (c *Controller) GetInt8(key string, def ...int8) (int8, error) {
strv := c.Ctx.Input.Query(key)
if len(strv) == 0 && len(def) > 0 {
return def[0], nil
}
i64, err := strconv.ParseInt(strv, 10, 8)
return int8(i64), err
}
// GetUint8 return input as an uint8 or the default value while it's present and input is blank
func (c *Controller) GetUint8(key string, def ...uint8) (uint8, error) {
strv := c.Ctx.Input.Query(key)
if len(strv) == 0 && len(def) > 0 {
return def[0], nil
}
u64, err := strconv.ParseUint(strv, 10, 8)
return uint8(u64), err
}
// GetInt16 returns input as an int16 or the default value while it's present and input is blank
func (c *Controller) GetInt16(key string, def ...int16) (int16, error) {
strv := c.Ctx.Input.Query(key)
if len(strv) == 0 && len(def) > 0 {
return def[0], nil
}
i64, err := strconv.ParseInt(strv, 10, 16)
return int16(i64), err
}
// GetUint16 returns input as an uint16 or the default value while it's present and input is blank
func (c *Controller) GetUint16(key string, def ...uint16) (uint16, error) {
strv := c.Ctx.Input.Query(key)
if len(strv) == 0 && len(def) > 0 {
return def[0], nil
}
u64, err := strconv.ParseUint(strv, 10, 16)
return uint16(u64), err
}
// GetInt32 returns input as an int32 or the default value while it's present and input is blank
func (c *Controller) GetInt32(key string, def ...int32) (int32, error) {
strv := c.Ctx.Input.Query(key)
if len(strv) == 0 && len(def) > 0 {
return def[0], nil
}
i64, err := strconv.ParseInt(strv, 10, 32)
return int32(i64), err
}
// GetUint32 returns input as an uint32 or the default value while it's present and input is blank
func (c *Controller) GetUint32(key string, def ...uint32) (uint32, error) {
strv := c.Ctx.Input.Query(key)
if len(strv) == 0 && len(def) > 0 {
return def[0], nil
}
u64, err := strconv.ParseUint(strv, 10, 32)
return uint32(u64), err
}
// GetInt64 returns input value as int64 or the default value while it's present and input is blank.
func (c *Controller) GetInt64(key string, def ...int64) (int64, error) {
strv := c.Ctx.Input.Query(key)
if len(strv) == 0 && len(def) > 0 {
return def[0], nil
}
return strconv.ParseInt(strv, 10, 64)
}
// GetUint64 returns input value as uint64 or the default value while it's present and input is blank.
func (c *Controller) GetUint64(key string, def ...uint64) (uint64, error) {
strv := c.Ctx.Input.Query(key)
if len(strv) == 0 && len(def) > 0 {
return def[0], nil
}
return strconv.ParseUint(strv, 10, 64)
}
// GetBool returns input value as bool or the default value while it's present and input is blank.
func (c *Controller) GetBool(key string, def ...bool) (bool, error) {
strv := c.Ctx.Input.Query(key)
if len(strv) == 0 && len(def) > 0 {
return def[0], nil
}
return strconv.ParseBool(strv)
}
// GetFloat returns input value as float64 or the default value while it's present and input is blank.
func (c *Controller) GetFloat(key string, def ...float64) (float64, error) {
strv := c.Ctx.Input.Query(key)
if len(strv) == 0 && len(def) > 0 {
return def[0], nil
}
return strconv.ParseFloat(strv, 64)
}
// GetFile returns the file data in file upload field named as key.
// it returns the first one of multi-uploaded files.
func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader, error) {
return c.Ctx.Request.FormFile(key)
}
// GetFiles return multi-upload files
// files, err:=c.GetFiles("myfiles")
// if err != nil {
// http.Error(w, err.Error(), http.StatusNoContent)
// return
// }
// for i, _ := range files {
// //for each fileheader, get a handle to the actual file
// file, err := files[i].Open()
// defer file.Close()
// if err != nil {
// http.Error(w, err.Error(), http.StatusInternalServerError)
// return
// }
// //create destination file making sure the path is writeable.
// dst, err := os.Create("upload/" + files[i].Filename)
// defer dst.Close()
// if err != nil {
// http.Error(w, err.Error(), http.StatusInternalServerError)
// return
// }
// //copy the uploaded file to the destination file
// if _, err := io.Copy(dst, file); err != nil {
// http.Error(w, err.Error(), http.StatusInternalServerError)
// return
// }
// }
func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) {
if files, ok := c.Ctx.Request.MultipartForm.File[key]; ok {
return files, nil
}
return nil, http.ErrMissingFile
}
// SaveToFile saves uploaded file to new path.
// it only operates the first one of mutil-upload form file field.
func (c *Controller) SaveToFile(fromfile, tofile string) error {
file, _, err := c.Ctx.Request.FormFile(fromfile)
if err != nil {
return err
}
defer file.Close()
f, err := os.OpenFile(tofile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666)
if err != nil {
return err
}
defer f.Close()
io.Copy(f, file)
return nil
}
// StartSession starts session and load old session data info this controller.
func (c *Controller) StartSession() session.Store {
if c.CruSession == nil {
c.CruSession = c.Ctx.Input.CruSession
}
return c.CruSession
}
// SetSession puts value into session.
func (c *Controller) SetSession(name interface{}, value interface{}) {
if c.CruSession == nil {
c.StartSession()
}
c.CruSession.Set(nil, name, value)
}
// GetSession gets value from session.
func (c *Controller) GetSession(name interface{}) interface{} {
if c.CruSession == nil {
c.StartSession()
}
return c.CruSession.Get(nil, name)
}
// DelSession removes value from session.
func (c *Controller) DelSession(name interface{}) {
if c.CruSession == nil {
c.StartSession()
}
c.CruSession.Delete(nil, name)
}
// SessionRegenerateID regenerates session id for this session.
// the session data have no changes.
func (c *Controller) SessionRegenerateID() {
if c.CruSession != nil {
c.CruSession.SessionRelease(nil, c.Ctx.ResponseWriter)
}
c.CruSession = GlobalSessions.SessionRegenerateID(c.Ctx.ResponseWriter, c.Ctx.Request)
c.Ctx.Input.CruSession = c.CruSession
}
// DestroySession cleans session data and session cookie.
func (c *Controller) DestroySession() {
c.Ctx.Input.CruSession.Flush(nil)
c.Ctx.Input.CruSession = nil
GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request)
}
// IsAjax returns this request is ajax or not.
func (c *Controller) IsAjax() bool {
return c.Ctx.Input.IsAjax()
}
// GetSecureCookie returns decoded cookie value from encoded browser cookie values.
func (c *Controller) GetSecureCookie(Secret, key string) (string, bool) {
return c.Ctx.GetSecureCookie(Secret, key)
}
// SetSecureCookie puts value into cookie after encoded the value.
func (c *Controller) SetSecureCookie(Secret, name, value string, others ...interface{}) {
c.Ctx.SetSecureCookie(Secret, name, value, others...)
}
// XSRFToken creates a CSRF token string and returns.
func (c *Controller) XSRFToken() string {
if c._xsrfToken == "" {
expire := int64(BConfig.WebConfig.XSRFExpire)
if c.XSRFExpire > 0 {
expire = int64(c.XSRFExpire)
}
c._xsrfToken = c.Ctx.XSRFToken(BConfig.WebConfig.XSRFKey, expire)
}
return c._xsrfToken
}
// CheckXSRFCookie checks xsrf token in this request is valid or not.
// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken"
// or in form field value named as "_xsrf".
func (c *Controller) CheckXSRFCookie() bool {
if !c.EnableXSRF {
return true
}
return c.Ctx.CheckXSRFCookie()
}
// XSRFFormHTML writes an input field contains xsrf token value.
func (c *Controller) XSRFFormHTML() string {
return `<input type="hidden" name="_xsrf" value="` +
c.XSRFToken() + `" />`
}
// GetControllerAndAction gets the executing controller name and action name.
func (c *Controller) GetControllerAndAction() (string, string) {
return c.controllerName, c.actionName
}

View File

@ -0,0 +1,185 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"math"
"os"
"path/filepath"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/astaxie/beego/server/web/context"
)
func TestGetInt(t *testing.T) {
i := context.NewInput()
i.SetParam("age", "40")
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
val, _ := ctrlr.GetInt("age")
if val != 40 {
t.Errorf("TestGetInt expect 40,get %T,%v", val, val)
}
}
func TestGetInt8(t *testing.T) {
i := context.NewInput()
i.SetParam("age", "40")
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
val, _ := ctrlr.GetInt8("age")
if val != 40 {
t.Errorf("TestGetInt8 expect 40,get %T,%v", val, val)
}
//Output: int8
}
func TestGetInt16(t *testing.T) {
i := context.NewInput()
i.SetParam("age", "40")
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
val, _ := ctrlr.GetInt16("age")
if val != 40 {
t.Errorf("TestGetInt16 expect 40,get %T,%v", val, val)
}
}
func TestGetInt32(t *testing.T) {
i := context.NewInput()
i.SetParam("age", "40")
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
val, _ := ctrlr.GetInt32("age")
if val != 40 {
t.Errorf("TestGetInt32 expect 40,get %T,%v", val, val)
}
}
func TestGetInt64(t *testing.T) {
i := context.NewInput()
i.SetParam("age", "40")
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
val, _ := ctrlr.GetInt64("age")
if val != 40 {
t.Errorf("TestGeetInt64 expect 40,get %T,%v", val, val)
}
}
func TestGetUint8(t *testing.T) {
i := context.NewInput()
i.SetParam("age", strconv.FormatUint(math.MaxUint8, 10))
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
val, _ := ctrlr.GetUint8("age")
if val != math.MaxUint8 {
t.Errorf("TestGetUint8 expect %v,get %T,%v", math.MaxUint8, val, val)
}
}
func TestGetUint16(t *testing.T) {
i := context.NewInput()
i.SetParam("age", strconv.FormatUint(math.MaxUint16, 10))
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
val, _ := ctrlr.GetUint16("age")
if val != math.MaxUint16 {
t.Errorf("TestGetUint16 expect %v,get %T,%v", math.MaxUint16, val, val)
}
}
func TestGetUint32(t *testing.T) {
i := context.NewInput()
i.SetParam("age", strconv.FormatUint(math.MaxUint32, 10))
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
val, _ := ctrlr.GetUint32("age")
if val != math.MaxUint32 {
t.Errorf("TestGetUint32 expect %v,get %T,%v", math.MaxUint32, val, val)
}
}
func TestGetUint64(t *testing.T) {
i := context.NewInput()
i.SetParam("age", strconv.FormatUint(math.MaxUint64, 10))
ctx := &context.Context{Input: i}
ctrlr := Controller{Ctx: ctx}
val, _ := ctrlr.GetUint64("age")
if val != math.MaxUint64 {
t.Errorf("TestGetUint64 expect %v,get %T,%v", uint64(math.MaxUint64), val, val)
}
}
func TestAdditionalViewPaths(t *testing.T) {
wkdir, err := os.Getwd()
assert.Nil(t, err)
dir1 := filepath.Join(wkdir, "_beeTmp", "TestAdditionalViewPaths")
dir2 := filepath.Join(wkdir, "_beeTmp2", "TestAdditionalViewPaths")
defer os.RemoveAll(dir1)
defer os.RemoveAll(dir2)
dir1file := "file1.tpl"
dir2file := "file2.tpl"
genFile := func(dir string, name string, content string) {
os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777)
if f, err := os.Create(filepath.Join(dir, name)); err != nil {
t.Fatal(err)
} else {
defer f.Close()
f.WriteString(content)
f.Close()
}
}
genFile(dir1, dir1file, `<div>{{.Content}}</div>`)
genFile(dir2, dir2file, `<html>{{.Content}}</html>`)
AddViewPath(dir1)
AddViewPath(dir2)
ctrl := Controller{
TplName: "file1.tpl",
ViewPath: dir1,
}
ctrl.Data = map[interface{}]interface{}{
"Content": "value2",
}
if result, err := ctrl.RenderString(); err != nil {
t.Fatal(err)
} else {
if result != "<div>value2</div>" {
t.Fatalf("TestAdditionalViewPaths expect %s got %s", "<div>value2</div>", result)
}
}
func() {
ctrl.TplName = "file2.tpl"
defer func() {
if r := recover(); r == nil {
t.Fatal("TestAdditionalViewPaths expected error")
}
}()
ctrl.RenderString()
}()
ctrl.TplName = "file2.tpl"
ctrl.ViewPath = dir2
ctrl.RenderString()
}

17
server/web/doc.go Normal file
View File

@ -0,0 +1,17 @@
/*
Package beego provide a MVC framework
beego: an open-source, high-performance, modular, full-stack web framework
It is used for rapid development of RESTful APIs, web apps and backend services in Go.
beego is inspired by Tornado, Sinatra and Flask with the added benefit of some Go-specific features such as interfaces and struct embedding.
package main
import "github.com/astaxie/beego"
func main() {
beego.Run()
}
more information: http://beego.me
*/
package web

490
server/web/error.go Normal file
View File

@ -0,0 +1,490 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"fmt"
"html/template"
"net/http"
"reflect"
"runtime"
"strconv"
"strings"
"github.com/astaxie/beego"
"github.com/astaxie/beego/core/utils"
"github.com/astaxie/beego/server/web/context"
)
const (
errorTypeHandler = iota
errorTypeController
)
var tpl = `
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8" />
<title>beego application error</title>
<style>
html, body, body * {padding: 0; margin: 0;}
#header {background:#ffd; border-bottom:solid 2px #A31515; padding: 20px 10px;}
#header h2{ }
#footer {border-top:solid 1px #aaa; padding: 5px 10px; font-size: 12px; color:green;}
#content {padding: 5px;}
#content .stack b{ font-size: 13px; color: red;}
#content .stack pre{padding-left: 10px;}
table {}
td.t {text-align: right; padding-right: 5px; color: #888;}
</style>
<script type="text/javascript">
</script>
</head>
<body>
<div id="header">
<h2>{{.AppError}}</h2>
</div>
<div id="content">
<table>
<tr>
<td class="t">Request Method: </td><td>{{.RequestMethod}}</td>
</tr>
<tr>
<td class="t">Request URL: </td><td>{{.RequestURL}}</td>
</tr>
<tr>
<td class="t">RemoteAddr: </td><td>{{.RemoteAddr }}</td>
</tr>
</table>
<div class="stack">
<b>Stack</b>
<pre>{{.Stack}}</pre>
</div>
</div>
<div id="footer">
<p>beego {{ .BeegoVersion }} (beego framework)</p>
<p>golang version: {{.GoVersion}}</p>
</div>
</body>
</html>
`
// render default application error page with error and stack string.
func showErr(err interface{}, ctx *context.Context, stack string) {
t, _ := template.New("beegoerrortemp").Parse(tpl)
data := map[string]string{
"AppError": fmt.Sprintf("%s:%v", BConfig.AppName, err),
"RequestMethod": ctx.Input.Method(),
"RequestURL": ctx.Input.URI(),
"RemoteAddr": ctx.Input.IP(),
"Stack": stack,
"BeegoVersion": beego.VERSION,
"GoVersion": runtime.Version(),
}
t.Execute(ctx.ResponseWriter, data)
}
var errtpl = `
<!DOCTYPE html>
<html lang="en">
<head>
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
<title>{{.Title}}</title>
<style type="text/css">
* {
margin:0;
padding:0;
}
body {
background-color:#EFEFEF;
font: .9em "Lucida Sans Unicode", "Lucida Grande", sans-serif;
}
#wrapper{
width:600px;
margin:40px auto 0;
text-align:center;
-moz-box-shadow: 5px 5px 10px rgba(0,0,0,0.3);
-webkit-box-shadow: 5px 5px 10px rgba(0,0,0,0.3);
box-shadow: 5px 5px 10px rgba(0,0,0,0.3);
}
#wrapper h1{
color:#FFF;
text-align:center;
margin-bottom:20px;
}
#wrapper a{
display:block;
font-size:.9em;
padding-top:20px;
color:#FFF;
text-decoration:none;
text-align:center;
}
#container {
width:600px;
padding-bottom:15px;
background-color:#FFFFFF;
}
.navtop{
height:40px;
background-color:#24B2EB;
padding:13px;
}
.content {
padding:10px 10px 25px;
background: #FFFFFF;
margin:;
color:#333;
}
a.button{
color:white;
padding:15px 20px;
text-shadow:1px 1px 0 #00A5FF;
font-weight:bold;
text-align:center;
border:1px solid #24B2EB;
margin:0px 200px;
clear:both;
background-color: #24B2EB;
border-radius:100px;
-moz-border-radius:100px;
-webkit-border-radius:100px;
}
a.button:hover{
text-decoration:none;
background-color: #24B2EB;
}
</style>
</head>
<body>
<div id="wrapper">
<div id="container">
<div class="navtop">
<h1>{{.Title}}</h1>
</div>
<div id="content">
{{.Content}}
<a href="/" title="Home" class="button">Go Home</a><br />
<br>Powered by beego {{.BeegoVersion}}
</div>
</div>
</div>
</body>
</html>
`
type errorInfo struct {
controllerType reflect.Type
handler http.HandlerFunc
method string
errorType int
}
// ErrorMaps holds map of http handlers for each error string.
// there is 10 kinds default error(40x and 50x)
var ErrorMaps = make(map[string]*errorInfo, 10)
// show 401 unauthorized error.
func unauthorized(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
401,
"<br>The page you have requested can't be authorized."+
"<br>Perhaps you are here because:"+
"<br><br><ul>"+
"<br>The credentials you supplied are incorrect"+
"<br>There are errors in the website address"+
"</ul>",
)
}
// show 402 Payment Required
func paymentRequired(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
402,
"<br>The page you have requested Payment Required."+
"<br>Perhaps you are here because:"+
"<br><br><ul>"+
"<br>The credentials you supplied are incorrect"+
"<br>There are errors in the website address"+
"</ul>",
)
}
// show 403 forbidden error.
func forbidden(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
403,
"<br>The page you have requested is forbidden."+
"<br>Perhaps you are here because:"+
"<br><br><ul>"+
"<br>Your address may be blocked"+
"<br>The site may be disabled"+
"<br>You need to log in"+
"</ul>",
)
}
// show 422 missing xsrf token
func missingxsrf(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
422,
"<br>The page you have requested is forbidden."+
"<br>Perhaps you are here because:"+
"<br><br><ul>"+
"<br>'_xsrf' argument missing from POST"+
"</ul>",
)
}
// show 417 invalid xsrf token
func invalidxsrf(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
417,
"<br>The page you have requested is forbidden."+
"<br>Perhaps you are here because:"+
"<br><br><ul>"+
"<br>expected XSRF not found"+
"</ul>",
)
}
// show 404 not found error.
func notFound(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
404,
"<br>The page you have requested has flown the coop."+
"<br>Perhaps you are here because:"+
"<br><br><ul>"+
"<br>The page has moved"+
"<br>The page no longer exists"+
"<br>You were looking for your puppy and got lost"+
"<br>You like 404 pages"+
"</ul>",
)
}
// show 405 Method Not Allowed
func methodNotAllowed(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
405,
"<br>The method you have requested Not Allowed."+
"<br>Perhaps you are here because:"+
"<br><br><ul>"+
"<br>The method specified in the Request-Line is not allowed for the resource identified by the Request-URI"+
"<br>The response MUST include an Allow header containing a list of valid methods for the requested resource."+
"</ul>",
)
}
// show 500 internal server error.
func internalServerError(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
500,
"<br>The page you have requested is down right now."+
"<br><br><ul>"+
"<br>Please try again later and report the error to the website administrator"+
"<br></ul>",
)
}
// show 501 Not Implemented.
func notImplemented(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
501,
"<br>The page you have requested is Not Implemented."+
"<br><br><ul>"+
"<br>Please try again later and report the error to the website administrator"+
"<br></ul>",
)
}
// show 502 Bad Gateway.
func badGateway(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
502,
"<br>The page you have requested is down right now."+
"<br><br><ul>"+
"<br>The server, while acting as a gateway or proxy, received an invalid response from the upstream server it accessed in attempting to fulfill the request."+
"<br>Please try again later and report the error to the website administrator"+
"<br></ul>",
)
}
// show 503 service unavailable error.
func serviceUnavailable(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
503,
"<br>The page you have requested is unavailable."+
"<br>Perhaps you are here because:"+
"<br><br><ul>"+
"<br><br>The page is overloaded"+
"<br>Please try again later."+
"</ul>",
)
}
// show 504 Gateway Timeout.
func gatewayTimeout(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
504,
"<br>The page you have requested is unavailable"+
"<br>Perhaps you are here because:"+
"<br><br><ul>"+
"<br><br>The server, while acting as a gateway or proxy, did not receive a timely response from the upstream server specified by the URI."+
"<br>Please try again later."+
"</ul>",
)
}
// show 413 Payload Too Large
func payloadTooLarge(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
413,
`<br>The page you have requested is unavailable.
<br>Perhaps you are here because:<br><br>
<ul>
<br>The request entity is larger than limits defined by server.
<br>Please change the request entity and try again.
</ul>
`,
)
}
func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errContent string) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := M{
"Title": http.StatusText(errCode),
"BeegoVersion": beego.VERSION,
"Content": template.HTML(errContent),
}
t.Execute(rw, data)
}
// ErrorHandler registers http.HandlerFunc to each http err code string.
// usage:
// beego.ErrorHandler("404",NotFound)
// beego.ErrorHandler("500",InternalServerError)
func ErrorHandler(code string, h http.HandlerFunc) *HttpServer {
ErrorMaps[code] = &errorInfo{
errorType: errorTypeHandler,
handler: h,
method: code,
}
return BeeApp
}
// ErrorController registers ControllerInterface to each http err code string.
// usage:
// beego.ErrorController(&controllers.ErrorController{})
func ErrorController(c ControllerInterface) *HttpServer {
reflectVal := reflect.ValueOf(c)
rt := reflectVal.Type()
ct := reflect.Indirect(reflectVal).Type()
for i := 0; i < rt.NumMethod(); i++ {
methodName := rt.Method(i).Name
if !utils.InSlice(methodName, exceptMethod) && strings.HasPrefix(methodName, "Error") {
errName := strings.TrimPrefix(methodName, "Error")
ErrorMaps[errName] = &errorInfo{
errorType: errorTypeController,
controllerType: ct,
method: methodName,
}
}
}
return BeeApp
}
// Exception Write HttpStatus with errCode and Exec error handler if exist.
func Exception(errCode uint64, ctx *context.Context) {
exception(strconv.FormatUint(errCode, 10), ctx)
}
// show error string as simple text message.
// if error string is empty, show 503 or 500 error as default.
func exception(errCode string, ctx *context.Context) {
atoi := func(code string) int {
v, err := strconv.Atoi(code)
if err == nil {
return v
}
if ctx.Output.Status == 0 {
return 503
}
return ctx.Output.Status
}
for _, ec := range []string{errCode, "503", "500"} {
if h, ok := ErrorMaps[ec]; ok {
executeError(h, ctx, atoi(ec))
return
}
}
//if 50x error has been removed from errorMap
ctx.ResponseWriter.WriteHeader(atoi(errCode))
ctx.WriteString(errCode)
}
func executeError(err *errorInfo, ctx *context.Context, code int) {
//make sure to log the error in the access log
LogAccess(ctx, nil, code)
if err.errorType == errorTypeHandler {
ctx.ResponseWriter.WriteHeader(code)
err.handler(ctx.ResponseWriter, ctx.Request)
return
}
if err.errorType == errorTypeController {
ctx.Output.SetStatus(code)
//Invoke the request handler
vc := reflect.New(err.controllerType)
execController, ok := vc.Interface().(ControllerInterface)
if !ok {
panic("controller is not ControllerInterface")
}
//call the controller init function
execController.Init(ctx, err.controllerType.Name(), err.method, vc.Interface())
//call prepare function
execController.Prepare()
execController.URLMapping()
method := vc.MethodByName(err.method)
method.Call([]reflect.Value{})
//render template
if BConfig.WebConfig.AutoRender {
if err := execController.Render(); err != nil {
panic(err)
}
}
// finish all runrouter. release resource
execController.Finish()
}
}

88
server/web/error_test.go Normal file
View File

@ -0,0 +1,88 @@
// Copyright 2016 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
)
type errorTestController struct {
Controller
}
const parseCodeError = "parse code error"
func (ec *errorTestController) Get() {
errorCode, err := ec.GetInt("code")
if err != nil {
ec.Abort(parseCodeError)
}
if errorCode != 0 {
ec.CustomAbort(errorCode, ec.GetString("code"))
}
ec.Abort("404")
}
func TestErrorCode_01(t *testing.T) {
registerDefaultErrorHandler()
for k := range ErrorMaps {
r, _ := http.NewRequest("GET", "/error?code="+k, nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/error", &errorTestController{})
handler.ServeHTTP(w, r)
code, _ := strconv.Atoi(k)
if w.Code != code {
t.Fail()
}
if !strings.Contains(w.Body.String(), http.StatusText(code)) {
t.Fail()
}
}
}
func TestErrorCode_02(t *testing.T) {
registerDefaultErrorHandler()
r, _ := http.NewRequest("GET", "/error?code=0", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/error", &errorTestController{})
handler.ServeHTTP(w, r)
if w.Code != 404 {
t.Fail()
}
}
func TestErrorCode_03(t *testing.T) {
registerDefaultErrorHandler()
r, _ := http.NewRequest("GET", "/error?code=panic", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/error", &errorTestController{})
handler.ServeHTTP(w, r)
if w.Code != 200 {
t.Fail()
}
if w.Body.String() != parseCodeError {
t.Fail()
}
}

134
server/web/filter.go Normal file
View File

@ -0,0 +1,134 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"strings"
"github.com/astaxie/beego/server/web/context"
)
// FilterChain is different from pure FilterFunc
// when you use this, you must invoke next(ctx) inside the FilterFunc which is returned
// And all those FilterChain will be invoked before other FilterFunc
type FilterChain func(next FilterFunc) FilterFunc
// FilterFunc defines a filter function which is invoked before the controller handler is executed.
type FilterFunc func(ctx *context.Context)
// FilterRouter defines a filter operation which is invoked before the controller handler is executed.
// It can match the URL against a pattern, and execute a filter function
// when a request with a matching URL arrives.
type FilterRouter struct {
filterFunc FilterFunc
next *FilterRouter
tree *Tree
pattern string
returnOnOutput bool
resetParams bool
}
// params is for:
// 1. setting the returnOnOutput value (false allows multiple filters to execute)
// 2. determining whether or not params need to be reset.
func newFilterRouter(pattern string, filter FilterFunc, opts ...FilterOpt) *FilterRouter {
mr := &FilterRouter{
tree: NewTree(),
pattern: pattern,
filterFunc: filter,
}
fos := &filterOpts{
returnOnOutput: true,
}
for _, o := range opts {
o(fos)
}
if !fos.routerCaseSensitive {
mr.pattern = strings.ToLower(pattern)
}
mr.returnOnOutput = fos.returnOnOutput
mr.resetParams = fos.resetParams
mr.tree.AddRouter(pattern, true)
return mr
}
// filter will check whether we need to execute the filter logic
// return (started, done)
func (f *FilterRouter) filter(ctx *context.Context, urlPath string, preFilterParams map[string]string) (bool, bool) {
if f.returnOnOutput && ctx.ResponseWriter.Started {
return true, true
}
if f.resetParams {
preFilterParams = ctx.Input.Params()
}
if ok := f.ValidRouter(urlPath, ctx); ok {
f.filterFunc(ctx)
if f.resetParams {
ctx.Input.ResetParams()
for k, v := range preFilterParams {
ctx.Input.SetParam(k, v)
}
}
} else if f.next != nil {
return f.next.filter(ctx, urlPath, preFilterParams)
}
if f.returnOnOutput && ctx.ResponseWriter.Started {
return true, true
}
return false, false
}
// ValidRouter checks if the current request is matched by this filter.
// If the request is matched, the values of the URL parameters defined
// by the filter pattern are also returned.
func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool {
isOk := f.tree.Match(url, ctx)
if isOk != nil {
if b, ok := isOk.(bool); ok {
return b
}
}
return false
}
type filterOpts struct {
returnOnOutput bool
resetParams bool
routerCaseSensitive bool
}
type FilterOpt func(opts *filterOpts)
func WithReturnOnOutput(ret bool) FilterOpt {
return func(opts *filterOpts) {
opts.returnOnOutput = ret
}
}
func WithResetParams(reset bool) FilterOpt {
return func(opts *filterOpts) {
opts.resetParams = reset
}
}
func WithCaseSensitive(sensitive bool) FilterOpt {
return func(opts *filterOpts) {
opts.routerCaseSensitive = sensitive
}
}

View File

@ -0,0 +1,162 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package apiauth provides handlers to enable apiauth support.
//
// Simple Usage:
// import(
// "github.com/astaxie/beego"
// "github.com/astaxie/beego/plugins/apiauth"
// )
//
// func main(){
// // apiauth every request
// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APIBaiscAuth("appid","appkey"))
// beego.Run()
// }
//
// Advanced Usage:
//
// func getAppSecret(appid string) string {
// // get appsecret by appid
// // maybe store in configure, maybe in database
// }
//
// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APISecretAuth(getAppSecret, 360))
//
// Information:
//
// In the request user should include these params in the query
//
// 1. appid
//
// appid is assigned to the application
//
// 2. signature
//
// get the signature use apiauth.Signature()
//
// when you send to server remember use url.QueryEscape()
//
// 3. timestamp:
//
// send the request time, the format is yyyy-mm-dd HH:ii:ss
//
package apiauth
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"fmt"
"net/url"
"sort"
"time"
"github.com/astaxie/beego/server/web"
"github.com/astaxie/beego/server/web/context"
)
// AppIDToAppSecret gets appsecret through appid
type AppIDToAppSecret func(string) string
// APIBasicAuth uses the basic appid/appkey as the AppIdToAppSecret
func APIBasicAuth(appid, appkey string) web.FilterFunc {
ft := func(aid string) string {
if aid == appid {
return appkey
}
return ""
}
return APISecretAuth(ft, 300)
}
// APISecretAuth uses AppIdToAppSecret verify and
func APISecretAuth(f AppIDToAppSecret, timeout int) web.FilterFunc {
return func(ctx *context.Context) {
if ctx.Input.Query("appid") == "" {
ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("missing query parameter: appid")
return
}
appsecret := f(ctx.Input.Query("appid"))
if appsecret == "" {
ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("appid query parameter missing")
return
}
if ctx.Input.Query("signature") == "" {
ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("missing query parameter: signature")
return
}
if ctx.Input.Query("timestamp") == "" {
ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("missing query parameter: timestamp")
return
}
u, err := time.Parse("2006-01-02 15:04:05", ctx.Input.Query("timestamp"))
if err != nil {
ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("incorrect timestamp format. Should be in the form 2006-01-02 15:04:05")
return
}
t := time.Now()
if t.Sub(u).Seconds() > float64(timeout) {
ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("request timer timeout exceeded. Please try again")
return
}
if ctx.Input.Query("signature") !=
Signature(appsecret, ctx.Input.Method(), ctx.Request.Form, ctx.Input.URL()) {
ctx.ResponseWriter.WriteHeader(403)
ctx.WriteString("authentication failed")
}
}
}
// Signature generates signature with appsecret/method/params/RequestURI
func Signature(appsecret, method string, params url.Values, RequestURL string) (result string) {
var b bytes.Buffer
keys := make([]string, len(params))
pa := make(map[string]string)
for k, v := range params {
pa[k] = v[0]
keys = append(keys, k)
}
sort.Strings(keys)
for _, key := range keys {
if key == "signature" {
continue
}
val := pa[key]
if key != "" && val != "" {
b.WriteString(key)
b.WriteString(val)
}
}
stringToSign := fmt.Sprintf("%v\n%v\n%v\n", method, b.String(), RequestURL)
sha256 := sha256.New
hash := hmac.New(sha256, []byte(appsecret))
hash.Write([]byte(stringToSign))
return base64.StdEncoding.EncodeToString(hash.Sum(nil))
}

View File

@ -0,0 +1,20 @@
package apiauth
import (
"net/url"
"testing"
)
func TestSignature(t *testing.T) {
appsecret := "beego secret"
method := "GET"
RequestURL := "http://localhost/test/url"
params := make(url.Values)
params.Add("arg1", "hello")
params.Add("arg2", "beego")
signature := "mFdpvLh48ca4mDVEItE9++AKKQ/IVca7O/ZyyB8hR58="
if Signature(appsecret, method, params, RequestURL) != signature {
t.Error("Signature error")
}
}

View File

@ -0,0 +1,107 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package auth provides handlers to enable basic auth support.
// Simple Usage:
// import(
// "github.com/astaxie/beego"
// "github.com/astaxie/beego/plugins/auth"
// )
//
// func main(){
// // authenticate every request
// beego.InsertFilter("*", beego.BeforeRouter,auth.Basic("username","secretpassword"))
// beego.Run()
// }
//
//
// Advanced Usage:
//
// func SecretAuth(username, password string) bool {
// return username == "astaxie" && password == "helloBeego"
// }
// authPlugin := auth.NewBasicAuthenticator(SecretAuth, "Authorization Required")
// beego.InsertFilter("*", beego.BeforeRouter,authPlugin)
package auth
import (
"encoding/base64"
"net/http"
"strings"
"github.com/astaxie/beego/server/web"
"github.com/astaxie/beego/server/web/context"
)
var defaultRealm = "Authorization Required"
// Basic is the http basic auth
func Basic(username string, password string) web.FilterFunc {
secrets := func(user, pass string) bool {
return user == username && pass == password
}
return NewBasicAuthenticator(secrets, defaultRealm)
}
// NewBasicAuthenticator return the BasicAuth
func NewBasicAuthenticator(secrets SecretProvider, Realm string) web.FilterFunc {
return func(ctx *context.Context) {
a := &BasicAuth{Secrets: secrets, Realm: Realm}
if username := a.CheckAuth(ctx.Request); username == "" {
a.RequireAuth(ctx.ResponseWriter, ctx.Request)
}
}
}
// SecretProvider is the SecretProvider function
type SecretProvider func(user, pass string) bool
// BasicAuth store the SecretProvider and Realm
type BasicAuth struct {
Secrets SecretProvider
Realm string
}
// CheckAuth Checks the username/password combination from the request. Returns
// either an empty string (authentication failed) or the name of the
// authenticated user.
// Supports MD5 and SHA1 password entries
func (a *BasicAuth) CheckAuth(r *http.Request) string {
s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
if len(s) != 2 || s[0] != "Basic" {
return ""
}
b, err := base64.StdEncoding.DecodeString(s[1])
if err != nil {
return ""
}
pair := strings.SplitN(string(b), ":", 2)
if len(pair) != 2 {
return ""
}
if a.Secrets(pair[0], pair[1]) {
return pair[0]
}
return ""
}
// RequireAuth http.Handler for BasicAuth which initiates the authentication process
// (or requires reauthentication).
func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("WWW-Authenticate", `Basic realm="`+a.Realm+`"`)
w.WriteHeader(401)
w.Write([]byte("401 Unauthorized\n"))
}

View File

@ -0,0 +1,88 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package authz provides handlers to enable ACL, RBAC, ABAC authorization support.
// Simple Usage:
// import(
// "github.com/astaxie/beego"
// "github.com/astaxie/beego/plugins/authz"
// "github.com/casbin/casbin"
// )
//
// func main(){
// // mediate the access for every request
// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
// beego.Run()
// }
//
//
// Advanced Usage:
//
// func main(){
// e := casbin.NewEnforcer("authz_model.conf", "")
// e.AddRoleForUser("alice", "admin")
// e.AddPolicy(...)
//
// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(e))
// beego.Run()
// }
package authz
import (
"net/http"
"github.com/casbin/casbin"
"github.com/astaxie/beego/server/web"
"github.com/astaxie/beego/server/web/context"
)
// NewAuthorizer returns the authorizer.
// Use a casbin enforcer as input
func NewAuthorizer(e *casbin.Enforcer) web.FilterFunc {
return func(ctx *context.Context) {
a := &BasicAuthorizer{enforcer: e}
if !a.CheckPermission(ctx.Request) {
a.RequirePermission(ctx.ResponseWriter)
}
}
}
// BasicAuthorizer stores the casbin handler
type BasicAuthorizer struct {
enforcer *casbin.Enforcer
}
// GetUserName gets the user name from the request.
// Currently, only HTTP basic authentication is supported
func (a *BasicAuthorizer) GetUserName(r *http.Request) string {
username, _, _ := r.BasicAuth()
return username
}
// CheckPermission checks the user/method/path combination from the request.
// Returns true (permission granted) or false (permission forbidden)
func (a *BasicAuthorizer) CheckPermission(r *http.Request) bool {
user := a.GetUserName(r)
method := r.Method
path := r.URL.Path
return a.enforcer.Enforce(user, path, method)
}
// RequirePermission returns the 403 Forbidden to the client
func (a *BasicAuthorizer) RequirePermission(w http.ResponseWriter) {
w.WriteHeader(403)
w.Write([]byte("403 Forbidden\n"))
}

View File

@ -0,0 +1,14 @@
[request_definition]
r = sub, obj, act
[policy_definition]
p = sub, obj, act
[role_definition]
g = _, _
[policy_effect]
e = some(where (p.eft == allow))
[matchers]
m = g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*")

View File

@ -0,0 +1,7 @@
p, alice, /dataset1/*, GET
p, alice, /dataset1/resource1, POST
p, bob, /dataset2/resource1, *
p, bob, /dataset2/resource2, GET
p, bob, /dataset2/folder1/*, POST
p, dataset1_admin, /dataset1/*, *
g, cathy, dataset1_admin
1 p, alice, /dataset1/*, GET
2 p, alice, /dataset1/resource1, POST
3 p, bob, /dataset2/resource1, *
4 p, bob, /dataset2/resource2, GET
5 p, bob, /dataset2/folder1/*, POST
6 p, dataset1_admin, /dataset1/*, *
7 g, cathy, dataset1_admin

View File

@ -0,0 +1,109 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package authz
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/casbin/casbin"
"github.com/astaxie/beego/server/web"
"github.com/astaxie/beego/server/web/context"
"github.com/astaxie/beego/server/web/filter/auth"
)
func testRequest(t *testing.T, handler *web.ControllerRegister, user string, path string, method string, code int) {
r, _ := http.NewRequest(method, path, nil)
r.SetBasicAuth(user, "123")
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
if w.Code != code {
t.Errorf("%s, %s, %s: %d, supposed to be %d", user, path, method, w.Code, code)
}
}
func TestBasic(t *testing.T) {
handler := web.NewControllerRegister()
handler.InsertFilter("*", web.BeforeRouter, auth.Basic("alice", "123"))
handler.InsertFilter("*", web.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
handler.Any("*", func(ctx *context.Context) {
ctx.Output.SetStatus(200)
})
testRequest(t, handler, "alice", "/dataset1/resource1", "GET", 200)
testRequest(t, handler, "alice", "/dataset1/resource1", "POST", 200)
testRequest(t, handler, "alice", "/dataset1/resource2", "GET", 200)
testRequest(t, handler, "alice", "/dataset1/resource2", "POST", 403)
}
func TestPathWildcard(t *testing.T) {
handler := web.NewControllerRegister()
handler.InsertFilter("*", web.BeforeRouter, auth.Basic("bob", "123"))
handler.InsertFilter("*", web.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
handler.Any("*", func(ctx *context.Context) {
ctx.Output.SetStatus(200)
})
testRequest(t, handler, "bob", "/dataset2/resource1", "GET", 200)
testRequest(t, handler, "bob", "/dataset2/resource1", "POST", 200)
testRequest(t, handler, "bob", "/dataset2/resource1", "DELETE", 200)
testRequest(t, handler, "bob", "/dataset2/resource2", "GET", 200)
testRequest(t, handler, "bob", "/dataset2/resource2", "POST", 403)
testRequest(t, handler, "bob", "/dataset2/resource2", "DELETE", 403)
testRequest(t, handler, "bob", "/dataset2/folder1/item1", "GET", 403)
testRequest(t, handler, "bob", "/dataset2/folder1/item1", "POST", 200)
testRequest(t, handler, "bob", "/dataset2/folder1/item1", "DELETE", 403)
testRequest(t, handler, "bob", "/dataset2/folder1/item2", "GET", 403)
testRequest(t, handler, "bob", "/dataset2/folder1/item2", "POST", 200)
testRequest(t, handler, "bob", "/dataset2/folder1/item2", "DELETE", 403)
}
func TestRBAC(t *testing.T) {
handler := web.NewControllerRegister()
handler.InsertFilter("*", web.BeforeRouter, auth.Basic("cathy", "123"))
e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")
handler.InsertFilter("*", web.BeforeRouter, NewAuthorizer(e))
handler.Any("*", func(ctx *context.Context) {
ctx.Output.SetStatus(200)
})
// cathy can access all /dataset1/* resources via all methods because it has the dataset1_admin role.
testRequest(t, handler, "cathy", "/dataset1/item", "GET", 200)
testRequest(t, handler, "cathy", "/dataset1/item", "POST", 200)
testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 200)
testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403)
testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403)
testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403)
// delete all roles on user cathy, so cathy cannot access any resources now.
e.DeleteRolesForUser("cathy")
testRequest(t, handler, "cathy", "/dataset1/item", "GET", 403)
testRequest(t, handler, "cathy", "/dataset1/item", "POST", 403)
testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 403)
testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403)
testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403)
testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403)
}

View File

@ -0,0 +1,228 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package cors provides handlers to enable CORS support.
// Usage
// import (
// "github.com/astaxie/beego"
// "github.com/astaxie/beego/plugins/cors"
// )
//
// func main() {
// // CORS for https://foo.* origins, allowing:
// // - PUT and PATCH methods
// // - Origin header
// // - Credentials share
// beego.InsertFilter("*", beego.BeforeRouter, cors.Allow(&cors.Options{
// AllowOrigins: []string{"https://*.foo.com"},
// AllowMethods: []string{"PUT", "PATCH"},
// AllowHeaders: []string{"Origin"},
// ExposeHeaders: []string{"Content-Length"},
// AllowCredentials: true,
// }))
// beego.Run()
// }
package cors
import (
"net/http"
"regexp"
"strconv"
"strings"
"time"
"github.com/astaxie/beego/server/web"
"github.com/astaxie/beego/server/web/context"
)
const (
headerAllowOrigin = "Access-Control-Allow-Origin"
headerAllowCredentials = "Access-Control-Allow-Credentials"
headerAllowHeaders = "Access-Control-Allow-Headers"
headerAllowMethods = "Access-Control-Allow-Methods"
headerExposeHeaders = "Access-Control-Expose-Headers"
headerMaxAge = "Access-Control-Max-Age"
headerOrigin = "Origin"
headerRequestMethod = "Access-Control-Request-Method"
headerRequestHeaders = "Access-Control-Request-Headers"
)
var (
defaultAllowHeaders = []string{"Origin", "Accept", "Content-Type", "Authorization"}
// Regex patterns are generated from AllowOrigins. These are used and generated internally.
allowOriginPatterns = []string{}
)
// Options represents Access Control options.
type Options struct {
// If set, all origins are allowed.
AllowAllOrigins bool
// A list of allowed origins. Wild cards and FQDNs are supported.
AllowOrigins []string
// If set, allows to share auth credentials such as cookies.
AllowCredentials bool
// A list of allowed HTTP methods.
AllowMethods []string
// A list of allowed HTTP headers.
AllowHeaders []string
// A list of exposed HTTP headers.
ExposeHeaders []string
// Max age of the CORS headers.
MaxAge time.Duration
}
// Header converts options into CORS headers.
func (o *Options) Header(origin string) (headers map[string]string) {
headers = make(map[string]string)
// if origin is not allowed, don't extend the headers
// with CORS headers.
if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) {
return
}
// add allow origin
if o.AllowAllOrigins {
headers[headerAllowOrigin] = "*"
} else {
headers[headerAllowOrigin] = origin
}
// add allow credentials
headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials)
// add allow methods
if len(o.AllowMethods) > 0 {
headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",")
}
// add allow headers
if len(o.AllowHeaders) > 0 {
headers[headerAllowHeaders] = strings.Join(o.AllowHeaders, ",")
}
// add exposed header
if len(o.ExposeHeaders) > 0 {
headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",")
}
// add a max age header
if o.MaxAge > time.Duration(0) {
headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10)
}
return
}
// PreflightHeader converts options into CORS headers for a preflight response.
func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map[string]string) {
headers = make(map[string]string)
if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) {
return
}
// verify if requested method is allowed
for _, method := range o.AllowMethods {
if method == rMethod {
headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",")
break
}
}
// verify if requested headers are allowed
var allowed []string
for _, rHeader := range strings.Split(rHeaders, ",") {
rHeader = strings.TrimSpace(rHeader)
lookupLoop:
for _, allowedHeader := range o.AllowHeaders {
if strings.ToLower(rHeader) == strings.ToLower(allowedHeader) {
allowed = append(allowed, rHeader)
break lookupLoop
}
}
}
headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials)
// add allow origin
if o.AllowAllOrigins {
headers[headerAllowOrigin] = "*"
} else {
headers[headerAllowOrigin] = origin
}
// add allowed headers
if len(allowed) > 0 {
headers[headerAllowHeaders] = strings.Join(allowed, ",")
}
// add exposed headers
if len(o.ExposeHeaders) > 0 {
headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",")
}
// add a max age header
if o.MaxAge > time.Duration(0) {
headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10)
}
return
}
// IsOriginAllowed looks up if the origin matches one of the patterns
// generated from Options.AllowOrigins patterns.
func (o *Options) IsOriginAllowed(origin string) (allowed bool) {
for _, pattern := range allowOriginPatterns {
allowed, _ = regexp.MatchString(pattern, origin)
if allowed {
return
}
}
return
}
// Allow enables CORS for requests those match the provided options.
func Allow(opts *Options) web.FilterFunc {
// Allow default headers if nothing is specified.
if len(opts.AllowHeaders) == 0 {
opts.AllowHeaders = defaultAllowHeaders
}
for _, origin := range opts.AllowOrigins {
pattern := regexp.QuoteMeta(origin)
pattern = strings.Replace(pattern, "\\*", ".*", -1)
pattern = strings.Replace(pattern, "\\?", ".", -1)
allowOriginPatterns = append(allowOriginPatterns, "^"+pattern+"$")
}
return func(ctx *context.Context) {
var (
origin = ctx.Input.Header(headerOrigin)
requestedMethod = ctx.Input.Header(headerRequestMethod)
requestedHeaders = ctx.Input.Header(headerRequestHeaders)
// additional headers to be added
// to the response.
headers map[string]string
)
if ctx.Input.Method() == "OPTIONS" &&
(requestedMethod != "" || requestedHeaders != "") {
headers = opts.PreflightHeader(origin, requestedMethod, requestedHeaders)
for key, value := range headers {
ctx.Output.Header(key, value)
}
ctx.ResponseWriter.WriteHeader(http.StatusOK)
return
}
headers = opts.Header(origin)
for key, value := range headers {
ctx.Output.Header(key, value)
}
}
}

View File

@ -0,0 +1,253 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cors
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/astaxie/beego/server/web"
"github.com/astaxie/beego/server/web/context"
)
// HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header
type HTTPHeaderGuardRecorder struct {
*httptest.ResponseRecorder
savedHeaderMap http.Header
}
// NewRecorder return HttpHeaderGuardRecorder
func NewRecorder() *HTTPHeaderGuardRecorder {
return &HTTPHeaderGuardRecorder{httptest.NewRecorder(), nil}
}
func (gr *HTTPHeaderGuardRecorder) WriteHeader(code int) {
gr.ResponseRecorder.WriteHeader(code)
gr.savedHeaderMap = gr.ResponseRecorder.Header()
}
func (gr *HTTPHeaderGuardRecorder) Header() http.Header {
if gr.savedHeaderMap != nil {
// headers were written. clone so we don't get updates
clone := make(http.Header)
for k, v := range gr.savedHeaderMap {
clone[k] = v
}
return clone
}
return gr.ResponseRecorder.Header()
}
func Test_AllowAll(t *testing.T) {
recorder := httptest.NewRecorder()
handler := web.NewControllerRegister()
handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{
AllowAllOrigins: true,
}))
handler.Any("/foo", func(ctx *context.Context) {
ctx.Output.SetStatus(500)
})
r, _ := http.NewRequest("PUT", "/foo", nil)
handler.ServeHTTP(recorder, r)
if recorder.HeaderMap.Get(headerAllowOrigin) != "*" {
t.Errorf("Allow-Origin header should be *")
}
}
func Test_AllowRegexMatch(t *testing.T) {
recorder := httptest.NewRecorder()
handler := web.NewControllerRegister()
handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{
AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"},
}))
handler.Any("/foo", func(ctx *context.Context) {
ctx.Output.SetStatus(500)
})
origin := "https://bar.foo.com"
r, _ := http.NewRequest("PUT", "/foo", nil)
r.Header.Add("Origin", origin)
handler.ServeHTTP(recorder, r)
headerValue := recorder.HeaderMap.Get(headerAllowOrigin)
if headerValue != origin {
t.Errorf("Allow-Origin header should be %v, found %v", origin, headerValue)
}
}
func Test_AllowRegexNoMatch(t *testing.T) {
recorder := httptest.NewRecorder()
handler := web.NewControllerRegister()
handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{
AllowOrigins: []string{"https://*.foo.com"},
}))
handler.Any("/foo", func(ctx *context.Context) {
ctx.Output.SetStatus(500)
})
origin := "https://ww.foo.com.evil.com"
r, _ := http.NewRequest("PUT", "/foo", nil)
r.Header.Add("Origin", origin)
handler.ServeHTTP(recorder, r)
headerValue := recorder.HeaderMap.Get(headerAllowOrigin)
if headerValue != "" {
t.Errorf("Allow-Origin header should not exist, found %v", headerValue)
}
}
func Test_OtherHeaders(t *testing.T) {
recorder := httptest.NewRecorder()
handler := web.NewControllerRegister()
handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{
AllowAllOrigins: true,
AllowCredentials: true,
AllowMethods: []string{"PATCH", "GET"},
AllowHeaders: []string{"Origin", "X-whatever"},
ExposeHeaders: []string{"Content-Length", "Hello"},
MaxAge: 5 * time.Minute,
}))
handler.Any("/foo", func(ctx *context.Context) {
ctx.Output.SetStatus(500)
})
r, _ := http.NewRequest("PUT", "/foo", nil)
handler.ServeHTTP(recorder, r)
credentialsVal := recorder.HeaderMap.Get(headerAllowCredentials)
methodsVal := recorder.HeaderMap.Get(headerAllowMethods)
headersVal := recorder.HeaderMap.Get(headerAllowHeaders)
exposedHeadersVal := recorder.HeaderMap.Get(headerExposeHeaders)
maxAgeVal := recorder.HeaderMap.Get(headerMaxAge)
if credentialsVal != "true" {
t.Errorf("Allow-Credentials is expected to be true, found %v", credentialsVal)
}
if methodsVal != "PATCH,GET" {
t.Errorf("Allow-Methods is expected to be PATCH,GET; found %v", methodsVal)
}
if headersVal != "Origin,X-whatever" {
t.Errorf("Allow-Headers is expected to be Origin,X-whatever; found %v", headersVal)
}
if exposedHeadersVal != "Content-Length,Hello" {
t.Errorf("Expose-Headers are expected to be Content-Length,Hello. Found %v", exposedHeadersVal)
}
if maxAgeVal != "300" {
t.Errorf("Max-Age is expected to be 300, found %v", maxAgeVal)
}
}
func Test_DefaultAllowHeaders(t *testing.T) {
recorder := httptest.NewRecorder()
handler := web.NewControllerRegister()
handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{
AllowAllOrigins: true,
}))
handler.Any("/foo", func(ctx *context.Context) {
ctx.Output.SetStatus(500)
})
r, _ := http.NewRequest("PUT", "/foo", nil)
handler.ServeHTTP(recorder, r)
headersVal := recorder.HeaderMap.Get(headerAllowHeaders)
if headersVal != "Origin,Accept,Content-Type,Authorization" {
t.Errorf("Allow-Headers is expected to be Origin,Accept,Content-Type,Authorization; found %v", headersVal)
}
}
func Test_Preflight(t *testing.T) {
recorder := NewRecorder()
handler := web.NewControllerRegister()
handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{
AllowAllOrigins: true,
AllowMethods: []string{"PUT", "PATCH"},
AllowHeaders: []string{"Origin", "X-whatever", "X-CaseSensitive"},
}))
handler.Any("/foo", func(ctx *context.Context) {
ctx.Output.SetStatus(200)
})
r, _ := http.NewRequest("OPTIONS", "/foo", nil)
r.Header.Add(headerRequestMethod, "PUT")
r.Header.Add(headerRequestHeaders, "X-whatever, x-casesensitive")
handler.ServeHTTP(recorder, r)
headers := recorder.Header()
methodsVal := headers.Get(headerAllowMethods)
headersVal := headers.Get(headerAllowHeaders)
originVal := headers.Get(headerAllowOrigin)
if methodsVal != "PUT,PATCH" {
t.Errorf("Allow-Methods is expected to be PUT,PATCH, found %v", methodsVal)
}
if !strings.Contains(headersVal, "X-whatever") {
t.Errorf("Allow-Headers is expected to contain X-whatever, found %v", headersVal)
}
if !strings.Contains(headersVal, "x-casesensitive") {
t.Errorf("Allow-Headers is expected to contain x-casesensitive, found %v", headersVal)
}
if originVal != "*" {
t.Errorf("Allow-Origin is expected to be *, found %v", originVal)
}
if recorder.Code != http.StatusOK {
t.Errorf("Status code is expected to be 200, found %d", recorder.Code)
}
}
func Benchmark_WithoutCORS(b *testing.B) {
recorder := httptest.NewRecorder()
handler := web.NewControllerRegister()
web.BConfig.RunMode = web.PROD
handler.Any("/foo", func(ctx *context.Context) {
ctx.Output.SetStatus(500)
})
b.ResetTimer()
r, _ := http.NewRequest("PUT", "/foo", nil)
for i := 0; i < b.N; i++ {
handler.ServeHTTP(recorder, r)
}
}
func Benchmark_WithCORS(b *testing.B) {
recorder := httptest.NewRecorder()
handler := web.NewControllerRegister()
web.BConfig.RunMode = web.PROD
handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{
AllowAllOrigins: true,
AllowCredentials: true,
AllowMethods: []string{"PATCH", "GET"},
AllowHeaders: []string{"Origin", "X-whatever"},
MaxAge: 5 * time.Minute,
}))
handler.Any("/foo", func(ctx *context.Context) {
ctx.Output.SetStatus(500)
})
b.ResetTimer()
r, _ := http.NewRequest("PUT", "/foo", nil)
for i := 0; i < b.N; i++ {
handler.ServeHTTP(recorder, r)
}
}

View File

@ -0,0 +1,86 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package opentracing
import (
"context"
"github.com/astaxie/beego/server/web"
beegoCtx "github.com/astaxie/beego/server/web/context"
logKit "github.com/go-kit/kit/log"
opentracingKit "github.com/go-kit/kit/tracing/opentracing"
"github.com/opentracing/opentracing-go"
)
// FilterChainBuilder provides an extension point that we can support more configurations if necessary
type FilterChainBuilder struct {
// CustomSpanFunc makes users to custom the span.
CustomSpanFunc func(span opentracing.Span, ctx *beegoCtx.Context)
}
func (builder *FilterChainBuilder) FilterChain(next web.FilterFunc) web.FilterFunc {
return func(ctx *beegoCtx.Context) {
var (
spanCtx context.Context
span opentracing.Span
)
operationName := builder.operationName(ctx)
if preSpan := opentracing.SpanFromContext(ctx.Request.Context()); preSpan == nil {
inject := opentracingKit.HTTPToContext(opentracing.GlobalTracer(), operationName, logKit.NewNopLogger())
spanCtx = inject(ctx.Request.Context(), ctx.Request)
span = opentracing.SpanFromContext(spanCtx)
} else {
span, spanCtx = opentracing.StartSpanFromContext(ctx.Request.Context(), operationName)
}
defer span.Finish()
newReq := ctx.Request.Clone(spanCtx)
ctx.Reset(ctx.ResponseWriter.ResponseWriter, newReq)
next(ctx)
// if you think we need to do more things, feel free to create an issue to tell us
span.SetTag("http.status_code", ctx.ResponseWriter.Status)
span.SetTag("http.method", ctx.Input.Method())
span.SetTag("peer.hostname", ctx.Request.Host)
span.SetTag("http.url", ctx.Request.URL.String())
span.SetTag("http.scheme", ctx.Request.URL.Scheme)
span.SetTag("span.kind", "server")
span.SetTag("component", "beego")
if ctx.Output.IsServerError() || ctx.Output.IsClientError() {
span.SetTag("error", true)
}
span.SetTag("peer.address", ctx.Request.RemoteAddr)
span.SetTag("http.proto", ctx.Request.Proto)
span.SetTag("beego.route", ctx.Input.GetData("RouterPattern"))
if builder.CustomSpanFunc != nil {
builder.CustomSpanFunc(span, ctx)
}
}
}
func (builder *FilterChainBuilder) operationName(ctx *beegoCtx.Context) string {
operationName := ctx.Input.URL()
// it means that there is not any span, so we create a span as the root span.
// TODO, if we support multiple servers, this need to be changed
route, found := web.BeeApp.Handlers.FindRouter(ctx)
if found {
operationName = ctx.Input.Method() + "#" + route.GetPattern()
}
return operationName
}

View File

@ -0,0 +1,47 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package opentracing
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/opentracing/opentracing-go"
"github.com/stretchr/testify/assert"
"github.com/astaxie/beego/server/web/context"
)
func TestFilterChainBuilder_FilterChain(t *testing.T) {
builder := &FilterChainBuilder{
CustomSpanFunc: func(span opentracing.Span, ctx *context.Context) {
span.SetTag("aa", "bbb")
},
}
ctx := context.NewContext()
r, _ := http.NewRequest("GET", "/prometheus/user", nil)
w := httptest.NewRecorder()
ctx.Reset(w, r)
ctx.Input.SetData("RouterPattern", "my-route")
filterFunc := builder.FilterChain(func(ctx *context.Context) {
ctx.Input.SetData("opentracing", true)
})
filterFunc(ctx)
assert.True(t, ctx.Input.GetData("opentracing").(bool))
}

View File

@ -0,0 +1,87 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package prometheus
import (
"strconv"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/astaxie/beego"
"github.com/astaxie/beego/server/web"
"github.com/astaxie/beego/server/web/context"
)
// FilterChainBuilder is an extension point,
// when we want to support some configuration,
// please use this structure
type FilterChainBuilder struct {
}
// FilterChain returns a FilterFunc. The filter will records some metrics
func (builder *FilterChainBuilder) FilterChain(next web.FilterFunc) web.FilterFunc {
summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{
Name: "beego",
Subsystem: "http_request",
ConstLabels: map[string]string{
"server": web.BConfig.ServerName,
"env": web.BConfig.RunMode,
"appname": web.BConfig.AppName,
},
Help: "The statics info for http request",
}, []string{"pattern", "method", "status", "duration"})
prometheus.MustRegister(summaryVec)
registerBuildInfo()
return func(ctx *context.Context) {
startTime := time.Now()
next(ctx)
endTime := time.Now()
go report(endTime.Sub(startTime), ctx, summaryVec)
}
}
func registerBuildInfo() {
buildInfo := prometheus.NewGaugeVec(prometheus.GaugeOpts{
Name: "beego",
Subsystem: "build_info",
Help: "The building information",
ConstLabels: map[string]string{
"appname": web.BConfig.AppName,
"build_version": beego.BuildVersion,
"build_revision": beego.BuildGitRevision,
"build_status": beego.BuildStatus,
"build_tag": beego.BuildTag,
"build_time": strings.Replace(beego.BuildTime, "--", " ", 1),
"go_version": beego.GoVersion,
"git_branch": beego.GitBranch,
"start_time": time.Now().Format("2006-01-02 15:04:05"),
},
}, []string{})
prometheus.MustRegister(buildInfo)
buildInfo.WithLabelValues().Set(1)
}
func report(dur time.Duration, ctx *context.Context, vec *prometheus.SummaryVec) {
status := ctx.Output.Status
ptn := ctx.Input.GetData("RouterPattern").(string)
ms := dur / time.Millisecond
vec.WithLabelValues(ptn, ctx.Input.Method(), strconv.Itoa(status), strconv.Itoa(int(ms))).Observe(float64(ms))
}

View File

@ -0,0 +1,40 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package prometheus
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/astaxie/beego/server/web/context"
)
func TestFilterChain(t *testing.T) {
filter := (&FilterChainBuilder{}).FilterChain(func(ctx *context.Context) {
// do nothing
ctx.Input.SetData("invocation", true)
})
ctx := context.NewContext()
r, _ := http.NewRequest("GET", "/prometheus/user", nil)
w := httptest.NewRecorder()
ctx.Reset(w, r)
ctx.Input.SetData("RouterPattern", "my-route")
filter(ctx)
assert.True(t, ctx.Input.GetData("invocation").(bool))
}

View File

@ -0,0 +1,48 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/astaxie/beego/server/web/context"
)
func TestControllerRegister_InsertFilterChain(t *testing.T) {
InsertFilterChain("/*", func(next FilterFunc) FilterFunc {
return func(ctx *context.Context) {
ctx.Output.Header("filter", "filter-chain")
next(ctx)
}
})
ns := NewNamespace("/chain")
ns.Get("/*", func(ctx *context.Context) {
ctx.Output.Body([]byte("hello"))
})
r, _ := http.NewRequest("GET", "/chain/user", nil)
w := httptest.NewRecorder()
BeeApp.Handlers.ServeHTTP(w, r)
assert.Equal(t, "filter-chain", w.Header().Get("filter"))
}

68
server/web/filter_test.go Normal file
View File

@ -0,0 +1,68 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/astaxie/beego/server/web/context"
)
var FilterUser = func(ctx *context.Context) {
ctx.Output.Body([]byte("i am " + ctx.Input.Param(":last") + ctx.Input.Param(":first")))
}
func TestFilter(t *testing.T) {
r, _ := http.NewRequest("GET", "/person/asta/Xie", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.InsertFilter("/person/:last/:first", BeforeRouter, FilterUser)
handler.Add("/person/:last/:first", &TestController{})
handler.ServeHTTP(w, r)
if w.Body.String() != "i am astaXie" {
t.Errorf("user define func can't run")
}
}
var FilterAdminUser = func(ctx *context.Context) {
ctx.Output.Body([]byte("i am admin"))
}
// Filter pattern /admin/:all
// all url like /admin/ /admin/xie will all get filter
func TestPatternTwo(t *testing.T) {
r, _ := http.NewRequest("GET", "/admin/", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.InsertFilter("/admin/?:all", BeforeRouter, FilterAdminUser)
handler.ServeHTTP(w, r)
if w.Body.String() != "i am admin" {
t.Errorf("filter /admin/ can't run")
}
}
func TestPatternThree(t *testing.T) {
r, _ := http.NewRequest("GET", "/admin/astaxie", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.InsertFilter("/admin/:all", BeforeRouter, FilterAdminUser)
handler.ServeHTTP(w, r)
if w.Body.String() != "i am admin" {
t.Errorf("filter /admin/astaxie can't run")
}
}

110
server/web/flash.go Normal file
View File

@ -0,0 +1,110 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"fmt"
"net/url"
"strings"
)
// FlashData is a tools to maintain data when using across request.
type FlashData struct {
Data map[string]string
}
// NewFlash return a new empty FlashData struct.
func NewFlash() *FlashData {
return &FlashData{
Data: make(map[string]string),
}
}
// Set message to flash
func (fd *FlashData) Set(key string, msg string, args ...interface{}) {
if len(args) == 0 {
fd.Data[key] = msg
} else {
fd.Data[key] = fmt.Sprintf(msg, args...)
}
}
// Success writes success message to flash.
func (fd *FlashData) Success(msg string, args ...interface{}) {
if len(args) == 0 {
fd.Data["success"] = msg
} else {
fd.Data["success"] = fmt.Sprintf(msg, args...)
}
}
// Notice writes notice message to flash.
func (fd *FlashData) Notice(msg string, args ...interface{}) {
if len(args) == 0 {
fd.Data["notice"] = msg
} else {
fd.Data["notice"] = fmt.Sprintf(msg, args...)
}
}
// Warning writes warning message to flash.
func (fd *FlashData) Warning(msg string, args ...interface{}) {
if len(args) == 0 {
fd.Data["warning"] = msg
} else {
fd.Data["warning"] = fmt.Sprintf(msg, args...)
}
}
// Error writes error message to flash.
func (fd *FlashData) Error(msg string, args ...interface{}) {
if len(args) == 0 {
fd.Data["error"] = msg
} else {
fd.Data["error"] = fmt.Sprintf(msg, args...)
}
}
// Store does the saving operation of flash data.
// the data are encoded and saved in cookie.
func (fd *FlashData) Store(c *Controller) {
c.Data["flash"] = fd.Data
var flashValue string
for key, value := range fd.Data {
flashValue += "\x00" + key + "\x23" + BConfig.WebConfig.FlashSeparator + "\x23" + value + "\x00"
}
c.Ctx.SetCookie(BConfig.WebConfig.FlashName, url.QueryEscape(flashValue), 0, "/")
}
// ReadFromRequest parsed flash data from encoded values in cookie.
func ReadFromRequest(c *Controller) *FlashData {
flash := NewFlash()
if cookie, err := c.Ctx.Request.Cookie(BConfig.WebConfig.FlashName); err == nil {
v, _ := url.QueryUnescape(cookie.Value)
vals := strings.Split(v, "\x00")
for _, v := range vals {
if len(v) > 0 {
kv := strings.Split(v, "\x23"+BConfig.WebConfig.FlashSeparator+"\x23")
if len(kv) == 2 {
flash.Data[kv[0]] = kv[1]
}
}
}
//read one time then delete it
c.Ctx.SetCookie(BConfig.WebConfig.FlashName, "", -1, "/")
}
c.Data["flash"] = flash.Data
return flash
}

54
server/web/flash_test.go Normal file
View File

@ -0,0 +1,54 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
type TestFlashController struct {
Controller
}
func (t *TestFlashController) TestWriteFlash() {
flash := NewFlash()
flash.Notice("TestFlashString")
flash.Store(&t.Controller)
// we choose to serve json because we don't want to load a template html file
t.ServeJSON(true)
}
func TestFlashHeader(t *testing.T) {
// create fake GET request
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
// setup the handler
handler := NewControllerRegister()
handler.Add("/", &TestFlashController{}, "get:TestWriteFlash")
handler.ServeHTTP(w, r)
// get the Set-Cookie value
sc := w.Header().Get("Set-Cookie")
// match for the expected header
res := strings.Contains(sc, "BEEGO_FLASH=%00notice%23BEEGOFLASH%23TestFlashString%00")
// validate the assertion
if !res {
t.Errorf("TestFlashHeader() unable to validate flash message")
}
}

74
server/web/fs.go Normal file
View File

@ -0,0 +1,74 @@
package web
import (
"net/http"
"os"
"path/filepath"
)
type FileSystem struct {
}
func (d FileSystem) Open(name string) (http.File, error) {
return os.Open(name)
}
// Walk walks the file tree rooted at root in filesystem, calling walkFn for each file or
// directory in the tree, including root. All errors that arise visiting files
// and directories are filtered by walkFn.
func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error {
f, err := fs.Open(root)
if err != nil {
return err
}
info, err := f.Stat()
if err != nil {
err = walkFn(root, nil, err)
} else {
err = walk(fs, root, info, walkFn)
}
if err == filepath.SkipDir {
return nil
}
return err
}
// walk recursively descends path, calling walkFn.
func walk(fs http.FileSystem, path string, info os.FileInfo, walkFn filepath.WalkFunc) error {
var err error
if !info.IsDir() {
return walkFn(path, info, nil)
}
dir, err := fs.Open(path)
if err != nil {
if err1 := walkFn(path, info, err); err1 != nil {
return err1
}
return err
}
defer dir.Close()
dirs, err := dir.Readdir(-1)
err1 := walkFn(path, info, err)
// If err != nil, walk can't walk into this directory.
// err1 != nil means walkFn want walk to skip this directory or stop walking.
// Therefore, if one of err and err1 isn't nil, walk will return.
if err != nil || err1 != nil {
// The caller's behavior is controlled by the return value, which is decided
// by walkFn. walkFn may ignore err and return nil.
// If walkFn returns SkipDir, it will be handled by the caller.
// So walk should return whatever walkFn returns.
return err1
}
for _, fileInfo := range dirs {
filename := filepath.Join(path, fileInfo.Name())
if err = walk(fs, filename, fileInfo, walkFn); err != nil {
if !fileInfo.IsDir() || err != filepath.SkipDir {
return err
}
}
}
return nil
}

166
server/web/grace/grace.go Normal file
View File

@ -0,0 +1,166 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package grace use to hot reload
// Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/
//
// Usage:
//
// import(
// "log"
// "net/http"
// "os"
//
// "github.com/astaxie/beego/grace"
// )
//
// func handler(w http.ResponseWriter, r *http.Request) {
// w.Write([]byte("WORLD!"))
// }
//
// func main() {
// mux := http.NewServeMux()
// mux.HandleFunc("/hello", handler)
//
// err := grace.ListenAndServe("localhost:8080", mux)
// if err != nil {
// log.Println(err)
// }
// log.Println("Server on 8080 stopped")
// os.Exit(0)
// }
package grace
import (
"flag"
"net/http"
"os"
"strings"
"sync"
"syscall"
"time"
)
const (
// PreSignal is the position to add filter before signal
PreSignal = iota
// PostSignal is the position to add filter after signal
PostSignal
// StateInit represent the application inited
StateInit
// StateRunning represent the application is running
StateRunning
// StateShuttingDown represent the application is shutting down
StateShuttingDown
// StateTerminate represent the application is killed
StateTerminate
)
var (
regLock *sync.Mutex
runningServers map[string]*Server
runningServersOrder []string
socketPtrOffsetMap map[string]uint
runningServersForked bool
// DefaultReadTimeOut is the HTTP read timeout
DefaultReadTimeOut time.Duration
// DefaultWriteTimeOut is the HTTP Write timeout
DefaultWriteTimeOut time.Duration
// DefaultMaxHeaderBytes is the Max HTTP Header size, default is 0, no limit
DefaultMaxHeaderBytes int
// DefaultTimeout is the shutdown server's timeout. default is 60s
DefaultTimeout = 60 * time.Second
isChild bool
socketOrder string
hookableSignals []os.Signal
)
func init() {
flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)")
flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started")
regLock = &sync.Mutex{}
runningServers = make(map[string]*Server)
runningServersOrder = []string{}
socketPtrOffsetMap = make(map[string]uint)
hookableSignals = []os.Signal{
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
}
}
// NewServer returns a new graceServer.
func NewServer(addr string, handler http.Handler) (srv *Server) {
regLock.Lock()
defer regLock.Unlock()
if !flag.Parsed() {
flag.Parse()
}
if len(socketOrder) > 0 {
for i, addr := range strings.Split(socketOrder, ",") {
socketPtrOffsetMap[addr] = uint(i)
}
} else {
socketPtrOffsetMap[addr] = uint(len(runningServersOrder))
}
srv = &Server{
sigChan: make(chan os.Signal),
isChild: isChild,
SignalHooks: map[int]map[os.Signal][]func(){
PreSignal: {
syscall.SIGHUP: {},
syscall.SIGINT: {},
syscall.SIGTERM: {},
},
PostSignal: {
syscall.SIGHUP: {},
syscall.SIGINT: {},
syscall.SIGTERM: {},
},
},
state: StateInit,
Network: "tcp",
terminalChan: make(chan error), //no cache channel
}
srv.Server = &http.Server{
Addr: addr,
ReadTimeout: DefaultReadTimeOut,
WriteTimeout: DefaultWriteTimeOut,
MaxHeaderBytes: DefaultMaxHeaderBytes,
Handler: handler,
}
runningServersOrder = append(runningServersOrder, addr)
runningServers[addr] = srv
return srv
}
// ListenAndServe refer http.ListenAndServe
func ListenAndServe(addr string, handler http.Handler) error {
server := NewServer(addr, handler)
return server.ListenAndServe()
}
// ListenAndServeTLS refer http.ListenAndServeTLS
func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error {
server := NewServer(addr, handler)
return server.ListenAndServeTLS(certFile, keyFile)
}

356
server/web/grace/server.go Normal file
View File

@ -0,0 +1,356 @@
package grace
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"os/exec"
"os/signal"
"strings"
"syscall"
"time"
)
// Server embedded http.Server
type Server struct {
*http.Server
ln net.Listener
SignalHooks map[int]map[os.Signal][]func()
sigChan chan os.Signal
isChild bool
state uint8
Network string
terminalChan chan error
}
// Serve accepts incoming connections on the Listener l
// and creates a new service goroutine for each.
// The service goroutines read requests and then call srv.Handler to reply to them.
func (srv *Server) Serve() (err error) {
srv.state = StateRunning
defer func() { srv.state = StateTerminate }()
// When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS
// immediately return ErrServerClosed. Make sure the program doesn't exit
// and waits instead for Shutdown to return.
if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed {
log.Println(syscall.Getpid(), "Server.Serve() error:", err)
return err
}
log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.")
// wait for Shutdown to return
if shutdownErr := <-srv.terminalChan; shutdownErr != nil {
return shutdownErr
}
return
}
// ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
// to handle requests on incoming connections. If srv.Addr is blank, ":http" is
// used.
func (srv *Server) ListenAndServe() (err error) {
addr := srv.Addr
if addr == "" {
addr = ":http"
}
go srv.handleSignals()
srv.ln, err = srv.getListener(addr)
if err != nil {
log.Println(err)
return err
}
if srv.isChild {
process, err := os.FindProcess(os.Getppid())
if err != nil {
log.Println(err)
return err
}
err = process.Signal(syscall.SIGTERM)
if err != nil {
return err
}
}
log.Println(os.Getpid(), srv.Addr)
return srv.Serve()
}
// ListenAndServeTLS listens on the TCP network address srv.Addr and then calls
// Serve to handle requests on incoming TLS connections.
//
// Filenames containing a certificate and matching private key for the server must
// be provided. If the certificate is signed by a certificate authority, the
// certFile should be the concatenation of the server's certificate followed by the
// CA's certificate.
//
// If srv.Addr is blank, ":https" is used.
func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
addr := srv.Addr
if addr == "" {
addr = ":https"
}
if srv.TLSConfig == nil {
srv.TLSConfig = &tls.Config{}
}
if srv.TLSConfig.NextProtos == nil {
srv.TLSConfig.NextProtos = []string{"http/1.1"}
}
srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return
}
go srv.handleSignals()
ln, err := srv.getListener(addr)
if err != nil {
log.Println(err)
return err
}
srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
if srv.isChild {
process, err := os.FindProcess(os.Getppid())
if err != nil {
log.Println(err)
return err
}
err = process.Signal(syscall.SIGTERM)
if err != nil {
return err
}
}
log.Println(os.Getpid(), srv.Addr)
return srv.Serve()
}
// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls
// Serve to handle requests on incoming mutual TLS connections.
func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) {
addr := srv.Addr
if addr == "" {
addr = ":https"
}
if srv.TLSConfig == nil {
srv.TLSConfig = &tls.Config{}
}
if srv.TLSConfig.NextProtos == nil {
srv.TLSConfig.NextProtos = []string{"http/1.1"}
}
srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return
}
srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
pool := x509.NewCertPool()
data, err := ioutil.ReadFile(trustFile)
if err != nil {
log.Println(err)
return err
}
pool.AppendCertsFromPEM(data)
srv.TLSConfig.ClientCAs = pool
log.Println("Mutual HTTPS")
go srv.handleSignals()
ln, err := srv.getListener(addr)
if err != nil {
log.Println(err)
return err
}
srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
if srv.isChild {
process, err := os.FindProcess(os.Getppid())
if err != nil {
log.Println(err)
return err
}
err = process.Signal(syscall.SIGTERM)
if err != nil {
return err
}
}
log.Println(os.Getpid(), srv.Addr)
return srv.Serve()
}
// getListener either opens a new socket to listen on, or takes the acceptor socket
// it got passed when restarted.
func (srv *Server) getListener(laddr string) (l net.Listener, err error) {
if srv.isChild {
var ptrOffset uint
if len(socketPtrOffsetMap) > 0 {
ptrOffset = socketPtrOffsetMap[laddr]
log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr])
}
f := os.NewFile(uintptr(3+ptrOffset), "")
l, err = net.FileListener(f)
if err != nil {
err = fmt.Errorf("net.FileListener error: %v", err)
return
}
} else {
l, err = net.Listen(srv.Network, laddr)
if err != nil {
err = fmt.Errorf("net.Listen error: %v", err)
return
}
}
return
}
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
tc, err := ln.AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
return tc, nil
}
// handleSignals listens for os Signals and calls any hooked in function that the
// user had registered with the signal.
func (srv *Server) handleSignals() {
var sig os.Signal
signal.Notify(
srv.sigChan,
hookableSignals...,
)
pid := syscall.Getpid()
for {
sig = <-srv.sigChan
srv.signalHooks(PreSignal, sig)
switch sig {
case syscall.SIGHUP:
log.Println(pid, "Received SIGHUP. forking.")
err := srv.fork()
if err != nil {
log.Println("Fork err:", err)
}
case syscall.SIGINT:
log.Println(pid, "Received SIGINT.")
srv.shutdown()
case syscall.SIGTERM:
log.Println(pid, "Received SIGTERM.")
srv.shutdown()
default:
log.Printf("Received %v: nothing i care about...\n", sig)
}
srv.signalHooks(PostSignal, sig)
}
}
func (srv *Server) signalHooks(ppFlag int, sig os.Signal) {
if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet {
return
}
for _, f := range srv.SignalHooks[ppFlag][sig] {
f()
}
}
// shutdown closes the listener so that no new connections are accepted. it also
// starts a goroutine that will serverTimeout (stop all running requests) the server
// after DefaultTimeout.
func (srv *Server) shutdown() {
if srv.state != StateRunning {
return
}
srv.state = StateShuttingDown
log.Println(syscall.Getpid(), "Waiting for connections to finish...")
ctx := context.Background()
if DefaultTimeout >= 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout)
defer cancel()
}
srv.terminalChan <- srv.Server.Shutdown(ctx)
}
func (srv *Server) fork() (err error) {
regLock.Lock()
defer regLock.Unlock()
if runningServersForked {
return
}
runningServersForked = true
var files = make([]*os.File, len(runningServers))
var orderArgs = make([]string, len(runningServers))
for _, srvPtr := range runningServers {
f, _ := srvPtr.ln.(*net.TCPListener).File()
files[socketPtrOffsetMap[srvPtr.Server.Addr]] = f
orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr
}
log.Println(files)
path := os.Args[0]
var args []string
if len(os.Args) > 1 {
for _, arg := range os.Args[1:] {
if arg == "-graceful" {
break
}
args = append(args, arg)
}
}
args = append(args, "-graceful")
if len(runningServers) > 1 {
args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ",")))
log.Println(args)
}
cmd := exec.Command(path, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.ExtraFiles = files
err = cmd.Start()
if err != nil {
log.Fatalf("Restart: Failed to launch, error: %v", err)
}
return
}
// RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal.
func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) {
if ppFlag != PreSignal && ppFlag != PostSignal {
err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal")
return
}
for _, s := range hookableSignals {
if s == sig {
srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f)
return
}
}
err = fmt.Errorf("Signal '%v' is not supported", sig)
return
}

108
server/web/hooks.go Normal file
View File

@ -0,0 +1,108 @@
package web
import (
context2 "context"
"encoding/json"
"mime"
"net/http"
"path/filepath"
"github.com/astaxie/beego/core/logs"
"github.com/astaxie/beego/server/web/context"
"github.com/astaxie/beego/server/web/session"
)
// register MIME type with content type
func registerMime() error {
for k, v := range mimemaps {
mime.AddExtensionType(k, v)
}
return nil
}
// register default error http handlers, 404,401,403,500 and 503.
func registerDefaultErrorHandler() error {
m := map[string]func(http.ResponseWriter, *http.Request){
"401": unauthorized,
"402": paymentRequired,
"403": forbidden,
"404": notFound,
"405": methodNotAllowed,
"500": internalServerError,
"501": notImplemented,
"502": badGateway,
"503": serviceUnavailable,
"504": gatewayTimeout,
"417": invalidxsrf,
"422": missingxsrf,
"413": payloadTooLarge,
}
for e, h := range m {
if _, ok := ErrorMaps[e]; !ok {
ErrorHandler(e, h)
}
}
return nil
}
func registerSession() error {
if BConfig.WebConfig.Session.SessionOn {
var err error
sessionConfig, err := AppConfig.String(nil, "sessionConfig")
conf := new(session.ManagerConfig)
if sessionConfig == "" || err != nil {
conf.CookieName = BConfig.WebConfig.Session.SessionName
conf.EnableSetCookie = BConfig.WebConfig.Session.SessionAutoSetCookie
conf.Gclifetime = BConfig.WebConfig.Session.SessionGCMaxLifetime
conf.Secure = BConfig.Listen.EnableHTTPS
conf.CookieLifeTime = BConfig.WebConfig.Session.SessionCookieLifeTime
conf.ProviderConfig = filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig)
conf.DisableHTTPOnly = BConfig.WebConfig.Session.SessionDisableHTTPOnly
conf.Domain = BConfig.WebConfig.Session.SessionDomain
conf.EnableSidInHTTPHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader
conf.SessionNameInHTTPHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader
conf.EnableSidInURLQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery
} else {
if err = json.Unmarshal([]byte(sessionConfig), conf); err != nil {
return err
}
}
if GlobalSessions, err = session.NewManager(BConfig.WebConfig.Session.SessionProvider, conf); err != nil {
return err
}
go GlobalSessions.GC()
}
return nil
}
func registerTemplate() error {
defer lockViewPaths()
if err := AddViewPath(BConfig.WebConfig.ViewsPath); err != nil {
if BConfig.RunMode == DEV {
logs.Warn(err)
}
return err
}
return nil
}
func registerGzip() error {
if BConfig.EnableGzip {
context.InitGzip(
AppConfig.DefaultInt(context2.Background(), "gzipMinLength", -1),
AppConfig.DefaultInt(context2.Background(), "gzipCompressLevel", -1),
AppConfig.DefaultStrings(context2.Background(), "includedMethods", []string{"GET"}),
)
}
return nil
}
func registerCommentRouter() error {
if BConfig.RunMode == DEV {
if err := parserPkg(filepath.Join(WorkPath, BConfig.WebConfig.CommentRouterPath)); err != nil {
return err
}
}
return nil
}

556
server/web/mime.go Normal file
View File

@ -0,0 +1,556 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
var mimemaps = map[string]string{
".3dm": "x-world/x-3dmf",
".3dmf": "x-world/x-3dmf",
".7z": "application/x-7z-compressed",
".a": "application/octet-stream",
".aab": "application/x-authorware-bin",
".aam": "application/x-authorware-map",
".aas": "application/x-authorware-seg",
".abc": "text/vndabc",
".ace": "application/x-ace-compressed",
".acgi": "text/html",
".afl": "video/animaflex",
".ai": "application/postscript",
".aif": "audio/aiff",
".aifc": "audio/aiff",
".aiff": "audio/aiff",
".aim": "application/x-aim",
".aip": "text/x-audiosoft-intra",
".alz": "application/x-alz-compressed",
".ani": "application/x-navi-animation",
".aos": "application/x-nokia-9000-communicator-add-on-software",
".aps": "application/mime",
".apk": "application/vnd.android.package-archive",
".arc": "application/x-arc-compressed",
".arj": "application/arj",
".art": "image/x-jg",
".asf": "video/x-ms-asf",
".asm": "text/x-asm",
".asp": "text/asp",
".asx": "application/x-mplayer2",
".au": "audio/basic",
".avi": "video/x-msvideo",
".avs": "video/avs-video",
".bcpio": "application/x-bcpio",
".bin": "application/mac-binary",
".bmp": "image/bmp",
".boo": "application/book",
".book": "application/book",
".boz": "application/x-bzip2",
".bsh": "application/x-bsh",
".bz2": "application/x-bzip2",
".bz": "application/x-bzip",
".c++": "text/plain",
".c": "text/x-c",
".cab": "application/vnd.ms-cab-compressed",
".cat": "application/vndms-pkiseccat",
".cc": "text/x-c",
".ccad": "application/clariscad",
".cco": "application/x-cocoa",
".cdf": "application/cdf",
".cer": "application/pkix-cert",
".cha": "application/x-chat",
".chat": "application/x-chat",
".chrt": "application/vnd.kde.kchart",
".class": "application/java",
".com": "text/plain",
".conf": "text/plain",
".cpio": "application/x-cpio",
".cpp": "text/x-c",
".cpt": "application/mac-compactpro",
".crl": "application/pkcs-crl",
".crt": "application/pkix-cert",
".crx": "application/x-chrome-extension",
".csh": "text/x-scriptcsh",
".css": "text/css",
".csv": "text/csv",
".cxx": "text/plain",
".dar": "application/x-dar",
".dcr": "application/x-director",
".deb": "application/x-debian-package",
".deepv": "application/x-deepv",
".def": "text/plain",
".der": "application/x-x509-ca-cert",
".dif": "video/x-dv",
".dir": "application/x-director",
".divx": "video/divx",
".dl": "video/dl",
".dmg": "application/x-apple-diskimage",
".doc": "application/msword",
".dot": "application/msword",
".dp": "application/commonground",
".drw": "application/drafting",
".dump": "application/octet-stream",
".dv": "video/x-dv",
".dvi": "application/x-dvi",
".dwf": "drawing/x-dwf=(old)",
".dwg": "application/acad",
".dxf": "application/dxf",
".dxr": "application/x-director",
".el": "text/x-scriptelisp",
".elc": "application/x-bytecodeelisp=(compiled=elisp)",
".eml": "message/rfc822",
".env": "application/x-envoy",
".eps": "application/postscript",
".es": "application/x-esrehber",
".etx": "text/x-setext",
".evy": "application/envoy",
".exe": "application/octet-stream",
".f77": "text/x-fortran",
".f90": "text/x-fortran",
".f": "text/x-fortran",
".fdf": "application/vndfdf",
".fif": "application/fractals",
".fli": "video/fli",
".flo": "image/florian",
".flv": "video/x-flv",
".flx": "text/vndfmiflexstor",
".fmf": "video/x-atomic3d-feature",
".for": "text/x-fortran",
".fpx": "image/vndfpx",
".frl": "application/freeloader",
".funk": "audio/make",
".g3": "image/g3fax",
".g": "text/plain",
".gif": "image/gif",
".gl": "video/gl",
".gsd": "audio/x-gsm",
".gsm": "audio/x-gsm",
".gsp": "application/x-gsp",
".gss": "application/x-gss",
".gtar": "application/x-gtar",
".gz": "application/x-compressed",
".gzip": "application/x-gzip",
".h": "text/x-h",
".hdf": "application/x-hdf",
".help": "application/x-helpfile",
".hgl": "application/vndhp-hpgl",
".hh": "text/x-h",
".hlb": "text/x-script",
".hlp": "application/hlp",
".hpg": "application/vndhp-hpgl",
".hpgl": "application/vndhp-hpgl",
".hqx": "application/binhex",
".hta": "application/hta",
".htc": "text/x-component",
".htm": "text/html",
".html": "text/html",
".htmls": "text/html",
".htt": "text/webviewhtml",
".htx": "text/html",
".ice": "x-conference/x-cooltalk",
".ico": "image/x-icon",
".ics": "text/calendar",
".icz": "text/calendar",
".idc": "text/plain",
".ief": "image/ief",
".iefs": "image/ief",
".iges": "application/iges",
".igs": "application/iges",
".ima": "application/x-ima",
".imap": "application/x-httpd-imap",
".inf": "application/inf",
".ins": "application/x-internett-signup",
".ip": "application/x-ip2",
".isu": "video/x-isvideo",
".it": "audio/it",
".iv": "application/x-inventor",
".ivr": "i-world/i-vrml",
".ivy": "application/x-livescreen",
".jam": "audio/x-jam",
".jav": "text/x-java-source",
".java": "text/x-java-source",
".jcm": "application/x-java-commerce",
".jfif-tbnl": "image/jpeg",
".jfif": "image/jpeg",
".jnlp": "application/x-java-jnlp-file",
".jpe": "image/jpeg",
".jpeg": "image/jpeg",
".jpg": "image/jpeg",
".jps": "image/x-jps",
".js": "application/javascript",
".json": "application/json",
".jut": "image/jutvision",
".kar": "audio/midi",
".karbon": "application/vnd.kde.karbon",
".kfo": "application/vnd.kde.kformula",
".flw": "application/vnd.kde.kivio",
".kml": "application/vnd.google-earth.kml+xml",
".kmz": "application/vnd.google-earth.kmz",
".kon": "application/vnd.kde.kontour",
".kpr": "application/vnd.kde.kpresenter",
".kpt": "application/vnd.kde.kpresenter",
".ksp": "application/vnd.kde.kspread",
".kwd": "application/vnd.kde.kword",
".kwt": "application/vnd.kde.kword",
".ksh": "text/x-scriptksh",
".la": "audio/nspaudio",
".lam": "audio/x-liveaudio",
".latex": "application/x-latex",
".lha": "application/lha",
".lhx": "application/octet-stream",
".list": "text/plain",
".lma": "audio/nspaudio",
".log": "text/plain",
".lsp": "text/x-scriptlisp",
".lst": "text/plain",
".lsx": "text/x-la-asf",
".ltx": "application/x-latex",
".lzh": "application/octet-stream",
".lzx": "application/lzx",
".m1v": "video/mpeg",
".m2a": "audio/mpeg",
".m2v": "video/mpeg",
".m3u": "audio/x-mpegurl",
".m": "text/x-m",
".man": "application/x-troff-man",
".manifest": "text/cache-manifest",
".map": "application/x-navimap",
".mar": "text/plain",
".mbd": "application/mbedlet",
".mc$": "application/x-magic-cap-package-10",
".mcd": "application/mcad",
".mcf": "text/mcf",
".mcp": "application/netmc",
".me": "application/x-troff-me",
".mht": "message/rfc822",
".mhtml": "message/rfc822",
".mid": "application/x-midi",
".midi": "application/x-midi",
".mif": "application/x-frame",
".mime": "message/rfc822",
".mjf": "audio/x-vndaudioexplosionmjuicemediafile",
".mjpg": "video/x-motion-jpeg",
".mm": "application/base64",
".mme": "application/base64",
".mod": "audio/mod",
".moov": "video/quicktime",
".mov": "video/quicktime",
".movie": "video/x-sgi-movie",
".mp2": "audio/mpeg",
".mp3": "audio/mpeg3",
".mp4": "video/mp4",
".mpa": "audio/mpeg",
".mpc": "application/x-project",
".mpe": "video/mpeg",
".mpeg": "video/mpeg",
".mpg": "video/mpeg",
".mpga": "audio/mpeg",
".mpp": "application/vndms-project",
".mpt": "application/x-project",
".mpv": "application/x-project",
".mpx": "application/x-project",
".mrc": "application/marc",
".ms": "application/x-troff-ms",
".mv": "video/x-sgi-movie",
".my": "audio/make",
".mzz": "application/x-vndaudioexplosionmzz",
".nap": "image/naplps",
".naplps": "image/naplps",
".nc": "application/x-netcdf",
".ncm": "application/vndnokiaconfiguration-message",
".nif": "image/x-niff",
".niff": "image/x-niff",
".nix": "application/x-mix-transfer",
".nsc": "application/x-conference",
".nvd": "application/x-navidoc",
".o": "application/octet-stream",
".oda": "application/oda",
".odb": "application/vnd.oasis.opendocument.database",
".odc": "application/vnd.oasis.opendocument.chart",
".odf": "application/vnd.oasis.opendocument.formula",
".odg": "application/vnd.oasis.opendocument.graphics",
".odi": "application/vnd.oasis.opendocument.image",
".odm": "application/vnd.oasis.opendocument.text-master",
".odp": "application/vnd.oasis.opendocument.presentation",
".ods": "application/vnd.oasis.opendocument.spreadsheet",
".odt": "application/vnd.oasis.opendocument.text",
".oga": "audio/ogg",
".ogg": "audio/ogg",
".ogv": "video/ogg",
".omc": "application/x-omc",
".omcd": "application/x-omcdatamaker",
".omcr": "application/x-omcregerator",
".otc": "application/vnd.oasis.opendocument.chart-template",
".otf": "application/vnd.oasis.opendocument.formula-template",
".otg": "application/vnd.oasis.opendocument.graphics-template",
".oth": "application/vnd.oasis.opendocument.text-web",
".oti": "application/vnd.oasis.opendocument.image-template",
".otm": "application/vnd.oasis.opendocument.text-master",
".otp": "application/vnd.oasis.opendocument.presentation-template",
".ots": "application/vnd.oasis.opendocument.spreadsheet-template",
".ott": "application/vnd.oasis.opendocument.text-template",
".p10": "application/pkcs10",
".p12": "application/pkcs-12",
".p7a": "application/x-pkcs7-signature",
".p7c": "application/pkcs7-mime",
".p7m": "application/pkcs7-mime",
".p7r": "application/x-pkcs7-certreqresp",
".p7s": "application/pkcs7-signature",
".p": "text/x-pascal",
".part": "application/pro_eng",
".pas": "text/pascal",
".pbm": "image/x-portable-bitmap",
".pcl": "application/vndhp-pcl",
".pct": "image/x-pict",
".pcx": "image/x-pcx",
".pdb": "chemical/x-pdb",
".pdf": "application/pdf",
".pfunk": "audio/make",
".pgm": "image/x-portable-graymap",
".pic": "image/pict",
".pict": "image/pict",
".pkg": "application/x-newton-compatible-pkg",
".pko": "application/vndms-pkipko",
".pl": "text/x-scriptperl",
".plx": "application/x-pixclscript",
".pm4": "application/x-pagemaker",
".pm5": "application/x-pagemaker",
".pm": "text/x-scriptperl-module",
".png": "image/png",
".pnm": "application/x-portable-anymap",
".pot": "application/mspowerpoint",
".pov": "model/x-pov",
".ppa": "application/vndms-powerpoint",
".ppm": "image/x-portable-pixmap",
".pps": "application/mspowerpoint",
".ppt": "application/mspowerpoint",
".ppz": "application/mspowerpoint",
".pre": "application/x-freelance",
".prt": "application/pro_eng",
".ps": "application/postscript",
".psd": "application/octet-stream",
".pvu": "paleovu/x-pv",
".pwz": "application/vndms-powerpoint",
".py": "text/x-scriptphyton",
".pyc": "application/x-bytecodepython",
".qcp": "audio/vndqcelp",
".qd3": "x-world/x-3dmf",
".qd3d": "x-world/x-3dmf",
".qif": "image/x-quicktime",
".qt": "video/quicktime",
".qtc": "video/x-qtc",
".qti": "image/x-quicktime",
".qtif": "image/x-quicktime",
".ra": "audio/x-pn-realaudio",
".ram": "audio/x-pn-realaudio",
".rar": "application/x-rar-compressed",
".ras": "application/x-cmu-raster",
".rast": "image/cmu-raster",
".rexx": "text/x-scriptrexx",
".rf": "image/vndrn-realflash",
".rgb": "image/x-rgb",
".rm": "application/vndrn-realmedia",
".rmi": "audio/mid",
".rmm": "audio/x-pn-realaudio",
".rmp": "audio/x-pn-realaudio",
".rng": "application/ringing-tones",
".rnx": "application/vndrn-realplayer",
".roff": "application/x-troff",
".rp": "image/vndrn-realpix",
".rpm": "audio/x-pn-realaudio-plugin",
".rt": "text/vndrn-realtext",
".rtf": "text/richtext",
".rtx": "text/richtext",
".rv": "video/vndrn-realvideo",
".s": "text/x-asm",
".s3m": "audio/s3m",
".s7z": "application/x-7z-compressed",
".saveme": "application/octet-stream",
".sbk": "application/x-tbook",
".scm": "text/x-scriptscheme",
".sdml": "text/plain",
".sdp": "application/sdp",
".sdr": "application/sounder",
".sea": "application/sea",
".set": "application/set",
".sgm": "text/x-sgml",
".sgml": "text/x-sgml",
".sh": "text/x-scriptsh",
".shar": "application/x-bsh",
".shtml": "text/x-server-parsed-html",
".sid": "audio/x-psid",
".skd": "application/x-koan",
".skm": "application/x-koan",
".skp": "application/x-koan",
".skt": "application/x-koan",
".sit": "application/x-stuffit",
".sitx": "application/x-stuffitx",
".sl": "application/x-seelogo",
".smi": "application/smil",
".smil": "application/smil",
".snd": "audio/basic",
".sol": "application/solids",
".spc": "text/x-speech",
".spl": "application/futuresplash",
".spr": "application/x-sprite",
".sprite": "application/x-sprite",
".spx": "audio/ogg",
".src": "application/x-wais-source",
".ssi": "text/x-server-parsed-html",
".ssm": "application/streamingmedia",
".sst": "application/vndms-pkicertstore",
".step": "application/step",
".stl": "application/sla",
".stp": "application/step",
".sv4cpio": "application/x-sv4cpio",
".sv4crc": "application/x-sv4crc",
".svf": "image/vnddwg",
".svg": "image/svg+xml",
".svr": "application/x-world",
".swf": "application/x-shockwave-flash",
".t": "application/x-troff",
".talk": "text/x-speech",
".tar": "application/x-tar",
".tbk": "application/toolbook",
".tcl": "text/x-scripttcl",
".tcsh": "text/x-scripttcsh",
".tex": "application/x-tex",
".texi": "application/x-texinfo",
".texinfo": "application/x-texinfo",
".text": "text/plain",
".tgz": "application/gnutar",
".tif": "image/tiff",
".tiff": "image/tiff",
".tr": "application/x-troff",
".tsi": "audio/tsp-audio",
".tsp": "application/dsptype",
".tsv": "text/tab-separated-values",
".turbot": "image/florian",
".txt": "text/plain",
".uil": "text/x-uil",
".uni": "text/uri-list",
".unis": "text/uri-list",
".unv": "application/i-deas",
".uri": "text/uri-list",
".uris": "text/uri-list",
".ustar": "application/x-ustar",
".uu": "text/x-uuencode",
".uue": "text/x-uuencode",
".vcd": "application/x-cdlink",
".vcf": "text/x-vcard",
".vcard": "text/x-vcard",
".vcs": "text/x-vcalendar",
".vda": "application/vda",
".vdo": "video/vdo",
".vew": "application/groupwise",
".viv": "video/vivo",
".vivo": "video/vivo",
".vmd": "application/vocaltec-media-desc",
".vmf": "application/vocaltec-media-file",
".voc": "audio/voc",
".vos": "video/vosaic",
".vox": "audio/voxware",
".vqe": "audio/x-twinvq-plugin",
".vqf": "audio/x-twinvq",
".vql": "audio/x-twinvq-plugin",
".vrml": "application/x-vrml",
".vrt": "x-world/x-vrt",
".vsd": "application/x-visio",
".vst": "application/x-visio",
".vsw": "application/x-visio",
".w60": "application/wordperfect60",
".w61": "application/wordperfect61",
".w6w": "application/msword",
".wav": "audio/wav",
".wb1": "application/x-qpro",
".wbmp": "image/vnd.wap.wbmp",
".web": "application/vndxara",
".wiz": "application/msword",
".wk1": "application/x-123",
".wmf": "windows/metafile",
".wml": "text/vnd.wap.wml",
".wmlc": "application/vnd.wap.wmlc",
".wmls": "text/vnd.wap.wmlscript",
".wmlsc": "application/vnd.wap.wmlscriptc",
".word": "application/msword",
".wp5": "application/wordperfect",
".wp6": "application/wordperfect",
".wp": "application/wordperfect",
".wpd": "application/wordperfect",
".wq1": "application/x-lotus",
".wri": "application/mswrite",
".wrl": "application/x-world",
".wrz": "model/vrml",
".wsc": "text/scriplet",
".wsrc": "application/x-wais-source",
".wtk": "application/x-wintalk",
".x-png": "image/png",
".xbm": "image/x-xbitmap",
".xdr": "video/x-amt-demorun",
".xgz": "xgl/drawing",
".xif": "image/vndxiff",
".xl": "application/excel",
".xla": "application/excel",
".xlb": "application/excel",
".xlc": "application/excel",
".xld": "application/excel",
".xlk": "application/excel",
".xll": "application/excel",
".xlm": "application/excel",
".xls": "application/excel",
".xlt": "application/excel",
".xlv": "application/excel",
".xlw": "application/excel",
".xm": "audio/xm",
".xml": "text/xml",
".xmz": "xgl/movie",
".xpix": "application/x-vndls-xpix",
".xpm": "image/x-xpixmap",
".xsr": "video/x-amt-showrun",
".xwd": "image/x-xwd",
".xyz": "chemical/x-pdb",
".z": "application/x-compress",
".zip": "application/zip",
".zoo": "application/octet-stream",
".zsh": "text/x-scriptzsh",
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
".docm": "application/vnd.ms-word.document.macroEnabled.12",
".dotx": "application/vnd.openxmlformats-officedocument.wordprocessingml.template",
".dotm": "application/vnd.ms-word.template.macroEnabled.12",
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
".xlsm": "application/vnd.ms-excel.sheet.macroEnabled.12",
".xltx": "application/vnd.openxmlformats-officedocument.spreadsheetml.template",
".xltm": "application/vnd.ms-excel.template.macroEnabled.12",
".xlsb": "application/vnd.ms-excel.sheet.binary.macroEnabled.12",
".xlam": "application/vnd.ms-excel.addin.macroEnabled.12",
".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
".pptm": "application/vnd.ms-powerpoint.presentation.macroEnabled.12",
".ppsx": "application/vnd.openxmlformats-officedocument.presentationml.slideshow",
".ppsm": "application/vnd.ms-powerpoint.slideshow.macroEnabled.12",
".potx": "application/vnd.openxmlformats-officedocument.presentationml.template",
".potm": "application/vnd.ms-powerpoint.template.macroEnabled.12",
".ppam": "application/vnd.ms-powerpoint.addin.macroEnabled.12",
".sldx": "application/vnd.openxmlformats-officedocument.presentationml.slide",
".sldm": "application/vnd.ms-powerpoint.slide.macroEnabled.12",
".thmx": "application/vnd.ms-officetheme",
".onetoc": "application/onenote",
".onetoc2": "application/onenote",
".onetmp": "application/onenote",
".onepkg": "application/onenote",
".key": "application/x-iwork-keynote-sffkey",
".kth": "application/x-iwork-keynote-sffkth",
".nmbtemplate": "application/x-iwork-numbers-sfftemplate",
".numbers": "application/x-iwork-numbers-sffnumbers",
".pages": "application/x-iwork-pages-sffpages",
".template": "application/x-iwork-pages-sfftemplate",
".xpi": "application/x-xpinstall",
".oex": "application/x-opera-extension",
".mustache": "text/html",
}

396
server/web/namespace.go Normal file
View File

@ -0,0 +1,396 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"net/http"
"strings"
beecontext "github.com/astaxie/beego/server/web/context"
)
type namespaceCond func(*beecontext.Context) bool
// LinkNamespace used as link action
type LinkNamespace func(*Namespace)
// Namespace is store all the info
type Namespace struct {
prefix string
handlers *ControllerRegister
}
// NewNamespace get new Namespace
func NewNamespace(prefix string, params ...LinkNamespace) *Namespace {
ns := &Namespace{
prefix: prefix,
handlers: NewControllerRegister(),
}
for _, p := range params {
p(ns)
}
return ns
}
// Cond set condition function
// if cond return true can run this namespace, else can't
// usage:
// ns.Cond(func (ctx *context.Context) bool{
// if ctx.Input.Domain() == "api.beego.me" {
// return true
// }
// return false
// })
// Cond as the first filter
func (n *Namespace) Cond(cond namespaceCond) *Namespace {
fn := func(ctx *beecontext.Context) {
if !cond(ctx) {
exception("405", ctx)
}
}
if v := n.handlers.filters[BeforeRouter]; len(v) > 0 {
mr := new(FilterRouter)
mr.tree = NewTree()
mr.pattern = "*"
mr.filterFunc = fn
mr.tree.AddRouter("*", true)
n.handlers.filters[BeforeRouter] = append([]*FilterRouter{mr}, v...)
} else {
n.handlers.InsertFilter("*", BeforeRouter, fn)
}
return n
}
// Filter add filter in the Namespace
// action has before & after
// FilterFunc
// usage:
// Filter("before", func (ctx *context.Context){
// _, ok := ctx.Input.Session("uid").(int)
// if !ok && ctx.Request.RequestURI != "/login" {
// ctx.Redirect(302, "/login")
// }
// })
func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace {
var a int
if action == "before" {
a = BeforeRouter
} else if action == "after" {
a = FinishRouter
}
for _, f := range filter {
n.handlers.InsertFilter("*", a, f, WithReturnOnOutput(true))
}
return n
}
// Router same as beego.Rourer
// refer: https://godoc.org/github.com/astaxie/beego#Router
func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace {
n.handlers.Add(rootpath, c, mappingMethods...)
return n
}
// AutoRouter same as beego.AutoRouter
// refer: https://godoc.org/github.com/astaxie/beego#AutoRouter
func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace {
n.handlers.AddAuto(c)
return n
}
// AutoPrefix same as beego.AutoPrefix
// refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix
func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace {
n.handlers.AddAutoPrefix(prefix, c)
return n
}
// Get same as beego.Get
// refer: https://godoc.org/github.com/astaxie/beego#Get
func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace {
n.handlers.Get(rootpath, f)
return n
}
// Post same as beego.Post
// refer: https://godoc.org/github.com/astaxie/beego#Post
func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace {
n.handlers.Post(rootpath, f)
return n
}
// Delete same as beego.Delete
// refer: https://godoc.org/github.com/astaxie/beego#Delete
func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace {
n.handlers.Delete(rootpath, f)
return n
}
// Put same as beego.Put
// refer: https://godoc.org/github.com/astaxie/beego#Put
func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace {
n.handlers.Put(rootpath, f)
return n
}
// Head same as beego.Head
// refer: https://godoc.org/github.com/astaxie/beego#Head
func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace {
n.handlers.Head(rootpath, f)
return n
}
// Options same as beego.Options
// refer: https://godoc.org/github.com/astaxie/beego#Options
func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace {
n.handlers.Options(rootpath, f)
return n
}
// Patch same as beego.Patch
// refer: https://godoc.org/github.com/astaxie/beego#Patch
func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace {
n.handlers.Patch(rootpath, f)
return n
}
// Any same as beego.Any
// refer: https://godoc.org/github.com/astaxie/beego#Any
func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace {
n.handlers.Any(rootpath, f)
return n
}
// Handler same as beego.Handler
// refer: https://godoc.org/github.com/astaxie/beego#Handler
func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace {
n.handlers.Handler(rootpath, h)
return n
}
// Include add include class
// refer: https://godoc.org/github.com/astaxie/beego#Include
func (n *Namespace) Include(cList ...ControllerInterface) *Namespace {
n.handlers.Include(cList...)
return n
}
// Namespace add nest Namespace
// usage:
//ns := beego.NewNamespace(“/v1”).
//Namespace(
// beego.NewNamespace("/shop").
// Get("/:id", func(ctx *context.Context) {
// ctx.Output.Body([]byte("shopinfo"))
// }),
// beego.NewNamespace("/order").
// Get("/:id", func(ctx *context.Context) {
// ctx.Output.Body([]byte("orderinfo"))
// }),
// beego.NewNamespace("/crm").
// Get("/:id", func(ctx *context.Context) {
// ctx.Output.Body([]byte("crminfo"))
// }),
//)
func (n *Namespace) Namespace(ns ...*Namespace) *Namespace {
for _, ni := range ns {
for k, v := range ni.handlers.routers {
if _, ok := n.handlers.routers[k]; ok {
addPrefix(v, ni.prefix)
n.handlers.routers[k].AddTree(ni.prefix, v)
} else {
t := NewTree()
t.AddTree(ni.prefix, v)
addPrefix(t, ni.prefix)
n.handlers.routers[k] = t
}
}
if ni.handlers.enableFilter {
for pos, filterList := range ni.handlers.filters {
for _, mr := range filterList {
t := NewTree()
t.AddTree(ni.prefix, mr.tree)
mr.tree = t
n.handlers.insertFilterRouter(pos, mr)
}
}
}
}
return n
}
// AddNamespace register Namespace into beego.Handler
// support multi Namespace
func AddNamespace(nl ...*Namespace) {
for _, n := range nl {
for k, v := range n.handlers.routers {
if _, ok := BeeApp.Handlers.routers[k]; ok {
addPrefix(v, n.prefix)
BeeApp.Handlers.routers[k].AddTree(n.prefix, v)
} else {
t := NewTree()
t.AddTree(n.prefix, v)
addPrefix(t, n.prefix)
BeeApp.Handlers.routers[k] = t
}
}
if n.handlers.enableFilter {
for pos, filterList := range n.handlers.filters {
for _, mr := range filterList {
t := NewTree()
t.AddTree(n.prefix, mr.tree)
mr.tree = t
BeeApp.Handlers.insertFilterRouter(pos, mr)
}
}
}
}
}
func addPrefix(t *Tree, prefix string) {
for _, v := range t.fixrouters {
addPrefix(v, prefix)
}
if t.wildcard != nil {
addPrefix(t.wildcard, prefix)
}
for _, l := range t.leaves {
if c, ok := l.runObject.(*ControllerInfo); ok {
if !strings.HasPrefix(c.pattern, prefix) {
c.pattern = prefix + c.pattern
}
}
}
}
// NSCond is Namespace Condition
func NSCond(cond namespaceCond) LinkNamespace {
return func(ns *Namespace) {
ns.Cond(cond)
}
}
// NSBefore Namespace BeforeRouter filter
func NSBefore(filterList ...FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Filter("before", filterList...)
}
}
// NSAfter add Namespace FinishRouter filter
func NSAfter(filterList ...FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Filter("after", filterList...)
}
}
// NSInclude Namespace Include ControllerInterface
func NSInclude(cList ...ControllerInterface) LinkNamespace {
return func(ns *Namespace) {
ns.Include(cList...)
}
}
// NSRouter call Namespace Router
func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace {
return func(ns *Namespace) {
ns.Router(rootpath, c, mappingMethods...)
}
}
// NSGet call Namespace Get
func NSGet(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Get(rootpath, f)
}
}
// NSPost call Namespace Post
func NSPost(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Post(rootpath, f)
}
}
// NSHead call Namespace Head
func NSHead(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Head(rootpath, f)
}
}
// NSPut call Namespace Put
func NSPut(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Put(rootpath, f)
}
}
// NSDelete call Namespace Delete
func NSDelete(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Delete(rootpath, f)
}
}
// NSAny call Namespace Any
func NSAny(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Any(rootpath, f)
}
}
// NSOptions call Namespace Options
func NSOptions(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Options(rootpath, f)
}
}
// NSPatch call Namespace Patch
func NSPatch(rootpath string, f FilterFunc) LinkNamespace {
return func(ns *Namespace) {
ns.Patch(rootpath, f)
}
}
// NSAutoRouter call Namespace AutoRouter
func NSAutoRouter(c ControllerInterface) LinkNamespace {
return func(ns *Namespace) {
ns.AutoRouter(c)
}
}
// NSAutoPrefix call Namespace AutoPrefix
func NSAutoPrefix(prefix string, c ControllerInterface) LinkNamespace {
return func(ns *Namespace) {
ns.AutoPrefix(prefix, c)
}
}
// NSNamespace add sub Namespace
func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace {
return func(ns *Namespace) {
n := NewNamespace(prefix, params...)
ns.Namespace(n)
}
}
// NSHandler add handler
func NSHandler(rootpath string, h http.Handler) LinkNamespace {
return func(ns *Namespace) {
ns.Handler(rootpath, h)
}
}

View File

@ -0,0 +1,168 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"net/http"
"net/http/httptest"
"strconv"
"testing"
"github.com/astaxie/beego/server/web/context"
)
func TestNamespaceGet(t *testing.T) {
r, _ := http.NewRequest("GET", "/v1/user", nil)
w := httptest.NewRecorder()
ns := NewNamespace("/v1")
ns.Get("/user", func(ctx *context.Context) {
ctx.Output.Body([]byte("v1_user"))
})
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != "v1_user" {
t.Errorf("TestNamespaceGet can't run, get the response is " + w.Body.String())
}
}
func TestNamespacePost(t *testing.T) {
r, _ := http.NewRequest("POST", "/v1/user/123", nil)
w := httptest.NewRecorder()
ns := NewNamespace("/v1")
ns.Post("/user/:id", func(ctx *context.Context) {
ctx.Output.Body([]byte(ctx.Input.Param(":id")))
})
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != "123" {
t.Errorf("TestNamespacePost can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceNest(t *testing.T) {
r, _ := http.NewRequest("GET", "/v1/admin/order", nil)
w := httptest.NewRecorder()
ns := NewNamespace("/v1")
ns.Namespace(
NewNamespace("/admin").
Get("/order", func(ctx *context.Context) {
ctx.Output.Body([]byte("order"))
}),
)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != "order" {
t.Errorf("TestNamespaceNest can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceNestParam(t *testing.T) {
r, _ := http.NewRequest("GET", "/v1/admin/order/123", nil)
w := httptest.NewRecorder()
ns := NewNamespace("/v1")
ns.Namespace(
NewNamespace("/admin").
Get("/order/:id", func(ctx *context.Context) {
ctx.Output.Body([]byte(ctx.Input.Param(":id")))
}),
)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != "123" {
t.Errorf("TestNamespaceNestParam can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceRouter(t *testing.T) {
r, _ := http.NewRequest("GET", "/v1/api/list", nil)
w := httptest.NewRecorder()
ns := NewNamespace("/v1")
ns.Router("/api/list", &TestController{}, "*:List")
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != "i am list" {
t.Errorf("TestNamespaceRouter can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceAutoFunc(t *testing.T) {
r, _ := http.NewRequest("GET", "/v1/test/list", nil)
w := httptest.NewRecorder()
ns := NewNamespace("/v1")
ns.AutoRouter(&TestController{})
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != "i am list" {
t.Errorf("user define func can't run")
}
}
func TestNamespaceFilter(t *testing.T) {
r, _ := http.NewRequest("GET", "/v1/user/123", nil)
w := httptest.NewRecorder()
ns := NewNamespace("/v1")
ns.Filter("before", func(ctx *context.Context) {
ctx.Output.Body([]byte("this is Filter"))
}).
Get("/user/:id", func(ctx *context.Context) {
ctx.Output.Body([]byte(ctx.Input.Param(":id")))
})
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != "this is Filter" {
t.Errorf("TestNamespaceFilter can't run, get the response is " + w.Body.String())
}
}
func TestNamespaceCond(t *testing.T) {
r, _ := http.NewRequest("GET", "/v2/test/list", nil)
w := httptest.NewRecorder()
ns := NewNamespace("/v2")
ns.Cond(func(ctx *context.Context) bool {
return ctx.Input.Domain() == "beego.me"
}).
AutoRouter(&TestController{})
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Code != 405 {
t.Errorf("TestNamespaceCond can't run get the result " + strconv.Itoa(w.Code))
}
}
func TestNamespaceInside(t *testing.T) {
r, _ := http.NewRequest("GET", "/v3/shop/order/123", nil)
w := httptest.NewRecorder()
ns := NewNamespace("/v3",
NSAutoRouter(&TestController{}),
NSNamespace("/shop",
NSGet("/order/:id", func(ctx *context.Context) {
ctx.Output.Body([]byte(ctx.Input.Param(":id")))
}),
),
)
AddNamespace(ns)
BeeApp.Handlers.ServeHTTP(w, r)
if w.Body.String() != "123" {
t.Errorf("TestNamespaceInside can't run, get the response is " + w.Body.String())
}
}

View File

@ -0,0 +1,27 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pagination
import (
"github.com/astaxie/beego/core/utils/pagination"
"github.com/astaxie/beego/server/web/context"
)
// SetPaginator Instantiates a Paginator and assigns it to context.Input.Data("paginator").
func SetPaginator(context *context.Context, per int, nums int64) (paginator *pagination.Paginator) {
paginator = pagination.NewPaginator(context.Request, per, nums)
context.Input.SetData("paginator", &paginator)
return
}

601
server/web/parser.go Normal file
View File

@ -0,0 +1,601 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"context"
"encoding/json"
"errors"
"fmt"
"go/ast"
"io/ioutil"
"os"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"unicode"
"golang.org/x/tools/go/packages"
"github.com/astaxie/beego/core/logs"
"github.com/astaxie/beego/core/utils"
"github.com/astaxie/beego/server/web/context/param"
)
var globalRouterTemplate = `package {{.routersDir}}
import (
"github.com/astaxie/beego"
"github.com/astaxie/beego/server/web/context/param"{{.globalimport}}
)
func init() {
{{.globalinfo}}
}
`
var (
lastupdateFilename = "lastupdate.tmp"
commentFilename string
pkgLastupdate map[string]int64
genInfoList map[string][]ControllerComments
routerHooks = map[string]int{
"beego.BeforeStatic": BeforeStatic,
"beego.BeforeRouter": BeforeRouter,
"beego.BeforeExec": BeforeExec,
"beego.AfterExec": AfterExec,
"beego.FinishRouter": FinishRouter,
}
routerHooksMapping = map[int]string{
BeforeStatic: "beego.BeforeStatic",
BeforeRouter: "beego.BeforeRouter",
BeforeExec: "beego.BeforeExec",
AfterExec: "beego.AfterExec",
FinishRouter: "beego.FinishRouter",
}
)
const commentPrefix = "commentsRouter_"
func init() {
pkgLastupdate = make(map[string]int64)
}
func parserPkg(pkgRealpath string) error {
rep := strings.NewReplacer("\\", "_", "/", "_", ".", "_")
commentFilename, _ = filepath.Rel(AppPath, pkgRealpath)
commentFilename = commentPrefix + rep.Replace(commentFilename) + ".go"
if !compareFile(pkgRealpath) {
logs.Info(pkgRealpath + " no changed")
return nil
}
genInfoList = make(map[string][]ControllerComments)
pkgs, err := packages.Load(&packages.Config{
Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedSyntax,
Dir: pkgRealpath,
}, "./...")
if err != nil {
return err
}
for _, pkg := range pkgs {
for _, fl := range pkg.Syntax {
for _, d := range fl.Decls {
switch specDecl := d.(type) {
case *ast.FuncDecl:
if specDecl.Recv != nil {
exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser
if ok {
parserComments(specDecl, fmt.Sprint(exp.X), pkg.PkgPath)
}
}
}
}
}
}
genRouterCode(pkgRealpath)
savetoFile(pkgRealpath)
return nil
}
type parsedComment struct {
routerPath string
methods []string
params map[string]parsedParam
filters []parsedFilter
imports []parsedImport
}
type parsedImport struct {
importPath string
importAlias string
}
type parsedFilter struct {
pattern string
pos int
filter string
params []bool
}
type parsedParam struct {
name string
datatype string
location string
defValue string
required bool
}
func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error {
if f.Doc != nil {
parsedComments, err := parseComment(f.Doc.List)
if err != nil {
return err
}
for _, parsedComment := range parsedComments {
if parsedComment.routerPath != "" {
key := pkgpath + ":" + controllerName
cc := ControllerComments{}
cc.Method = f.Name.String()
cc.Router = parsedComment.routerPath
cc.AllowHTTPMethods = parsedComment.methods
cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment)
cc.FilterComments = buildFilters(parsedComment.filters)
cc.ImportComments = buildImports(parsedComment.imports)
genInfoList[key] = append(genInfoList[key], cc)
}
}
}
return nil
}
func buildImports(pis []parsedImport) []*ControllerImportComments {
var importComments []*ControllerImportComments
for _, pi := range pis {
importComments = append(importComments, &ControllerImportComments{
ImportPath: pi.importPath,
ImportAlias: pi.importAlias,
})
}
return importComments
}
func buildFilters(pfs []parsedFilter) []*ControllerFilterComments {
var filterComments []*ControllerFilterComments
for _, pf := range pfs {
var (
returnOnOutput bool
resetParams bool
)
if len(pf.params) >= 1 {
returnOnOutput = pf.params[0]
}
if len(pf.params) >= 2 {
resetParams = pf.params[1]
}
filterComments = append(filterComments, &ControllerFilterComments{
Filter: pf.filter,
Pattern: pf.pattern,
Pos: pf.pos,
ReturnOnOutput: returnOnOutput,
ResetParams: resetParams,
})
}
return filterComments
}
func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam {
result := make([]*param.MethodParam, 0, len(funcParams))
for _, fparam := range funcParams {
for _, pName := range fparam.Names {
methodParam := buildMethodParam(fparam, pName.Name, pc)
result = append(result, methodParam)
}
}
return result
}
func buildMethodParam(fparam *ast.Field, name string, pc *parsedComment) *param.MethodParam {
options := []param.MethodParamOption{}
if cparam, ok := pc.params[name]; ok {
//Build param from comment info
name = cparam.name
if cparam.required {
options = append(options, param.IsRequired)
}
switch cparam.location {
case "body":
options = append(options, param.InBody)
case "header":
options = append(options, param.InHeader)
case "path":
options = append(options, param.InPath)
}
if cparam.defValue != "" {
options = append(options, param.Default(cparam.defValue))
}
} else {
if paramInPath(name, pc.routerPath) {
options = append(options, param.InPath)
}
}
return param.New(name, options...)
}
func paramInPath(name, route string) bool {
return strings.HasSuffix(route, ":"+name) ||
strings.Contains(route, ":"+name+"/")
}
var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`)
func parseComment(lines []*ast.Comment) (pcs []*parsedComment, err error) {
pcs = []*parsedComment{}
params := map[string]parsedParam{}
filters := []parsedFilter{}
imports := []parsedImport{}
for _, c := range lines {
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
if strings.HasPrefix(t, "@Param") {
pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param")))
if len(pv) < 4 {
logs.Error("Invalid @Param format. Needs at least 4 parameters")
}
p := parsedParam{}
names := strings.SplitN(pv[0], "=>", 2)
p.name = names[0]
funcParamName := p.name
if len(names) > 1 {
funcParamName = names[1]
}
p.location = pv[1]
p.datatype = pv[2]
switch len(pv) {
case 5:
p.required, _ = strconv.ParseBool(pv[3])
case 6:
p.defValue = pv[3]
p.required, _ = strconv.ParseBool(pv[4])
}
params[funcParamName] = p
}
}
for _, c := range lines {
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
if strings.HasPrefix(t, "@Import") {
iv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Import")))
if len(iv) == 0 || len(iv) > 2 {
logs.Error("Invalid @Import format. Only accepts 1 or 2 parameters")
continue
}
p := parsedImport{}
p.importPath = iv[0]
if len(iv) == 2 {
p.importAlias = iv[1]
}
imports = append(imports, p)
}
}
filterLoop:
for _, c := range lines {
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
if strings.HasPrefix(t, "@Filter") {
fv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Filter")))
if len(fv) < 3 {
logs.Error("Invalid @Filter format. Needs at least 3 parameters")
continue filterLoop
}
p := parsedFilter{}
p.pattern = fv[0]
posName := fv[1]
if pos, exists := routerHooks[posName]; exists {
p.pos = pos
} else {
logs.Error("Invalid @Filter pos: ", posName)
continue filterLoop
}
p.filter = fv[2]
fvParams := fv[3:]
for _, fvParam := range fvParams {
switch fvParam {
case "true":
p.params = append(p.params, true)
case "false":
p.params = append(p.params, false)
default:
logs.Error("Invalid @Filter param: ", fvParam)
continue filterLoop
}
}
filters = append(filters, p)
}
}
for _, c := range lines {
var pc = &parsedComment{}
pc.params = params
pc.filters = filters
pc.imports = imports
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
if strings.HasPrefix(t, "@router") {
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
matches := routeRegex.FindStringSubmatch(t)
if len(matches) == 3 {
pc.routerPath = matches[1]
methods := matches[2]
if methods == "" {
pc.methods = []string{"get"}
//pc.hasGet = true
} else {
pc.methods = strings.Split(methods, ",")
//pc.hasGet = strings.Contains(methods, "get")
}
pcs = append(pcs, pc)
} else {
return nil, errors.New("Router information is missing")
}
}
}
return
}
// direct copy from bee\g_docs.go
// analysis params return []string
// @Param query form string true "The email for login"
// [query form string true "The email for login"]
func getparams(str string) []string {
var s []rune
var j int
var start bool
var r []string
var quoted int8
for _, c := range str {
if unicode.IsSpace(c) && quoted == 0 {
if !start {
continue
} else {
start = false
j++
r = append(r, string(s))
s = make([]rune, 0)
continue
}
}
start = true
if c == '"' {
quoted ^= 1
continue
}
s = append(s, c)
}
if len(s) > 0 {
r = append(r, string(s))
}
return r
}
func genRouterCode(pkgRealpath string) {
os.Mkdir(getRouterDir(pkgRealpath), 0755)
logs.Info("generate router from comments")
var (
globalinfo string
globalimport string
sortKey []string
)
for k := range genInfoList {
sortKey = append(sortKey, k)
}
sort.Strings(sortKey)
for _, k := range sortKey {
cList := genInfoList[k]
sort.Sort(ControllerCommentsSlice(cList))
for _, c := range cList {
allmethod := "nil"
if len(c.AllowHTTPMethods) > 0 {
allmethod = "[]string{"
for _, m := range c.AllowHTTPMethods {
allmethod += `"` + m + `",`
}
allmethod = strings.TrimRight(allmethod, ",") + "}"
}
params := "nil"
if len(c.Params) > 0 {
params = "[]map[string]string{"
for _, p := range c.Params {
for k, v := range p {
params = params + `map[string]string{` + k + `:"` + v + `"},`
}
}
params = strings.TrimRight(params, ",") + "}"
}
methodParams := "param.Make("
if len(c.MethodParams) > 0 {
lines := make([]string, 0, len(c.MethodParams))
for _, m := range c.MethodParams {
lines = append(lines, fmt.Sprint(m))
}
methodParams += "\n " +
strings.Join(lines, ",\n ") +
",\n "
}
methodParams += ")"
imports := ""
if len(c.ImportComments) > 0 {
for _, i := range c.ImportComments {
var s string
if i.ImportAlias != "" {
s = fmt.Sprintf(`
%s "%s"`, i.ImportAlias, i.ImportPath)
} else {
s = fmt.Sprintf(`
"%s"`, i.ImportPath)
}
if !strings.Contains(globalimport, s) {
imports += s
}
}
}
filters := ""
if len(c.FilterComments) > 0 {
for _, f := range c.FilterComments {
filters += fmt.Sprintf(` &beego.ControllerFilter{
Pattern: "%s",
Pos: %s,
Filter: %s,
ReturnOnOutput: %v,
ResetParams: %v,
},`, f.Pattern, routerHooksMapping[f.Pos], f.Filter, f.ReturnOnOutput, f.ResetParams)
}
}
if filters == "" {
filters = "nil"
} else {
filters = fmt.Sprintf(`[]*beego.ControllerFilter{
%s
}`, filters)
}
globalimport += imports
globalinfo = globalinfo + `
beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"],
beego.ControllerComments{
Method: "` + strings.TrimSpace(c.Method) + `",
` + `Router: "` + c.Router + `"` + `,
AllowHTTPMethods: ` + allmethod + `,
MethodParams: ` + methodParams + `,
Filters: ` + filters + `,
Params: ` + params + `})
`
}
}
if globalinfo != "" {
f, err := os.Create(filepath.Join(getRouterDir(pkgRealpath), commentFilename))
if err != nil {
panic(err)
}
defer f.Close()
routersDir := AppConfig.DefaultString(context.Background(), "routersdir", "routers")
content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1)
content = strings.Replace(content, "{{.routersDir}}", routersDir, -1)
content = strings.Replace(content, "{{.globalimport}}", globalimport, -1)
f.WriteString(content)
}
}
func compareFile(pkgRealpath string) bool {
if !utils.FileExists(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) {
return true
}
if utils.FileExists(lastupdateFilename) {
content, err := ioutil.ReadFile(lastupdateFilename)
if err != nil {
return true
}
json.Unmarshal(content, &pkgLastupdate)
lastupdate, err := getpathTime(pkgRealpath)
if err != nil {
return true
}
if v, ok := pkgLastupdate[pkgRealpath]; ok {
if lastupdate <= v {
return false
}
}
}
return true
}
func savetoFile(pkgRealpath string) {
lastupdate, err := getpathTime(pkgRealpath)
if err != nil {
return
}
pkgLastupdate[pkgRealpath] = lastupdate
d, err := json.Marshal(pkgLastupdate)
if err != nil {
return
}
ioutil.WriteFile(lastupdateFilename, d, os.ModePerm)
}
func getpathTime(pkgRealpath string) (lastupdate int64, err error) {
fl, err := ioutil.ReadDir(pkgRealpath)
if err != nil {
return lastupdate, err
}
for _, f := range fl {
var t int64
if f.IsDir() {
t, err = getpathTime(filepath.Join(pkgRealpath, f.Name()))
if err != nil {
return lastupdate, err
}
} else {
t = f.ModTime().UnixNano()
}
if lastupdate < t {
lastupdate = t
}
}
return lastupdate, nil
}
func getRouterDir(pkgRealpath string) string {
dir := filepath.Dir(pkgRealpath)
for {
routersDir := AppConfig.DefaultString(context.Background(), "routersdir", "routers")
d := filepath.Join(dir, routersDir)
if utils.FileExists(d) {
return d
}
if r, _ := filepath.Rel(dir, AppPath); r == "." {
return d
}
// Parent dir.
dir = filepath.Dir(dir)
}
}

97
server/web/policy.go Normal file
View File

@ -0,0 +1,97 @@
// Copyright 2016 beego authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"strings"
"github.com/astaxie/beego/server/web/context"
)
// PolicyFunc defines a policy function which is invoked before the controller handler is executed.
type PolicyFunc func(*context.Context)
// FindPolicy Find Router info for URL
func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc {
var urlPath = cont.Input.URL()
if !BConfig.RouterCaseSensitive {
urlPath = strings.ToLower(urlPath)
}
httpMethod := cont.Input.Method()
isWildcard := false
// Find policy for current method
t, ok := p.policies[httpMethod]
// If not found - find policy for whole controller
if !ok {
t, ok = p.policies["*"]
isWildcard = true
}
if ok {
runObjects := t.Match(urlPath, cont)
if r, ok := runObjects.([]PolicyFunc); ok {
return r
} else if !isWildcard {
// If no policies found and we checked not for "*" method - try to find it
t, ok = p.policies["*"]
if ok {
runObjects = t.Match(urlPath, cont)
if r, ok = runObjects.([]PolicyFunc); ok {
return r
}
}
}
}
return nil
}
func (p *ControllerRegister) addToPolicy(method, pattern string, r ...PolicyFunc) {
method = strings.ToUpper(method)
p.enablePolicy = true
if !BConfig.RouterCaseSensitive {
pattern = strings.ToLower(pattern)
}
if t, ok := p.policies[method]; ok {
t.AddRouter(pattern, r)
} else {
t := NewTree()
t.AddRouter(pattern, r)
p.policies[method] = t
}
}
// Policy Register new policy in beego
func Policy(pattern, method string, policy ...PolicyFunc) {
BeeApp.Handlers.addToPolicy(method, pattern, policy...)
}
// Find policies and execute if were found
func (p *ControllerRegister) execPolicy(cont *context.Context, urlPath string) (started bool) {
if !p.enablePolicy {
return false
}
// Find Policy for method
policyList := p.FindPolicy(cont)
if len(policyList) > 0 {
// Run policies
for _, runPolicy := range policyList {
runPolicy(cont)
if cont.ResponseWriter.Started {
return true
}
}
return false
}
return false
}

997
server/web/router.go Normal file
View File

@ -0,0 +1,997 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"errors"
"fmt"
"net/http"
"path"
"reflect"
"strconv"
"strings"
"sync"
"time"
"github.com/astaxie/beego/core/logs"
"github.com/astaxie/beego/core/utils"
beecontext "github.com/astaxie/beego/server/web/context"
"github.com/astaxie/beego/server/web/context/param"
)
// default filter execution points
const (
BeforeStatic = iota
BeforeRouter
BeforeExec
AfterExec
FinishRouter
)
const (
routerTypeBeego = iota
routerTypeRESTFul
routerTypeHandler
)
var (
// HTTPMETHOD list the supported http methods.
HTTPMETHOD = map[string]bool{
"GET": true,
"POST": true,
"PUT": true,
"DELETE": true,
"PATCH": true,
"OPTIONS": true,
"HEAD": true,
"TRACE": true,
"CONNECT": true,
"MKCOL": true,
"COPY": true,
"MOVE": true,
"PROPFIND": true,
"PROPPATCH": true,
"LOCK": true,
"UNLOCK": true,
}
// these beego.Controller's methods shouldn't reflect to AutoRouter
exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString",
"RenderBytes", "Redirect", "Abort", "StopRun", "UrlFor", "ServeJSON", "ServeJSONP",
"ServeYAML", "ServeXML", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool",
"GetFloat", "GetFile", "SaveToFile", "StartSession", "SetSession", "GetSession",
"DelSession", "SessionRegenerateID", "DestroySession", "IsAjax", "GetSecureCookie",
"SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml",
"GetControllerAndAction", "ServeFormatted"}
urlPlaceholder = "{{placeholder}}"
// DefaultAccessLogFilter will skip the accesslog if return true
DefaultAccessLogFilter FilterHandler = &logFilter{}
)
// FilterHandler is an interface for
type FilterHandler interface {
Filter(*beecontext.Context) bool
}
// default log filter static file will not show
type logFilter struct {
}
func (l *logFilter) Filter(ctx *beecontext.Context) bool {
requestPath := path.Clean(ctx.Request.URL.Path)
if requestPath == "/favicon.ico" || requestPath == "/robots.txt" {
return true
}
for prefix := range BConfig.WebConfig.StaticDir {
if strings.HasPrefix(requestPath, prefix) {
return true
}
}
return false
}
// ExceptMethodAppend to append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter
func ExceptMethodAppend(action string) {
exceptMethod = append(exceptMethod, action)
}
// ControllerInfo holds information about the controller.
type ControllerInfo struct {
pattern string
controllerType reflect.Type
methods map[string]string
handler http.Handler
runFunction FilterFunc
routerType int
initialize func() ControllerInterface
methodParams []*param.MethodParam
}
func (c *ControllerInfo) GetPattern() string {
return c.pattern
}
// ControllerRegister containers registered router rules, controller handlers and filters.
type ControllerRegister struct {
routers map[string]*Tree
enablePolicy bool
policies map[string]*Tree
enableFilter bool
filters [FinishRouter + 1][]*FilterRouter
pool sync.Pool
// the filter created by FilterChain
chainRoot *FilterRouter
cfg *Config
}
// NewControllerRegister returns a new ControllerRegister.
// Usually you should not use this method
// please use NewControllerRegisterWithCfg
func NewControllerRegister() *ControllerRegister {
return NewControllerRegisterWithCfg(BeeApp.Cfg)
}
func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister {
res := &ControllerRegister{
routers: make(map[string]*Tree),
policies: make(map[string]*Tree),
pool: sync.Pool{
New: func() interface{} {
return beecontext.NewContext()
},
},
cfg: cfg,
}
res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false))
return res
}
// Add controller handler and pattern rules to ControllerRegister.
// usage:
// default methods is the same name as method
// Add("/user",&UserController{})
// Add("/api/list",&RestController{},"*:ListFood")
// Add("/api/create",&RestController{},"post:CreateFood")
// Add("/api/update",&RestController{},"put:UpdateFood")
// Add("/api/delete",&RestController{},"delete:DeleteFood")
// Add("/api",&RestController{},"get,post:ApiFunc"
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
p.addWithMethodParams(pattern, c, nil, mappingMethods...)
}
func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) {
reflectVal := reflect.ValueOf(c)
t := reflect.Indirect(reflectVal).Type()
methods := make(map[string]string)
if len(mappingMethods) > 0 {
semi := strings.Split(mappingMethods[0], ";")
for _, v := range semi {
colon := strings.Split(v, ":")
if len(colon) != 2 {
panic("method mapping format is invalid")
}
comma := strings.Split(colon[0], ",")
for _, m := range comma {
if m == "*" || HTTPMETHOD[strings.ToUpper(m)] {
if val := reflectVal.MethodByName(colon[1]); val.IsValid() {
methods[strings.ToUpper(m)] = colon[1]
} else {
panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name())
}
} else {
panic(v + " is an invalid method mapping. Method doesn't exist " + m)
}
}
}
}
route := &ControllerInfo{}
route.pattern = pattern
route.methods = methods
route.routerType = routerTypeBeego
route.controllerType = t
route.initialize = func() ControllerInterface {
vc := reflect.New(route.controllerType)
execController, ok := vc.Interface().(ControllerInterface)
if !ok {
panic("controller is not ControllerInterface")
}
elemVal := reflect.ValueOf(c).Elem()
elemType := reflect.TypeOf(c).Elem()
execElem := reflect.ValueOf(execController).Elem()
numOfFields := elemVal.NumField()
for i := 0; i < numOfFields; i++ {
fieldType := elemType.Field(i)
elemField := execElem.FieldByName(fieldType.Name)
if elemField.CanSet() {
fieldVal := elemVal.Field(i)
elemField.Set(fieldVal)
}
}
return execController
}
route.methodParams = methodParams
if len(methods) == 0 {
for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
}
} else {
for k := range methods {
if k == "*" {
for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
}
} else {
p.addToRouter(k, pattern, route)
}
}
}
}
func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) {
if !p.cfg.RouterCaseSensitive {
pattern = strings.ToLower(pattern)
}
if t, ok := p.routers[method]; ok {
t.AddRouter(pattern, r)
} else {
t := NewTree()
t.AddRouter(pattern, r)
p.routers[method] = t
}
}
// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller
// Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{})
func (p *ControllerRegister) Include(cList ...ControllerInterface) {
for _, c := range cList {
reflectVal := reflect.ValueOf(c)
t := reflect.Indirect(reflectVal).Type()
key := t.PkgPath() + ":" + t.Name()
if comm, ok := GlobalControllerRouter[key]; ok {
for _, a := range comm {
for _, f := range a.Filters {
p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams))
}
p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)
}
}
}
}
// GetContext returns a context from pool, so usually you should remember to call Reset function to clean the context
// And don't forget to give back context to pool
// example:
// ctx := p.GetContext()
// ctx.Reset(w, q)
// defer p.GiveBackContext(ctx)
func (p *ControllerRegister) GetContext() *beecontext.Context {
return p.pool.Get().(*beecontext.Context)
}
// GiveBackContext put the ctx into pool so that it could be reuse
func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) {
p.pool.Put(ctx)
}
// Get add get method
// usage:
// Get("/", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func (p *ControllerRegister) Get(pattern string, f FilterFunc) {
p.AddMethod("get", pattern, f)
}
// Post add post method
// usage:
// Post("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func (p *ControllerRegister) Post(pattern string, f FilterFunc) {
p.AddMethod("post", pattern, f)
}
// Put add put method
// usage:
// Put("/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func (p *ControllerRegister) Put(pattern string, f FilterFunc) {
p.AddMethod("put", pattern, f)
}
// Delete add delete method
// usage:
// Delete("/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func (p *ControllerRegister) Delete(pattern string, f FilterFunc) {
p.AddMethod("delete", pattern, f)
}
// Head add head method
// usage:
// Head("/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func (p *ControllerRegister) Head(pattern string, f FilterFunc) {
p.AddMethod("head", pattern, f)
}
// Patch add patch method
// usage:
// Patch("/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func (p *ControllerRegister) Patch(pattern string, f FilterFunc) {
p.AddMethod("patch", pattern, f)
}
// Options add options method
// usage:
// Options("/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func (p *ControllerRegister) Options(pattern string, f FilterFunc) {
p.AddMethod("options", pattern, f)
}
// Any add all method
// usage:
// Any("/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func (p *ControllerRegister) Any(pattern string, f FilterFunc) {
p.AddMethod("*", pattern, f)
}
// AddMethod add http method router
// usage:
// AddMethod("get","/api/:id", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
method = strings.ToUpper(method)
if method != "*" && !HTTPMETHOD[method] {
panic("not support http method: " + method)
}
route := &ControllerInfo{}
route.pattern = pattern
route.routerType = routerTypeRESTFul
route.runFunction = f
methods := make(map[string]string)
if method == "*" {
for val := range HTTPMETHOD {
methods[val] = val
}
} else {
methods[method] = method
}
route.methods = methods
for k := range methods {
if k == "*" {
for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
}
} else {
p.addToRouter(k, pattern, route)
}
}
}
// Handler add user defined Handler
func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) {
route := &ControllerInfo{}
route.pattern = pattern
route.routerType = routerTypeHandler
route.handler = h
if len(options) > 0 {
if _, ok := options[0].(bool); ok {
pattern = path.Join(pattern, "?:all(.*)")
}
}
for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
}
}
// AddAuto router to ControllerRegister.
// example beego.AddAuto(&MainContorlller{}),
// MainController has method List and Page.
// visit the url /main/list to execute List function
// /main/page to execute Page function.
func (p *ControllerRegister) AddAuto(c ControllerInterface) {
p.AddAutoPrefix("/", c)
}
// AddAutoPrefix Add auto router to ControllerRegister with prefix.
// example beego.AddAutoPrefix("/admin",&MainContorlller{}),
// MainController has method List and Page.
// visit the url /admin/main/list to execute List function
// /admin/main/page to execute Page function.
func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) {
reflectVal := reflect.ValueOf(c)
rt := reflectVal.Type()
ct := reflect.Indirect(reflectVal).Type()
controllerName := strings.TrimSuffix(ct.Name(), "Controller")
for i := 0; i < rt.NumMethod(); i++ {
if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
route := &ControllerInfo{}
route.routerType = routerTypeBeego
route.methods = map[string]string{"*": rt.Method(i).Name}
route.controllerType = ct
pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*")
patternInit := path.Join(prefix, controllerName, rt.Method(i).Name, "*")
patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name))
patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name)
route.pattern = pattern
for m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
p.addToRouter(m, patternInit, route)
p.addToRouter(m, patternFix, route)
p.addToRouter(m, patternFixInit, route)
}
}
}
}
// InsertFilter Add a FilterFunc with pattern rule and action constant.
// params is for:
// 1. setting the returnOnOutput value (false allows multiple filters to execute)
// 2. determining whether or not params need to be reset.
func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, opts ...FilterOpt) error {
opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive))
mr := newFilterRouter(pattern, filter, opts...)
return p.insertFilterRouter(pos, mr)
}
// InsertFilterChain is similar to InsertFilter,
// but it will using chainRoot.filterFunc as input to build a new filterFunc
// for example, assume that chainRoot is funcA
// and we add new FilterChain
// fc := func(next) {
// return func(ctx) {
// // do something
// next(ctx)
// // do something
// }
// }
func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) {
root := p.chainRoot
filterFunc := chain(root.filterFunc)
opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive))
p.chainRoot = newFilterRouter(pattern, filterFunc, opts...)
p.chainRoot.next = root
}
// add Filter into
func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) {
if pos < BeforeStatic || pos > FinishRouter {
return errors.New("can not find your filter position")
}
p.enableFilter = true
p.filters[pos] = append(p.filters[pos], mr)
return nil
}
// URLFor does another controller handler in this request function.
// it can access any controller method.
func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string {
paths := strings.Split(endpoint, ".")
if len(paths) <= 1 {
logs.Warn("urlfor endpoint must like path.controller.method")
return ""
}
if len(values)%2 != 0 {
logs.Warn("urlfor params must key-value pair")
return ""
}
params := make(map[string]string)
if len(values) > 0 {
key := ""
for k, v := range values {
if k%2 == 0 {
key = fmt.Sprint(v)
} else {
params[key] = fmt.Sprint(v)
}
}
}
controllerName := strings.Join(paths[:len(paths)-1], "/")
methodName := paths[len(paths)-1]
for m, t := range p.routers {
ok, url := p.getURL(t, "/", controllerName, methodName, params, m)
if ok {
return url
}
}
return ""
}
func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName string, params map[string]string, httpMethod string) (bool, string) {
for _, subtree := range t.fixrouters {
u := path.Join(url, subtree.prefix)
ok, u := p.getURL(subtree, u, controllerName, methodName, params, httpMethod)
if ok {
return ok, u
}
}
if t.wildcard != nil {
u := path.Join(url, urlPlaceholder)
ok, u := p.getURL(t.wildcard, u, controllerName, methodName, params, httpMethod)
if ok {
return ok, u
}
}
for _, l := range t.leaves {
if c, ok := l.runObject.(*ControllerInfo); ok {
if c.routerType == routerTypeBeego &&
strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) {
find := false
if HTTPMETHOD[strings.ToUpper(methodName)] {
if len(c.methods) == 0 {
find = true
} else if m, ok := c.methods[strings.ToUpper(methodName)]; ok && m == strings.ToUpper(methodName) {
find = true
} else if m, ok = c.methods["*"]; ok && m == methodName {
find = true
}
}
if !find {
for m, md := range c.methods {
if (m == "*" || m == httpMethod) && md == methodName {
find = true
}
}
}
if find {
if l.regexps == nil {
if len(l.wildcards) == 0 {
return true, strings.Replace(url, "/"+urlPlaceholder, "", 1) + toURL(params)
}
if len(l.wildcards) == 1 {
if v, ok := params[l.wildcards[0]]; ok {
delete(params, l.wildcards[0])
return true, strings.Replace(url, urlPlaceholder, v, 1) + toURL(params)
}
return false, ""
}
if len(l.wildcards) == 3 && l.wildcards[0] == "." {
if p, ok := params[":path"]; ok {
if e, isok := params[":ext"]; isok {
delete(params, ":path")
delete(params, ":ext")
return true, strings.Replace(url, urlPlaceholder, p+"."+e, -1) + toURL(params)
}
}
}
canSkip := false
for _, v := range l.wildcards {
if v == ":" {
canSkip = true
continue
}
if u, ok := params[v]; ok {
delete(params, v)
url = strings.Replace(url, urlPlaceholder, u, 1)
} else {
if canSkip {
canSkip = false
continue
}
return false, ""
}
}
return true, url + toURL(params)
}
var i int
var startReg bool
regURL := ""
for _, v := range strings.Trim(l.regexps.String(), "^$") {
if v == '(' {
startReg = true
continue
} else if v == ')' {
startReg = false
if v, ok := params[l.wildcards[i]]; ok {
delete(params, l.wildcards[i])
regURL = regURL + v
i++
} else {
break
}
} else if !startReg {
regURL = string(append([]rune(regURL), v))
}
}
if l.regexps.MatchString(regURL) {
ps := strings.Split(regURL, "/")
for _, p := range ps {
url = strings.Replace(url, urlPlaceholder, p, 1)
}
return true, url + toURL(params)
}
}
}
}
}
return false, ""
}
func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) {
var preFilterParams map[string]string
for _, filterR := range p.filters[pos] {
b, done := filterR.filter(context, urlPath, preFilterParams)
if done {
return b
}
}
return false
}
// Implement http.Handler interface.
func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
ctx := p.GetContext()
ctx.Reset(rw, r)
defer p.GiveBackContext(ctx)
var preFilterParams map[string]string
p.chainRoot.filter(ctx, p.getUrlPath(ctx), preFilterParams)
}
func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) {
startTime := time.Now()
r := ctx.Request
rw := ctx.ResponseWriter.ResponseWriter
var (
runRouter reflect.Type
findRouter bool
runMethod string
methodParams []*param.MethodParam
routerInfo *ControllerInfo
isRunnable bool
)
if p.cfg.RecoverFunc != nil {
defer p.cfg.RecoverFunc(ctx, p.cfg)
}
ctx.Output.EnableGzip = p.cfg.EnableGzip
if p.cfg.RunMode == DEV {
ctx.Output.Header("Server", p.cfg.ServerName)
}
urlPath := p.getUrlPath(ctx)
// filter wrong http method
if !HTTPMETHOD[r.Method] {
exception("405", ctx)
goto Admin
}
// filter for static file
if len(p.filters[BeforeStatic]) > 0 && p.execFilter(ctx, urlPath, BeforeStatic) {
goto Admin
}
serverStaticRouter(ctx)
if ctx.ResponseWriter.Started {
findRouter = true
goto Admin
}
if r.Method != http.MethodGet && r.Method != http.MethodHead {
if p.cfg.CopyRequestBody && !ctx.Input.IsUpload() {
// connection will close if the incoming data are larger (RFC 7231, 6.5.11)
if r.ContentLength > p.cfg.MaxMemory {
logs.Error(errors.New("payload too large"))
exception("413", ctx)
goto Admin
}
ctx.Input.CopyBody(p.cfg.MaxMemory)
}
ctx.Input.ParseFormOrMulitForm(p.cfg.MaxMemory)
}
// session init
if p.cfg.WebConfig.Session.SessionOn {
var err error
ctx.Input.CruSession, err = GlobalSessions.SessionStart(rw, r)
if err != nil {
logs.Error(err)
exception("503", ctx)
goto Admin
}
defer func() {
if ctx.Input.CruSession != nil {
ctx.Input.CruSession.SessionRelease(nil, rw)
}
}()
}
if len(p.filters[BeforeRouter]) > 0 && p.execFilter(ctx, urlPath, BeforeRouter) {
goto Admin
}
// User can define RunController and RunMethod in filter
if ctx.Input.RunController != nil && ctx.Input.RunMethod != "" {
findRouter = true
runMethod = ctx.Input.RunMethod
runRouter = ctx.Input.RunController
} else {
routerInfo, findRouter = p.FindRouter(ctx)
}
// if no matches to url, throw a not found exception
if !findRouter {
exception("404", ctx)
goto Admin
}
if splat := ctx.Input.Param(":splat"); splat != "" {
for k, v := range strings.Split(splat, "/") {
ctx.Input.SetParam(strconv.Itoa(k), v)
}
}
if routerInfo != nil {
// store router pattern into context
ctx.Input.SetData("RouterPattern", routerInfo.pattern)
}
// execute middleware filters
if len(p.filters[BeforeExec]) > 0 && p.execFilter(ctx, urlPath, BeforeExec) {
goto Admin
}
// check policies
if p.execPolicy(ctx, urlPath) {
goto Admin
}
if routerInfo != nil {
if routerInfo.routerType == routerTypeRESTFul {
if _, ok := routerInfo.methods[r.Method]; ok {
isRunnable = true
routerInfo.runFunction(ctx)
} else {
exception("405", ctx)
goto Admin
}
} else if routerInfo.routerType == routerTypeHandler {
isRunnable = true
routerInfo.handler.ServeHTTP(ctx.ResponseWriter, ctx.Request)
} else {
runRouter = routerInfo.controllerType
methodParams = routerInfo.methodParams
method := r.Method
if r.Method == http.MethodPost && ctx.Input.Query("_method") == http.MethodPut {
method = http.MethodPut
}
if r.Method == http.MethodPost && ctx.Input.Query("_method") == http.MethodDelete {
method = http.MethodDelete
}
if m, ok := routerInfo.methods[method]; ok {
runMethod = m
} else if m, ok = routerInfo.methods["*"]; ok {
runMethod = m
} else {
runMethod = method
}
}
}
// also defined runRouter & runMethod from filter
if !isRunnable {
// Invoke the request handler
var execController ControllerInterface
if routerInfo != nil && routerInfo.initialize != nil {
execController = routerInfo.initialize()
} else {
vc := reflect.New(runRouter)
var ok bool
execController, ok = vc.Interface().(ControllerInterface)
if !ok {
panic("controller is not ControllerInterface")
}
}
// call the controller init function
execController.Init(ctx, runRouter.Name(), runMethod, execController)
// call prepare function
execController.Prepare()
// if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf
if p.cfg.WebConfig.EnableXSRF {
execController.XSRFToken()
if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut ||
(r.Method == http.MethodPost && (ctx.Input.Query("_method") == http.MethodDelete || ctx.Input.Query("_method") == http.MethodPut)) {
execController.CheckXSRFCookie()
}
}
execController.URLMapping()
if !ctx.ResponseWriter.Started {
// exec main logic
switch runMethod {
case http.MethodGet:
execController.Get()
case http.MethodPost:
execController.Post()
case http.MethodDelete:
execController.Delete()
case http.MethodPut:
execController.Put()
case http.MethodHead:
execController.Head()
case http.MethodPatch:
execController.Patch()
case http.MethodOptions:
execController.Options()
case http.MethodTrace:
execController.Trace()
default:
if !execController.HandlerFunc(runMethod) {
vc := reflect.ValueOf(execController)
method := vc.MethodByName(runMethod)
in := param.ConvertParams(methodParams, method.Type(), ctx)
out := method.Call(in)
// For backward compatibility we only handle response if we had incoming methodParams
if methodParams != nil {
p.handleParamResponse(ctx, execController, out)
}
}
}
// render template
if !ctx.ResponseWriter.Started && ctx.Output.Status == 0 {
if p.cfg.WebConfig.AutoRender {
if err := execController.Render(); err != nil {
logs.Error(err)
}
}
}
}
// finish all runRouter. release resource
execController.Finish()
}
// execute middleware filters
if len(p.filters[AfterExec]) > 0 && p.execFilter(ctx, urlPath, AfterExec) {
goto Admin
}
if len(p.filters[FinishRouter]) > 0 && p.execFilter(ctx, urlPath, FinishRouter) {
goto Admin
}
Admin:
// admin module record QPS
statusCode := ctx.ResponseWriter.Status
if statusCode == 0 {
statusCode = 200
}
LogAccess(ctx, &startTime, statusCode)
timeDur := time.Since(startTime)
ctx.ResponseWriter.Elapsed = timeDur
if p.cfg.Listen.EnableAdmin {
pattern := ""
if routerInfo != nil {
pattern = routerInfo.pattern
}
if FilterMonitorFunc(r.Method, r.URL.Path, timeDur, pattern, statusCode) {
routerName := ""
if runRouter != nil {
routerName = runRouter.Name()
}
go StatisticsMap.AddStatistics(r.Method, r.URL.Path, routerName, timeDur)
}
}
if p.cfg.RunMode == DEV && !p.cfg.Log.AccessLogs {
match := map[bool]string{true: "match", false: "nomatch"}
devInfo := fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s",
ctx.Input.IP(),
logs.ColorByStatus(statusCode), statusCode, logs.ResetColor(),
timeDur.String(),
match[findRouter],
logs.ColorByMethod(r.Method), r.Method, logs.ResetColor(),
r.URL.Path)
if routerInfo != nil {
devInfo += fmt.Sprintf(" r:%s", routerInfo.pattern)
}
logs.Debug(devInfo)
}
// Call WriteHeader if status code has been set changed
if ctx.Output.Status != 0 {
ctx.ResponseWriter.WriteHeader(ctx.Output.Status)
}
}
func (p *ControllerRegister) getUrlPath(ctx *beecontext.Context) string {
urlPath := ctx.Request.URL.Path
if !p.cfg.RouterCaseSensitive {
urlPath = strings.ToLower(urlPath)
}
return urlPath
}
func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) {
// looping in reverse order for the case when both error and value are returned and error sets the response status code
for i := len(results) - 1; i >= 0; i-- {
result := results[i]
if result.Kind() != reflect.Interface || !result.IsNil() {
resultValue := result.Interface()
context.RenderMethodResult(resultValue)
}
}
if !context.ResponseWriter.Started && len(results) > 0 && context.Output.Status == 0 {
context.Output.SetStatus(200)
}
}
// FindRouter Find Router info for URL
func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) {
var urlPath = context.Input.URL()
if !p.cfg.RouterCaseSensitive {
urlPath = strings.ToLower(urlPath)
}
httpMethod := context.Input.Method()
if t, ok := p.routers[httpMethod]; ok {
runObject := t.Match(urlPath, context)
if r, ok := runObject.(*ControllerInfo); ok {
return r, true
}
}
return
}
func toURL(params map[string]string) string {
if len(params) == 0 {
return ""
}
u := "?"
for k, v := range params {
u += k + "=" + v + "&"
}
return strings.TrimRight(u, "&")
}
// LogAccess logging info HTTP Access
func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) {
BeeApp.LogAccess(ctx, startTime, statusCode)
}

752
server/web/router_test.go Normal file
View File

@ -0,0 +1,752 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"bytes"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/astaxie/beego/core/logs"
"github.com/astaxie/beego/server/web/context"
)
type TestController struct {
Controller
}
func (tc *TestController) Get() {
tc.Data["Username"] = "astaxie"
tc.Ctx.Output.Body([]byte("ok"))
}
func (tc *TestController) Post() {
tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Query(":name")))
}
func (tc *TestController) Param() {
tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Query(":name")))
}
func (tc *TestController) List() {
tc.Ctx.Output.Body([]byte("i am list"))
}
func (tc *TestController) Params() {
tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Param("0") + tc.Ctx.Input.Param("1") + tc.Ctx.Input.Param("2")))
}
func (tc *TestController) Myext() {
tc.Ctx.Output.Body([]byte(tc.Ctx.Input.Param(":ext")))
}
func (tc *TestController) GetURL() {
tc.Ctx.Output.Body([]byte(tc.URLFor(".Myext")))
}
func (tc *TestController) GetParams() {
tc.Ctx.WriteString(tc.Ctx.Input.Query(":last") + "+" +
tc.Ctx.Input.Query(":first") + "+" + tc.Ctx.Input.Query("learn"))
}
func (tc *TestController) GetManyRouter() {
tc.Ctx.WriteString(tc.Ctx.Input.Query(":id") + tc.Ctx.Input.Query(":page"))
}
func (tc *TestController) GetEmptyBody() {
var res []byte
tc.Ctx.Output.Body(res)
}
type JSONController struct {
Controller
}
func (jc *JSONController) Prepare() {
jc.Data["json"] = "prepare"
jc.ServeJSON(true)
}
func (jc *JSONController) Get() {
jc.Data["Username"] = "astaxie"
jc.Ctx.Output.Body([]byte("ok"))
}
func TestUrlFor(t *testing.T) {
handler := NewControllerRegister()
handler.Add("/api/list", &TestController{}, "*:List")
handler.Add("/person/:last/:first", &TestController{}, "*:Param")
if a := handler.URLFor("TestController.List"); a != "/api/list" {
logs.Info(a)
t.Errorf("TestController.List must equal to /api/list")
}
if a := handler.URLFor("TestController.Param", ":last", "xie", ":first", "asta"); a != "/person/xie/asta" {
t.Errorf("TestController.Param must equal to /person/xie/asta, but get " + a)
}
}
func TestUrlFor3(t *testing.T) {
handler := NewControllerRegister()
handler.AddAuto(&TestController{})
if a := handler.URLFor("TestController.Myext"); a != "/test/myext" && a != "/Test/Myext" {
t.Errorf("TestController.Myext must equal to /test/myext, but get " + a)
}
if a := handler.URLFor("TestController.GetURL"); a != "/test/geturl" && a != "/Test/GetURL" {
t.Errorf("TestController.GetURL must equal to /test/geturl, but get " + a)
}
}
func TestUrlFor2(t *testing.T) {
handler := NewControllerRegister()
handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, "*:List")
handler.Add("/v1/:username/edit", &TestController{}, "get:GetURL")
handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param")
handler.Add("/:year:int/:month:int/:title/:entid", &TestController{})
if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" {
logs.Info(handler.URLFor("TestController.GetURL"))
t.Errorf("TestController.List must equal to /v1/astaxie/edit")
}
if handler.URLFor("TestController.List", ":v", "za", ":id", "12", ":page", "123") !=
"/v1/za/cms_12_123.html" {
logs.Info(handler.URLFor("TestController.List"))
t.Errorf("TestController.List must equal to /v1/za/cms_12_123.html")
}
if handler.URLFor("TestController.Param", ":v", "za", ":id", "12", ":page", "123") !=
"/v1/za_cms/ttt_12_123.html" {
logs.Info(handler.URLFor("TestController.Param"))
t.Errorf("TestController.List must equal to /v1/za_cms/ttt_12_123.html")
}
if handler.URLFor("TestController.Get", ":year", "1111", ":month", "11",
":title", "aaaa", ":entid", "aaaa") !=
"/1111/11/aaaa/aaaa" {
logs.Info(handler.URLFor("TestController.Get"))
t.Errorf("TestController.Get must equal to /1111/11/aaaa/aaaa")
}
}
func TestUserFunc(t *testing.T) {
r, _ := http.NewRequest("GET", "/api/list", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/api/list", &TestController{}, "*:List")
handler.ServeHTTP(w, r)
if w.Body.String() != "i am list" {
t.Errorf("user define func can't run")
}
}
func TestPostFunc(t *testing.T) {
r, _ := http.NewRequest("POST", "/astaxie", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/:name", &TestController{})
handler.ServeHTTP(w, r)
if w.Body.String() != "astaxie" {
t.Errorf("post func should astaxie")
}
}
func TestAutoFunc(t *testing.T) {
r, _ := http.NewRequest("GET", "/test/list", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.AddAuto(&TestController{})
handler.ServeHTTP(w, r)
if w.Body.String() != "i am list" {
t.Errorf("user define func can't run")
}
}
func TestAutoFunc2(t *testing.T) {
r, _ := http.NewRequest("GET", "/Test/List", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.AddAuto(&TestController{})
handler.ServeHTTP(w, r)
if w.Body.String() != "i am list" {
t.Errorf("user define func can't run")
}
}
func TestAutoFuncParams(t *testing.T) {
r, _ := http.NewRequest("GET", "/test/params/2009/11/12", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.AddAuto(&TestController{})
handler.ServeHTTP(w, r)
if w.Body.String() != "20091112" {
t.Errorf("user define func can't run")
}
}
func TestAutoExtFunc(t *testing.T) {
r, _ := http.NewRequest("GET", "/test/myext.json", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.AddAuto(&TestController{})
handler.ServeHTTP(w, r)
if w.Body.String() != "json" {
t.Errorf("user define func can't run")
}
}
func TestEscape(t *testing.T) {
r, _ := http.NewRequest("GET", "/search/%E4%BD%A0%E5%A5%BD", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Get("/search/:keyword(.+)", func(ctx *context.Context) {
value := ctx.Input.Param(":keyword")
ctx.Output.Body([]byte(value))
})
handler.ServeHTTP(w, r)
str := w.Body.String()
if str != "你好" {
t.Errorf("incorrect, %s", str)
}
}
func TestRouteOk(t *testing.T) {
r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/person/:last/:first", &TestController{}, "get:GetParams")
handler.ServeHTTP(w, r)
body := w.Body.String()
if body != "anderson+thomas+kungfu" {
t.Errorf("url param set to [%s];", body)
}
}
func TestManyRoute(t *testing.T) {
r, _ := http.NewRequest("GET", "/beego32-12.html", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, "get:GetManyRouter")
handler.ServeHTTP(w, r)
body := w.Body.String()
if body != "3212" {
t.Errorf("url param set to [%s];", body)
}
}
// Test for issue #1669
func TestEmptyResponse(t *testing.T) {
r, _ := http.NewRequest("GET", "/beego-empty.html", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/beego-empty.html", &TestController{}, "get:GetEmptyBody")
handler.ServeHTTP(w, r)
if body := w.Body.String(); body != "" {
t.Error("want empty body")
}
}
func TestNotFound(t *testing.T) {
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.ServeHTTP(w, r)
if w.Code != http.StatusNotFound {
t.Errorf("Code set to [%v]; want [%v]", w.Code, http.StatusNotFound)
}
}
// TestStatic tests the ability to serve static
// content from the filesystem
func TestStatic(t *testing.T) {
r, _ := http.NewRequest("GET", "/static/js/jquery.js", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.ServeHTTP(w, r)
if w.Code != 404 {
t.Errorf("handler.Static failed to serve file")
}
}
func TestPrepare(t *testing.T) {
r, _ := http.NewRequest("GET", "/json/list", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/json/list", &JSONController{})
handler.ServeHTTP(w, r)
if w.Body.String() != `"prepare"` {
t.Errorf(w.Body.String() + "user define func can't run")
}
}
func TestAutoPrefix(t *testing.T) {
r, _ := http.NewRequest("GET", "/admin/test/list", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.AddAutoPrefix("/admin", &TestController{})
handler.ServeHTTP(w, r)
if w.Body.String() != "i am list" {
t.Errorf("TestAutoPrefix can't run")
}
}
func TestRouterGet(t *testing.T) {
r, _ := http.NewRequest("GET", "/user", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Get("/user", func(ctx *context.Context) {
ctx.Output.Body([]byte("Get userlist"))
})
handler.ServeHTTP(w, r)
if w.Body.String() != "Get userlist" {
t.Errorf("TestRouterGet can't run")
}
}
func TestRouterPost(t *testing.T) {
r, _ := http.NewRequest("POST", "/user/123", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Post("/user/:id", func(ctx *context.Context) {
ctx.Output.Body([]byte(ctx.Input.Param(":id")))
})
handler.ServeHTTP(w, r)
if w.Body.String() != "123" {
t.Errorf("TestRouterPost can't run")
}
}
func sayhello(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("sayhello"))
}
func TestRouterHandler(t *testing.T) {
r, _ := http.NewRequest("POST", "/sayhi", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Handler("/sayhi", http.HandlerFunc(sayhello))
handler.ServeHTTP(w, r)
if w.Body.String() != "sayhello" {
t.Errorf("TestRouterHandler can't run")
}
}
func TestRouterHandlerAll(t *testing.T) {
r, _ := http.NewRequest("POST", "/sayhi/a/b/c", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Handler("/sayhi", http.HandlerFunc(sayhello), true)
handler.ServeHTTP(w, r)
if w.Body.String() != "sayhello" {
t.Errorf("TestRouterHandler can't run")
}
}
//
// Benchmarks NewHttpSever:
//
func beegoFilterFunc(ctx *context.Context) {
ctx.WriteString("hello")
}
type AdminController struct {
Controller
}
func (a *AdminController) Get() {
a.Ctx.WriteString("hello")
}
func TestRouterFunc(t *testing.T) {
mux := NewControllerRegister()
mux.Get("/action", beegoFilterFunc)
mux.Post("/action", beegoFilterFunc)
rw, r := testRequest("GET", "/action")
mux.ServeHTTP(rw, r)
if rw.Body.String() != "hello" {
t.Errorf("TestRouterFunc can't run")
}
}
func BenchmarkFunc(b *testing.B) {
mux := NewControllerRegister()
mux.Get("/action", beegoFilterFunc)
rw, r := testRequest("GET", "/action")
b.ResetTimer()
for i := 0; i < b.N; i++ {
mux.ServeHTTP(rw, r)
}
}
func BenchmarkController(b *testing.B) {
mux := NewControllerRegister()
mux.Add("/action", &AdminController{})
rw, r := testRequest("GET", "/action")
b.ResetTimer()
for i := 0; i < b.N; i++ {
mux.ServeHTTP(rw, r)
}
}
func testRequest(method, path string) (*httptest.ResponseRecorder, *http.Request) {
request, _ := http.NewRequest(method, path, nil)
recorder := httptest.NewRecorder()
return recorder, request
}
// Expectation: A Filter with the correct configuration should be created given
// specific parameters.
func TestInsertFilter(t *testing.T) {
testName := "TestInsertFilter"
mux := NewControllerRegister()
mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(true))
if !mux.filters[BeforeRouter][0].returnOnOutput {
t.Errorf(
"%s: passing no variadic params should set returnOnOutput to true",
testName)
}
if mux.filters[BeforeRouter][0].resetParams {
t.Errorf(
"%s: passing no variadic params should set resetParams to false",
testName)
}
mux = NewControllerRegister()
mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(false))
if mux.filters[BeforeRouter][0].returnOnOutput {
t.Errorf(
"%s: passing false as 1st variadic param should set returnOnOutput to false",
testName)
}
mux = NewControllerRegister()
mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(true), WithResetParams(true))
if !mux.filters[BeforeRouter][0].resetParams {
t.Errorf(
"%s: passing true as 2nd variadic param should set resetParams to true",
testName)
}
}
// Expectation: the second variadic arg should cause the execution of the filter
// to preserve the parameters from before its execution.
func TestParamResetFilter(t *testing.T) {
testName := "TestParamResetFilter"
route := "/beego/*" // splat
path := "/beego/routes/routes"
mux := NewControllerRegister()
mux.InsertFilter("*", BeforeExec, beegoResetParams, WithReturnOnOutput(true), WithResetParams(true))
mux.Get(route, beegoHandleResetParams)
rw, r := testRequest("GET", path)
mux.ServeHTTP(rw, r)
// The two functions, `beegoResetParams` and `beegoHandleResetParams` add
// a response header of `Splat`. The expectation here is that that Header
// value should match what the _request's_ router set, not the filter's.
headers := rw.Result().Header
if len(headers["Splat"]) != 1 {
t.Errorf(
"%s: There was an error in the test. Splat param not set in Header",
testName)
}
if headers["Splat"][0] != "routes/routes" {
t.Errorf(
"%s: expected `:splat` param to be [routes/routes] but it was [%s]",
testName, headers["Splat"][0])
}
}
// Execution point: BeforeRouter
// expectation: only BeforeRouter function is executed, notmatch output as router doesn't handle
func TestFilterBeforeRouter(t *testing.T) {
testName := "TestFilterBeforeRouter"
url := "/beforeRouter"
mux := NewControllerRegister()
mux.InsertFilter(url, BeforeRouter, beegoBeforeRouter1)
mux.Get(url, beegoFilterFunc)
rw, r := testRequest("GET", url)
mux.ServeHTTP(rw, r)
if !strings.Contains(rw.Body.String(), "BeforeRouter1") {
t.Errorf(testName + " BeforeRouter did not run")
}
if strings.Contains(rw.Body.String(), "hello") {
t.Errorf(testName + " BeforeRouter did not return properly")
}
}
// Execution point: BeforeExec
// expectation: only BeforeExec function is executed, match as router determines route only
func TestFilterBeforeExec(t *testing.T) {
testName := "TestFilterBeforeExec"
url := "/beforeExec"
mux := NewControllerRegister()
mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput, WithReturnOnOutput(true))
mux.InsertFilter(url, BeforeExec, beegoBeforeExec1, WithReturnOnOutput(true))
mux.Get(url, beegoFilterFunc)
rw, r := testRequest("GET", url)
mux.ServeHTTP(rw, r)
if !strings.Contains(rw.Body.String(), "BeforeExec1") {
t.Errorf(testName + " BeforeExec did not run")
}
if strings.Contains(rw.Body.String(), "hello") {
t.Errorf(testName + " BeforeExec did not return properly")
}
if strings.Contains(rw.Body.String(), "BeforeRouter") {
t.Errorf(testName + " BeforeRouter ran in error")
}
}
// Execution point: AfterExec
// expectation: only AfterExec function is executed, match as router handles
func TestFilterAfterExec(t *testing.T) {
testName := "TestFilterAfterExec"
url := "/afterExec"
mux := NewControllerRegister()
mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput)
mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput)
mux.InsertFilter(url, AfterExec, beegoAfterExec1, WithReturnOnOutput(false))
mux.Get(url, beegoFilterFunc)
rw, r := testRequest("GET", url)
mux.ServeHTTP(rw, r)
if !strings.Contains(rw.Body.String(), "AfterExec1") {
t.Errorf(testName + " AfterExec did not run")
}
if !strings.Contains(rw.Body.String(), "hello") {
t.Errorf(testName + " handler did not run properly")
}
if strings.Contains(rw.Body.String(), "BeforeRouter") {
t.Errorf(testName + " BeforeRouter ran in error")
}
if strings.Contains(rw.Body.String(), "BeforeExec") {
t.Errorf(testName + " BeforeExec ran in error")
}
}
// Execution point: FinishRouter
// expectation: only FinishRouter function is executed, match as router handles
func TestFilterFinishRouter(t *testing.T) {
testName := "TestFilterFinishRouter"
url := "/finishRouter"
mux := NewControllerRegister()
mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput, WithReturnOnOutput(true))
mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput, WithReturnOnOutput(true))
mux.InsertFilter(url, AfterExec, beegoFilterNoOutput, WithReturnOnOutput(true))
mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(true))
mux.Get(url, beegoFilterFunc)
rw, r := testRequest("GET", url)
mux.ServeHTTP(rw, r)
if strings.Contains(rw.Body.String(), "FinishRouter1") {
t.Errorf(testName + " FinishRouter did not run")
}
if !strings.Contains(rw.Body.String(), "hello") {
t.Errorf(testName + " handler did not run properly")
}
if strings.Contains(rw.Body.String(), "AfterExec1") {
t.Errorf(testName + " AfterExec ran in error")
}
if strings.Contains(rw.Body.String(), "BeforeRouter") {
t.Errorf(testName + " BeforeRouter ran in error")
}
if strings.Contains(rw.Body.String(), "BeforeExec") {
t.Errorf(testName + " BeforeExec ran in error")
}
}
// Execution point: FinishRouter
// expectation: only first FinishRouter function is executed, match as router handles
func TestFilterFinishRouterMultiFirstOnly(t *testing.T) {
testName := "TestFilterFinishRouterMultiFirstOnly"
url := "/finishRouterMultiFirstOnly"
mux := NewControllerRegister()
mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(false))
mux.InsertFilter(url, FinishRouter, beegoFinishRouter2)
mux.Get(url, beegoFilterFunc)
rw, r := testRequest("GET", url)
mux.ServeHTTP(rw, r)
if !strings.Contains(rw.Body.String(), "FinishRouter1") {
t.Errorf(testName + " FinishRouter1 did not run")
}
if !strings.Contains(rw.Body.String(), "hello") {
t.Errorf(testName + " handler did not run properly")
}
// not expected in body
if strings.Contains(rw.Body.String(), "FinishRouter2") {
t.Errorf(testName + " FinishRouter2 did run")
}
}
// Execution point: FinishRouter
// expectation: both FinishRouter functions execute, match as router handles
func TestFilterFinishRouterMulti(t *testing.T) {
testName := "TestFilterFinishRouterMulti"
url := "/finishRouterMulti"
mux := NewControllerRegister()
mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(false))
mux.InsertFilter(url, FinishRouter, beegoFinishRouter2, WithReturnOnOutput(false))
mux.Get(url, beegoFilterFunc)
rw, r := testRequest("GET", url)
mux.ServeHTTP(rw, r)
if !strings.Contains(rw.Body.String(), "FinishRouter1") {
t.Errorf(testName + " FinishRouter1 did not run")
}
if !strings.Contains(rw.Body.String(), "hello") {
t.Errorf(testName + " handler did not run properly")
}
if !strings.Contains(rw.Body.String(), "FinishRouter2") {
t.Errorf(testName + " FinishRouter2 did not run properly")
}
}
func beegoFilterNoOutput(ctx *context.Context) {
}
func beegoBeforeRouter1(ctx *context.Context) {
ctx.WriteString("|BeforeRouter1")
}
func beegoBeforeExec1(ctx *context.Context) {
ctx.WriteString("|BeforeExec1")
}
func beegoAfterExec1(ctx *context.Context) {
ctx.WriteString("|AfterExec1")
}
func beegoFinishRouter1(ctx *context.Context) {
ctx.WriteString("|FinishRouter1")
}
func beegoFinishRouter2(ctx *context.Context) {
ctx.WriteString("|FinishRouter2")
}
func beegoResetParams(ctx *context.Context) {
ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat"))
}
func beegoHandleResetParams(ctx *context.Context) {
ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat"))
}
// YAML
type YAMLController struct {
Controller
}
func (jc *YAMLController) Prepare() {
jc.Data["yaml"] = "prepare"
jc.ServeYAML()
}
func (jc *YAMLController) Get() {
jc.Data["Username"] = "astaxie"
jc.Ctx.Output.Body([]byte("ok"))
}
func TestYAMLPrepare(t *testing.T) {
r, _ := http.NewRequest("GET", "/yaml/list", nil)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Add("/yaml/list", &YAMLController{})
handler.ServeHTTP(w, r)
if strings.TrimSpace(w.Body.String()) != "prepare" {
t.Errorf(w.Body.String())
}
}
func TestRouterEntityTooLargeCopyBody(t *testing.T) {
_MaxMemory := BConfig.MaxMemory
_CopyRequestBody := BConfig.CopyRequestBody
BConfig.CopyRequestBody = true
BConfig.MaxMemory = 20
BeeApp.Cfg.CopyRequestBody = true
BeeApp.Cfg.MaxMemory = 20
b := bytes.NewBuffer([]byte("barbarbarbarbarbarbarbarbarbar"))
r, _ := http.NewRequest("POST", "/user/123", b)
w := httptest.NewRecorder()
handler := NewControllerRegister()
handler.Post("/user/:id", func(ctx *context.Context) {
ctx.Output.Body([]byte(ctx.Input.Param(":id")))
})
handler.ServeHTTP(w, r)
BConfig.CopyRequestBody = _CopyRequestBody
BConfig.MaxMemory = _MaxMemory
if w.Code != http.StatusRequestEntityTooLarge {
t.Errorf("TestRouterRequestEntityTooLarge can't run")
}
}

752
server/web/server.go Normal file
View File

@ -0,0 +1,752 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/fcgi"
"os"
"path"
"strconv"
"strings"
"text/template"
"time"
"golang.org/x/crypto/acme/autocert"
"github.com/astaxie/beego/core/logs"
beecontext "github.com/astaxie/beego/server/web/context"
"github.com/astaxie/beego/core/utils"
"github.com/astaxie/beego/server/web/grace"
)
var (
// BeeApp is an application instance
// If you are using single server, you could use this
// But if you need multiple servers, do not use this
BeeApp *HttpServer
)
func init() {
// create beego application
BeeApp = NewHttpSever()
}
// HttpServer defines beego application with a new PatternServeMux.
type HttpServer struct {
Handlers *ControllerRegister
Server *http.Server
Cfg *Config
}
// NewHttpSever returns a new beego application.
// this method will use the BConfig as the configure to create HttpServer
// Be careful that when you update BConfig, the server's Cfg will not be updated
func NewHttpSever() *HttpServer {
return NewHttpServerWithCfg(*BConfig)
}
// NewHttpServerWithCfg will create an sever with specific cfg
func NewHttpServerWithCfg(cfg Config) *HttpServer {
cfgPtr := &cfg
cr := NewControllerRegisterWithCfg(cfgPtr)
app := &HttpServer{
Handlers: cr,
Server: &http.Server{},
Cfg: cfgPtr,
}
return app
}
// MiddleWare function for http.Handler
type MiddleWare func(http.Handler) http.Handler
// Run beego application.
func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
initBeforeHTTPRun()
app.initAddr(addr)
addr = app.Cfg.Listen.HTTPAddr
if app.Cfg.Listen.HTTPPort != 0 {
addr = fmt.Sprintf("%s:%d", app.Cfg.Listen.HTTPAddr, app.Cfg.Listen.HTTPPort)
}
var (
err error
l net.Listener
endRunning = make(chan bool, 1)
)
// run cgi server
if app.Cfg.Listen.EnableFcgi {
if app.Cfg.Listen.EnableStdIo {
if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O
logs.Info("Use FCGI via standard I/O")
} else {
logs.Critical("Cannot use FCGI via standard I/O", err)
}
return
}
if app.Cfg.Listen.HTTPPort == 0 {
// remove the Socket file before start
if utils.FileExists(addr) {
os.Remove(addr)
}
l, err = net.Listen("unix", addr)
} else {
l, err = net.Listen("tcp", addr)
}
if err != nil {
logs.Critical("Listen: ", err)
}
if err = fcgi.Serve(l, app.Handlers); err != nil {
logs.Critical("fcgi.Serve: ", err)
}
return
}
app.Server.Handler = app.Handlers
for i := len(mws) - 1; i >= 0; i-- {
if mws[i] == nil {
continue
}
app.Server.Handler = mws[i](app.Server.Handler)
}
app.Server.ReadTimeout = time.Duration(app.Cfg.Listen.ServerTimeOut) * time.Second
app.Server.WriteTimeout = time.Duration(app.Cfg.Listen.ServerTimeOut) * time.Second
app.Server.ErrorLog = logs.GetLogger("HTTP")
// run graceful mode
if app.Cfg.Listen.Graceful {
httpsAddr := app.Cfg.Listen.HTTPSAddr
app.Server.Addr = httpsAddr
if app.Cfg.Listen.EnableHTTPS || app.Cfg.Listen.EnableMutualHTTPS {
go func() {
time.Sleep(1000 * time.Microsecond)
if app.Cfg.Listen.HTTPSPort != 0 {
httpsAddr = fmt.Sprintf("%s:%d", app.Cfg.Listen.HTTPSAddr, app.Cfg.Listen.HTTPSPort)
app.Server.Addr = httpsAddr
}
server := grace.NewServer(httpsAddr, app.Server.Handler)
server.Server.ReadTimeout = app.Server.ReadTimeout
server.Server.WriteTimeout = app.Server.WriteTimeout
if app.Cfg.Listen.EnableMutualHTTPS {
if err := server.ListenAndServeMutualTLS(app.Cfg.Listen.HTTPSCertFile,
app.Cfg.Listen.HTTPSKeyFile,
app.Cfg.Listen.TrustCaFile); err != nil {
logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
}
} else {
if app.Cfg.Listen.AutoTLS {
m := autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(app.Cfg.Listen.Domains...),
Cache: autocert.DirCache(app.Cfg.Listen.TLSCacheDir),
}
app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate}
app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile = "", ""
}
if err := server.ListenAndServeTLS(app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile); err != nil {
logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
}
}
endRunning <- true
}()
}
if app.Cfg.Listen.EnableHTTP {
go func() {
server := grace.NewServer(addr, app.Server.Handler)
server.Server.ReadTimeout = app.Server.ReadTimeout
server.Server.WriteTimeout = app.Server.WriteTimeout
if app.Cfg.Listen.ListenTCP4 {
server.Network = "tcp4"
}
if err := server.ListenAndServe(); err != nil {
logs.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
}
endRunning <- true
}()
}
<-endRunning
return
}
// run normal mode
if app.Cfg.Listen.EnableHTTPS || app.Cfg.Listen.EnableMutualHTTPS {
go func() {
time.Sleep(1000 * time.Microsecond)
if app.Cfg.Listen.HTTPSPort != 0 {
app.Server.Addr = fmt.Sprintf("%s:%d", app.Cfg.Listen.HTTPSAddr, app.Cfg.Listen.HTTPSPort)
} else if app.Cfg.Listen.EnableHTTP {
logs.Info("Start https server error, conflict with http. Please reset https port")
return
}
logs.Info("https server Running on https://%s", app.Server.Addr)
if app.Cfg.Listen.AutoTLS {
m := autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(app.Cfg.Listen.Domains...),
Cache: autocert.DirCache(app.Cfg.Listen.TLSCacheDir),
}
app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate}
app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile = "", ""
} else if app.Cfg.Listen.EnableMutualHTTPS {
pool := x509.NewCertPool()
data, err := ioutil.ReadFile(app.Cfg.Listen.TrustCaFile)
if err != nil {
logs.Info("MutualHTTPS should provide TrustCaFile")
return
}
pool.AppendCertsFromPEM(data)
app.Server.TLSConfig = &tls.Config{
ClientCAs: pool,
ClientAuth: tls.ClientAuthType(app.Cfg.Listen.ClientAuth),
}
}
if err := app.Server.ListenAndServeTLS(app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile); err != nil {
logs.Critical("ListenAndServeTLS: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
}()
}
if app.Cfg.Listen.EnableHTTP {
go func() {
app.Server.Addr = addr
logs.Info("http server Running on http://%s", app.Server.Addr)
if app.Cfg.Listen.ListenTCP4 {
ln, err := net.Listen("tcp4", app.Server.Addr)
if err != nil {
logs.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
return
}
if err = app.Server.Serve(ln); err != nil {
logs.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
return
}
} else {
if err := app.Server.ListenAndServe(); err != nil {
logs.Critical("ListenAndServe: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
}
}
}()
}
<-endRunning
}
// Router see HttpServer.Router
func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *HttpServer {
return BeeApp.Router(rootpath, c, mappingMethods...)
}
// Router adds a patterned controller handler to BeeApp.
// it's an alias method of HttpServer.Router.
// usage:
// simple router
// beego.Router("/admin", &admin.UserController{})
// beego.Router("/admin/index", &admin.ArticleController{})
//
// regex router
//
// beego.Router("/api/:id([0-9]+)", &controllers.RController{})
//
// custom rules
// beego.Router("/api/list",&RestController{},"*:ListFood")
// beego.Router("/api/create",&RestController{},"post:CreateFood")
// beego.Router("/api/update",&RestController{},"put:UpdateFood")
// beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
func (app *HttpServer) Router(rootPath string, c ControllerInterface, mappingMethods ...string) *HttpServer {
app.Handlers.Add(rootPath, c, mappingMethods...)
return app
}
// UnregisterFixedRoute see HttpServer.UnregisterFixedRoute
func UnregisterFixedRoute(fixedRoute string, method string) *HttpServer {
return BeeApp.UnregisterFixedRoute(fixedRoute, method)
}
// UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful
// in web applications that inherit most routes from a base webapp via the underscore
// import, and aim to overwrite only certain paths.
// The method parameter can be empty or "*" for all HTTP methods, or a particular
// method type (e.g. "GET" or "POST") for selective removal.
//
// Usage (replace "GET" with "*" for all methods):
// beego.UnregisterFixedRoute("/yourpreviouspath", "GET")
// beego.Router("/yourpreviouspath", yourControllerAddress, "get:GetNewPage")
func (app *HttpServer) UnregisterFixedRoute(fixedRoute string, method string) *HttpServer {
subPaths := splitPath(fixedRoute)
if method == "" || method == "*" {
for m := range HTTPMETHOD {
if _, ok := app.Handlers.routers[m]; !ok {
continue
}
if app.Handlers.routers[m].prefix == strings.Trim(fixedRoute, "/ ") {
findAndRemoveSingleTree(app.Handlers.routers[m])
continue
}
findAndRemoveTree(subPaths, app.Handlers.routers[m], m)
}
return app
}
// Single HTTP method
um := strings.ToUpper(method)
if _, ok := app.Handlers.routers[um]; ok {
if app.Handlers.routers[um].prefix == strings.Trim(fixedRoute, "/ ") {
findAndRemoveSingleTree(app.Handlers.routers[um])
return app
}
findAndRemoveTree(subPaths, app.Handlers.routers[um], um)
}
return app
}
func findAndRemoveTree(paths []string, entryPointTree *Tree, method string) {
for i := range entryPointTree.fixrouters {
if entryPointTree.fixrouters[i].prefix == paths[0] {
if len(paths) == 1 {
if len(entryPointTree.fixrouters[i].fixrouters) > 0 {
// If the route had children subtrees, remove just the functional leaf,
// to allow children to function as before
if len(entryPointTree.fixrouters[i].leaves) > 0 {
entryPointTree.fixrouters[i].leaves[0] = nil
entryPointTree.fixrouters[i].leaves = entryPointTree.fixrouters[i].leaves[1:]
}
} else {
// Remove the *Tree from the fixrouters slice
entryPointTree.fixrouters[i] = nil
if i == len(entryPointTree.fixrouters)-1 {
entryPointTree.fixrouters = entryPointTree.fixrouters[:i]
} else {
entryPointTree.fixrouters = append(entryPointTree.fixrouters[:i], entryPointTree.fixrouters[i+1:len(entryPointTree.fixrouters)]...)
}
}
return
}
findAndRemoveTree(paths[1:], entryPointTree.fixrouters[i], method)
}
}
}
func findAndRemoveSingleTree(entryPointTree *Tree) {
if entryPointTree == nil {
return
}
if len(entryPointTree.fixrouters) > 0 {
// If the route had children subtrees, remove just the functional leaf,
// to allow children to function as before
if len(entryPointTree.leaves) > 0 {
entryPointTree.leaves[0] = nil
entryPointTree.leaves = entryPointTree.leaves[1:]
}
}
}
// Include see HttpServer.Include
func Include(cList ...ControllerInterface) *HttpServer {
return BeeApp.Include(cList...)
}
// Include will generate router file in the router/xxx.go from the controller's comments
// usage:
// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{})
// type BankAccount struct{
// beego.Controller
// }
//
// register the function
// func (b *BankAccount)Mapping(){
// b.Mapping("ShowAccount" , b.ShowAccount)
// b.Mapping("ModifyAccount", b.ModifyAccount)
// }
//
// //@router /account/:id [get]
// func (b *BankAccount) ShowAccount(){
// //logic
// }
//
//
// //@router /account/:id [post]
// func (b *BankAccount) ModifyAccount(){
// //logic
// }
//
// the comments @router url methodlist
// url support all the function Router's pattern
// methodlist [get post head put delete options *]
func (app *HttpServer) Include(cList ...ControllerInterface) *HttpServer {
app.Handlers.Include(cList...)
return app
}
// RESTRouter see HttpServer.RESTRouter
func RESTRouter(rootpath string, c ControllerInterface) *HttpServer {
return BeeApp.RESTRouter(rootpath, c)
}
// RESTRouter adds a restful controller handler to BeeApp.
// its' controller implements beego.ControllerInterface and
// defines a param "pattern/:objectId" to visit each resource.
func (app *HttpServer) RESTRouter(rootpath string, c ControllerInterface) *HttpServer {
app.Router(rootpath, c)
app.Router(path.Join(rootpath, ":objectId"), c)
return app
}
// AutoRouter see HttpServer.AutoRouter
func AutoRouter(c ControllerInterface) *HttpServer {
return BeeApp.AutoRouter(c)
}
// AutoRouter adds defined controller handler to BeeApp.
// it's same to HttpServer.AutoRouter.
// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page,
// visit the url /main/list to exec List function or /main/page to exec Page function.
func (app *HttpServer) AutoRouter(c ControllerInterface) *HttpServer {
app.Handlers.AddAuto(c)
return app
}
// AutoPrefix see HttpServer.AutoPrefix
func AutoPrefix(prefix string, c ControllerInterface) *HttpServer {
return BeeApp.AutoPrefix(prefix, c)
}
// AutoPrefix adds controller handler to BeeApp with prefix.
// it's same to HttpServer.AutoRouterWithPrefix.
// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page,
// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function.
func (app *HttpServer) AutoPrefix(prefix string, c ControllerInterface) *HttpServer {
app.Handlers.AddAutoPrefix(prefix, c)
return app
}
// Get see HttpServer.Get
func Get(rootpath string, f FilterFunc) *HttpServer {
return BeeApp.Get(rootpath, f)
}
// Get used to register router for Get method
// usage:
// beego.Get("/", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func (app *HttpServer) Get(rootpath string, f FilterFunc) *HttpServer {
app.Handlers.Get(rootpath, f)
return app
}
// Post see HttpServer.Post
func Post(rootpath string, f FilterFunc) *HttpServer {
return BeeApp.Post(rootpath, f)
}
// Post used to register router for Post method
// usage:
// beego.Post("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func (app *HttpServer) Post(rootpath string, f FilterFunc) *HttpServer {
app.Handlers.Post(rootpath, f)
return app
}
// Delete see HttpServer.Delete
func Delete(rootpath string, f FilterFunc) *HttpServer {
return BeeApp.Delete(rootpath, f)
}
// Delete used to register router for Delete method
// usage:
// beego.Delete("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func (app *HttpServer) Delete(rootpath string, f FilterFunc) *HttpServer {
app.Handlers.Delete(rootpath, f)
return app
}
// Put see HttpServer.Put
func Put(rootpath string, f FilterFunc) *HttpServer {
return BeeApp.Put(rootpath, f)
}
// Put used to register router for Put method
// usage:
// beego.Put("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func (app *HttpServer) Put(rootpath string, f FilterFunc) *HttpServer {
app.Handlers.Put(rootpath, f)
return app
}
// Head see HttpServer.Head
func Head(rootpath string, f FilterFunc) *HttpServer {
return BeeApp.Head(rootpath, f)
}
// Head used to register router for Head method
// usage:
// beego.Head("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func (app *HttpServer) Head(rootpath string, f FilterFunc) *HttpServer {
app.Handlers.Head(rootpath, f)
return app
}
// Options see HttpServer.Options
func Options(rootpath string, f FilterFunc) *HttpServer {
BeeApp.Handlers.Options(rootpath, f)
return BeeApp
}
// Options used to register router for Options method
// usage:
// beego.Options("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func (app *HttpServer) Options(rootpath string, f FilterFunc) *HttpServer {
app.Handlers.Options(rootpath, f)
return app
}
// Patch see HttpServer.Patch
func Patch(rootpath string, f FilterFunc) *HttpServer {
return BeeApp.Patch(rootpath, f)
}
// Patch used to register router for Patch method
// usage:
// beego.Patch("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func (app *HttpServer) Patch(rootpath string, f FilterFunc) *HttpServer {
app.Handlers.Patch(rootpath, f)
return app
}
// Any see HttpServer.Any
func Any(rootpath string, f FilterFunc) *HttpServer {
return BeeApp.Any(rootpath, f)
}
// Any used to register router for all methods
// usage:
// beego.Any("/api", func(ctx *context.Context){
// ctx.Output.Body("hello world")
// })
func (app *HttpServer) Any(rootpath string, f FilterFunc) *HttpServer {
app.Handlers.Any(rootpath, f)
return app
}
// Handler see HttpServer.Handler
func Handler(rootpath string, h http.Handler, options ...interface{}) *HttpServer {
return BeeApp.Handler(rootpath, h, options...)
}
// Handler used to register a Handler router
// usage:
// beego.Handler("/api", http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {
// fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path))
// }))
func (app *HttpServer) Handler(rootpath string, h http.Handler, options ...interface{}) *HttpServer {
app.Handlers.Handler(rootpath, h, options...)
return app
}
// InserFilter see HttpServer.InsertFilter
func InsertFilter(pattern string, pos int, filter FilterFunc, opts ...FilterOpt) *HttpServer {
return BeeApp.InsertFilter(pattern, pos, filter, opts...)
}
// InsertFilter adds a FilterFunc with pattern condition and action constant.
// The pos means action constant including
// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter.
// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute)
func (app *HttpServer) InsertFilter(pattern string, pos int, filter FilterFunc, opts ...FilterOpt) *HttpServer {
app.Handlers.InsertFilter(pattern, pos, filter, opts...)
return app
}
// InsertFilterChain see HttpServer.InsertFilterChain
func InsertFilterChain(pattern string, filterChain FilterChain, opts ...FilterOpt) *HttpServer {
return BeeApp.InsertFilterChain(pattern, filterChain, opts...)
}
// InsertFilterChain adds a FilterFunc built by filterChain.
// This filter will be executed before all filters.
// the filter's behavior like stack's behavior
// and the last filter is serving the http request
func (app *HttpServer) InsertFilterChain(pattern string, filterChain FilterChain, opts ...FilterOpt) *HttpServer {
app.Handlers.InsertFilterChain(pattern, filterChain, opts...)
return app
}
func (app *HttpServer) initAddr(addr string) {
strs := strings.Split(addr, ":")
if len(strs) > 0 && strs[0] != "" {
app.Cfg.Listen.HTTPAddr = strs[0]
app.Cfg.Listen.Domains = []string{strs[0]}
}
if len(strs) > 1 && strs[1] != "" {
app.Cfg.Listen.HTTPPort, _ = strconv.Atoi(strs[1])
}
}
func (app *HttpServer) LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) {
// Skip logging if AccessLogs config is false
if !app.Cfg.Log.AccessLogs {
return
}
// Skip logging static requests unless EnableStaticLogs config is true
if !app.Cfg.Log.EnableStaticLogs && DefaultAccessLogFilter.Filter(ctx) {
return
}
var (
requestTime time.Time
elapsedTime time.Duration
r = ctx.Request
)
if startTime != nil {
requestTime = *startTime
elapsedTime = time.Since(*startTime)
}
record := &logs.AccessLogRecord{
RemoteAddr: ctx.Input.IP(),
RequestTime: requestTime,
RequestMethod: r.Method,
Request: fmt.Sprintf("%s %s %s", r.Method, r.RequestURI, r.Proto),
ServerProtocol: r.Proto,
Host: r.Host,
Status: statusCode,
ElapsedTime: elapsedTime,
HTTPReferrer: r.Header.Get("Referer"),
HTTPUserAgent: r.Header.Get("User-Agent"),
RemoteUser: r.Header.Get("Remote-User"),
BodyBytesSent: r.ContentLength,
}
logs.AccessLog(record, app.Cfg.Log.AccessLogsFormat)
}
// PrintTree prints all registered routers.
func (app *HttpServer) PrintTree() M {
var (
content = M{}
methods = []string{}
methodsData = make(M)
)
for method, t := range app.Handlers.routers {
resultList := new([][]string)
printTree(resultList, t)
methods = append(methods, template.HTMLEscapeString(method))
methodsData[template.HTMLEscapeString(method)] = resultList
}
content["Data"] = methodsData
content["Methods"] = methods
return content
}
func printTree(resultList *[][]string, t *Tree) {
for _, tr := range t.fixrouters {
printTree(resultList, tr)
}
if t.wildcard != nil {
printTree(resultList, t.wildcard)
}
for _, l := range t.leaves {
if v, ok := l.runObject.(*ControllerInfo); ok {
if v.routerType == routerTypeBeego {
var result = []string{
template.HTMLEscapeString(v.pattern),
template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)),
template.HTMLEscapeString(v.controllerType.String()),
}
*resultList = append(*resultList, result)
} else if v.routerType == routerTypeRESTFul {
var result = []string{
template.HTMLEscapeString(v.pattern),
template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)),
"",
}
*resultList = append(*resultList, result)
} else if v.routerType == routerTypeHandler {
var result = []string{
template.HTMLEscapeString(v.pattern),
"",
"",
}
*resultList = append(*resultList, result)
}
}
}
}
func (app *HttpServer) reportFilter() M {
filterTypeData := make(M)
// filterTypes := []string{}
if app.Handlers.enableFilter {
// var filterType string
for k, fr := range map[int]string{
BeforeStatic: "Before Static",
BeforeRouter: "Before Router",
BeforeExec: "Before Exec",
AfterExec: "After Exec",
FinishRouter: "Finish Router",
} {
if bf := app.Handlers.filters[k]; len(bf) > 0 {
resultList := new([][]string)
for _, f := range bf {
var result = []string{
// void xss
template.HTMLEscapeString(f.pattern),
template.HTMLEscapeString(utils.GetFuncName(f.filterFunc)),
}
*resultList = append(*resultList, result)
}
filterTypeData[fr] = resultList
}
}
}
return filterTypeData
}

31
server/web/server_test.go Normal file
View File

@ -0,0 +1,31 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewHttpServerWithCfg(t *testing.T) {
// we should make sure that update server's config won't change
BConfig.AppName = "Before"
svr := NewHttpServerWithCfg(*BConfig)
svr.Cfg.AppName = "hello"
assert.NotEqual(t, "hello", BConfig.AppName)
assert.Equal(t, "Before", BConfig.AppName)
}

View File

@ -0,0 +1,114 @@
session
==============
session is a Go session manager. It can use many session providers. Just like the `database/sql` and `database/sql/driver`.
## How to install?
go get github.com/astaxie/beego/session
## What providers are supported?
As of now this session manager support memory, file, Redis and MySQL.
## How to use it?
First you must import it
import (
"github.com/astaxie/beego/session"
)
Then in you web app init the global session manager
var globalSessions *session.Manager
* Use **memory** as provider:
func init() {
globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`)
go globalSessions.GC()
}
* Use **file** as provider, the last param is the path where you want file to be stored:
func init() {
globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"./tmp"}`)
go globalSessions.GC()
}
* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password:
func init() {
globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:6379,100,astaxie"}`)
go globalSessions.GC()
}
* Use **MySQL** as provider, the last param is the DSN, learn more from [mysql](https://github.com/go-sql-driver/mysql#dsn-data-source-name):
func init() {
globalSessions, _ = session.NewManager(
"mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"username:password@protocol(address)/dbname?param=value"}`)
go globalSessions.GC()
}
* Use **Cookie** as provider:
func init() {
globalSessions, _ = session.NewManager(
"cookie", `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`)
go globalSessions.GC()
}
Finally in the handlerfunc you can use it like this
func login(w http.ResponseWriter, r *http.Request) {
sess := globalSessions.SessionStart(w, r)
defer sess.SessionRelease(w)
username := sess.Get("username")
fmt.Println(username)
if r.Method == "GET" {
t, _ := template.ParseFiles("login.gtpl")
t.Execute(w, nil)
} else {
fmt.Println("username:", r.Form["username"])
sess.Set("username", r.Form["username"])
fmt.Println("password:", r.Form["password"])
}
}
## How to write own provider?
When you develop a web app, maybe you want to write own provider because you must meet the requirements.
Writing a provider is easy. You only need to define two struct types
(Session and Provider), which satisfy the interface definition.
Maybe you will find the **memory** provider is a good example.
type SessionStore interface {
Set(key, value interface{}) error //set session value
Get(key interface{}) interface{} //get session value
Delete(key interface{}) error //delete session value
SessionID() string //back current sessionID
SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
Flush() error //delete all data
}
type Provider interface {
SessionInit(gclifetime int64, config string) error
SessionRead(sid string) (SessionStore, error)
SessionExist(sid string) (bool, error)
SessionRegenerate(oldsid, sid string) (SessionStore, error)
SessionDestroy(sid string) error
SessionAll() int //get all active session
SessionGC()
}
## LICENSE
BSD License http://creativecommons.org/licenses/BSD/

View File

@ -0,0 +1,248 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package couchbase for session provider
//
// depend on github.com/couchbaselabs/go-couchbasee
//
// go install github.com/couchbaselabs/go-couchbase
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/couchbase"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("couchbase", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"http://host:port/, Pool, Bucket"}``)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package couchbase
import (
"context"
"net/http"
"strings"
"sync"
couchbase "github.com/couchbase/go-couchbase"
"github.com/astaxie/beego/server/web/session"
)
var couchbpder = &Provider{}
// SessionStore store each session
type SessionStore struct {
b *couchbase.Bucket
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Provider couchabse provided
type Provider struct {
maxlifetime int64
savePath string
pool string
bucket string
b *couchbase.Bucket
}
// Set value to couchabse session
func (cs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
cs.lock.Lock()
defer cs.lock.Unlock()
cs.values[key] = value
return nil
}
// Get value from couchabse session
func (cs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
cs.lock.RLock()
defer cs.lock.RUnlock()
if v, ok := cs.values[key]; ok {
return v
}
return nil
}
// Delete value in couchbase session by given key
func (cs *SessionStore) Delete(ctx context.Context, key interface{}) error {
cs.lock.Lock()
defer cs.lock.Unlock()
delete(cs.values, key)
return nil
}
// Flush Clean all values in couchbase session
func (cs *SessionStore) Flush(context.Context) error {
cs.lock.Lock()
defer cs.lock.Unlock()
cs.values = make(map[interface{}]interface{})
return nil
}
// SessionID Get couchbase session store id
func (cs *SessionStore) SessionID(context.Context) string {
return cs.sid
}
// SessionRelease Write couchbase session with Gob string
func (cs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
defer cs.b.Close()
bo, err := session.EncodeGob(cs.values)
if err != nil {
return
}
cs.b.Set(cs.sid, int(cs.maxlifetime), bo)
}
func (cp *Provider) getBucket() *couchbase.Bucket {
c, err := couchbase.Connect(cp.savePath)
if err != nil {
return nil
}
pool, err := c.GetPool(cp.pool)
if err != nil {
return nil
}
bucket, err := pool.GetBucket(cp.bucket)
if err != nil {
return nil
}
return bucket
}
// SessionInit init couchbase session
// savepath like couchbase server REST/JSON URL
// e.g. http://host:port/, Pool, Bucket
func (cp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
cp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
cp.savePath = configs[0]
}
if len(configs) > 1 {
cp.pool = configs[1]
}
if len(configs) > 2 {
cp.bucket = configs[2]
}
return nil
}
// SessionRead read couchbase session by sid
func (cp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
cp.b = cp.getBucket()
var (
kv map[interface{}]interface{}
err error
doc []byte
)
err = cp.b.Get(sid, &doc)
if err != nil {
return nil, err
} else if doc == nil {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(doc)
if err != nil {
return nil, err
}
}
cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
return cs, nil
}
// SessionExist Check couchbase session exist.
// it checkes sid exist or not.
func (cp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
cp.b = cp.getBucket()
defer cp.b.Close()
var doc []byte
if err := cp.b.Get(sid, &doc); err != nil || doc == nil {
return false, err
}
return true, nil
}
// SessionRegenerate remove oldsid and use sid to generate new session
func (cp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
cp.b = cp.getBucket()
var doc []byte
if err := cp.b.Get(oldsid, &doc); err != nil || doc == nil {
cp.b.Set(sid, int(cp.maxlifetime), "")
} else {
err := cp.b.Delete(oldsid)
if err != nil {
return nil, err
}
_, _ = cp.b.Add(sid, int(cp.maxlifetime), doc)
}
err := cp.b.Get(sid, &doc)
if err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if doc == nil {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(doc)
if err != nil {
return nil, err
}
}
cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
return cs, nil
}
// SessionDestroy Remove bucket in this couchbase
func (cp *Provider) SessionDestroy(ctx context.Context, sid string) error {
cp.b = cp.getBucket()
defer cp.b.Close()
cp.b.Delete(sid)
return nil
}
// SessionGC Recycle
func (cp *Provider) SessionGC(context.Context) {
}
// SessionAll return all active session
func (cp *Provider) SessionAll(context.Context) int {
return 0
}
func init() {
session.Register("couchbase", couchbpder)
}

View File

@ -0,0 +1,174 @@
// Package ledis provide session Provider
package ledis
import (
"context"
"net/http"
"strconv"
"strings"
"sync"
"github.com/ledisdb/ledisdb/config"
"github.com/ledisdb/ledisdb/ledis"
"github.com/astaxie/beego/server/web/session"
)
var (
ledispder = &Provider{}
c *ledis.DB
)
// SessionStore ledis session store
type SessionStore struct {
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Set value in ledis session
func (ls *SessionStore) Set(ctx context.Context, key, value interface{}) error {
ls.lock.Lock()
defer ls.lock.Unlock()
ls.values[key] = value
return nil
}
// Get value in ledis session
func (ls *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
ls.lock.RLock()
defer ls.lock.RUnlock()
if v, ok := ls.values[key]; ok {
return v
}
return nil
}
// Delete value in ledis session
func (ls *SessionStore) Delete(ctx context.Context, key interface{}) error {
ls.lock.Lock()
defer ls.lock.Unlock()
delete(ls.values, key)
return nil
}
// Flush clear all values in ledis session
func (ls *SessionStore) Flush(context.Context) error {
ls.lock.Lock()
defer ls.lock.Unlock()
ls.values = make(map[interface{}]interface{})
return nil
}
// SessionID get ledis session id
func (ls *SessionStore) SessionID(context.Context) string {
return ls.sid
}
// SessionRelease save session values to ledis
func (ls *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(ls.values)
if err != nil {
return
}
c.Set([]byte(ls.sid), b)
c.Expire([]byte(ls.sid), ls.maxlifetime)
}
// Provider ledis session provider
type Provider struct {
maxlifetime int64
savePath string
db int
}
// SessionInit init ledis session
// savepath like ledis server saveDataPath,pool size
// e.g. 127.0.0.1:6379,100,astaxie
func (lp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
var err error
lp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) == 1 {
lp.savePath = configs[0]
} else if len(configs) == 2 {
lp.savePath = configs[0]
lp.db, err = strconv.Atoi(configs[1])
if err != nil {
return err
}
}
cfg := new(config.Config)
cfg.DataDir = lp.savePath
var ledisInstance *ledis.Ledis
ledisInstance, err = ledis.Open(cfg)
if err != nil {
return err
}
c, err = ledisInstance.Select(lp.db)
return err
}
// SessionRead read ledis session by sid
func (lp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
var (
kv map[interface{}]interface{}
err error
)
kvs, _ := c.Get([]byte(sid))
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
if kv, err = session.DecodeGob(kvs); err != nil {
return nil, err
}
}
ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime}
return ls, nil
}
// SessionExist check ledis session exist by sid
func (lp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
count, _ := c.Exists([]byte(sid))
return count != 0, nil
}
// SessionRegenerate generate new sid for ledis session
func (lp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
count, _ := c.Exists([]byte(sid))
if count == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
c.Set([]byte(sid), []byte(""))
c.Expire([]byte(sid), lp.maxlifetime)
} else {
data, _ := c.Get([]byte(oldsid))
c.Set([]byte(sid), data)
c.Expire([]byte(sid), lp.maxlifetime)
}
return lp.SessionRead(context.Background(), sid)
}
// SessionDestroy delete ledis session by id
func (lp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c.Del([]byte(sid))
return nil
}
// SessionGC Impelment method, no used.
func (lp *Provider) SessionGC(context.Context) {
}
// SessionAll return all active session
func (lp *Provider) SessionAll(context.Context) int {
return 0
}
func init() {
session.Register("ledis", ledispder)
}

View File

@ -0,0 +1,231 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package memcache for session provider
//
// depend on github.com/bradfitz/gomemcache/memcache
//
// go install github.com/bradfitz/gomemcache/memcache
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/memcache"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("memcache", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:11211"}``)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package memcache
import (
"context"
"net/http"
"strings"
"sync"
"github.com/astaxie/beego/server/web/session"
"github.com/bradfitz/gomemcache/memcache"
)
var mempder = &MemProvider{}
var client *memcache.Client
// SessionStore memcache session store
type SessionStore struct {
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Set value in memcache session
func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
return nil
}
// Get value in memcache session
func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
return v
}
return nil
}
// Delete value in memcache session
func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
return nil
}
// Flush clear all values in memcache session
func (rs *SessionStore) Flush(context.Context) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
return nil
}
// SessionID get memcache session id
func (rs *SessionStore) SessionID(context.Context) string {
return rs.sid
}
// SessionRelease save session values to memcache
func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
}
item := memcache.Item{Key: rs.sid, Value: b, Expiration: int32(rs.maxlifetime)}
client.Set(&item)
}
// MemProvider memcache session provider
type MemProvider struct {
maxlifetime int64
conninfo []string
poolsize int
password string
}
// SessionInit init memcache session
// savepath like
// e.g. 127.0.0.1:9090
func (rp *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
rp.conninfo = strings.Split(savePath, ";")
client = memcache.New(rp.conninfo...)
return nil
}
// SessionRead read memcache session by sid
func (rp *MemProvider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
if client == nil {
if err := rp.connectInit(); err != nil {
return nil, err
}
}
item, err := client.Get(sid)
if err != nil {
if err == memcache.ErrCacheMiss {
rs := &SessionStore{sid: sid, values: make(map[interface{}]interface{}), maxlifetime: rp.maxlifetime}
return rs, nil
}
return nil, err
}
var kv map[interface{}]interface{}
if len(item.Value) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(item.Value)
if err != nil {
return nil, err
}
}
rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
// SessionExist check memcache session exist by sid
func (rp *MemProvider) SessionExist(ctx context.Context, sid string) (bool, error) {
if client == nil {
if err := rp.connectInit(); err != nil {
return false, err
}
}
if item, err := client.Get(sid); err != nil || len(item.Value) == 0 {
return false, err
}
return true, nil
}
// SessionRegenerate generate new sid for memcache session
func (rp *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
if client == nil {
if err := rp.connectInit(); err != nil {
return nil, err
}
}
var contain []byte
if item, err := client.Get(sid); err != nil || len(item.Value) == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
item.Key = sid
item.Value = []byte("")
item.Expiration = int32(rp.maxlifetime)
client.Set(item)
} else {
client.Delete(oldsid)
item.Key = sid
item.Expiration = int32(rp.maxlifetime)
client.Set(item)
contain = item.Value
}
var kv map[interface{}]interface{}
if len(contain) == 0 {
kv = make(map[interface{}]interface{})
} else {
var err error
kv, err = session.DecodeGob(contain)
if err != nil {
return nil, err
}
}
rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
// SessionDestroy delete memcache session by id
func (rp *MemProvider) SessionDestroy(ctx context.Context, sid string) error {
if client == nil {
if err := rp.connectInit(); err != nil {
return err
}
}
return client.Delete(sid)
}
func (rp *MemProvider) connectInit() error {
client = memcache.New(rp.conninfo...)
return nil
}
// SessionGC Impelment method, no used.
func (rp *MemProvider) SessionGC(context.Context) {
}
// SessionAll return all activeSession
func (rp *MemProvider) SessionAll(context.Context) int {
return 0
}
func init() {
session.Register("memcache", mempder)
}

View File

@ -0,0 +1,235 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package mysql for session provider
//
// depends on github.com/go-sql-driver/mysql:
//
// go install github.com/go-sql-driver/mysql
//
// mysql session support need create table as sql:
// CREATE TABLE `session` (
// `session_key` char(64) NOT NULL,
// `session_data` blob,
// `session_expiry` int(11) unsigned NOT NULL,
// PRIMARY KEY (`session_key`)
// ) ENGINE=MyISAM DEFAULT CHARSET=utf8;
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/mysql"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("mysql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...&paramN=valueN]"}``)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package mysql
import (
"context"
"database/sql"
"net/http"
"sync"
"time"
"github.com/astaxie/beego/server/web/session"
// import mysql driver
_ "github.com/go-sql-driver/mysql"
)
var (
// TableName store the session in MySQL
TableName = "session"
mysqlpder = &Provider{}
)
// SessionStore mysql session store
type SessionStore struct {
c *sql.DB
sid string
lock sync.RWMutex
values map[interface{}]interface{}
}
// Set value in mysql session.
// it is temp value in map.
func (st *SessionStore) Set(ctx context.Context, key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
return nil
}
// Get value from mysql session
func (st *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
}
return nil
}
// Delete value in mysql session
func (st *SessionStore) Delete(ctx context.Context, key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
return nil
}
// Flush clear all values in mysql session
func (st *SessionStore) Flush(context.Context) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
return nil
}
// SessionID get session id of this mysql session store
func (st *SessionStore) SessionID(context.Context) string {
return st.sid
}
// SessionRelease save mysql session values to database.
// must call this method to save values to database.
func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
defer st.c.Close()
b, err := session.EncodeGob(st.values)
if err != nil {
return
}
st.c.Exec("UPDATE "+TableName+" set `session_data`=?, `session_expiry`=? where session_key=?",
b, time.Now().Unix(), st.sid)
}
// Provider mysql session provider
type Provider struct {
maxlifetime int64
savePath string
}
// connect to mysql
func (mp *Provider) connectInit() *sql.DB {
db, e := sql.Open("mysql", mp.savePath)
if e != nil {
return nil
}
return db
}
// SessionInit init mysql session.
// savepath is the connection string of mysql.
func (mp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
mp.maxlifetime = maxlifetime
mp.savePath = savePath
return nil
}
// SessionRead get mysql session by sid
func (mp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)",
sid, "", time.Now().Unix())
}
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(sessiondata)
if err != nil {
return nil, err
}
}
rs := &SessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
// SessionExist check mysql session exist
func (mp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := mp.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, err
}
return true, nil
}
// SessionRegenerate generate new sid for mysql session
func (mp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix())
}
c.Exec("update "+TableName+" set `session_key`=? where session_key=?", sid, oldsid)
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(sessiondata)
if err != nil {
return nil, err
}
}
rs := &SessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
// SessionDestroy delete mysql session by sid
func (mp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := mp.connectInit()
c.Exec("DELETE FROM "+TableName+" where session_key=?", sid)
c.Close()
return nil
}
// SessionGC delete expired values in mysql session
func (mp *Provider) SessionGC(context.Context) {
c := mp.connectInit()
c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime)
c.Close()
}
// SessionAll count values in mysql session
func (mp *Provider) SessionAll(context.Context) int {
c := mp.connectInit()
defer c.Close()
var total int
err := c.QueryRow("SELECT count(*) as num from " + TableName).Scan(&total)
if err != nil {
return 0
}
return total
}
func init() {
session.Register("mysql", mysqlpder)
}

View File

@ -0,0 +1,250 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package postgres for session provider
//
// depends on github.com/lib/pq:
//
// go install github.com/lib/pq
//
//
// needs this table in your database:
//
// CREATE TABLE session (
// session_key char(64) NOT NULL,
// session_data bytea,
// session_expiry timestamp NOT NULL,
// CONSTRAINT session_key PRIMARY KEY(session_key)
// );
//
// will be activated with these settings in app.conf:
//
// SessionOn = true
// SessionProvider = postgresql
// SessionSavePath = "user=a password=b dbname=c sslmode=disable"
// SessionName = session
//
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/postgresql"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("postgresql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"user=pqgotest dbname=pqgotest sslmode=verify-full"}``)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package postgres
import (
"context"
"database/sql"
"net/http"
"sync"
"time"
"github.com/astaxie/beego/server/web/session"
// import postgresql Driver
_ "github.com/lib/pq"
)
var postgresqlpder = &Provider{}
// SessionStore postgresql session store
type SessionStore struct {
c *sql.DB
sid string
lock sync.RWMutex
values map[interface{}]interface{}
}
// Set value in postgresql session.
// it is temp value in map.
func (st *SessionStore) Set(ctx context.Context, key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
return nil
}
// Get value from postgresql session
func (st *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
}
return nil
}
// Delete value in postgresql session
func (st *SessionStore) Delete(ctx context.Context, key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
return nil
}
// Flush clear all values in postgresql session
func (st *SessionStore) Flush(context.Context) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
return nil
}
// SessionID get session id of this postgresql session store
func (st *SessionStore) SessionID(context.Context) string {
return st.sid
}
// SessionRelease save postgresql session values to database.
// must call this method to save values to database.
func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
defer st.c.Close()
b, err := session.EncodeGob(st.values)
if err != nil {
return
}
st.c.Exec("UPDATE session set session_data=$1, session_expiry=$2 where session_key=$3",
b, time.Now().Format(time.RFC3339), st.sid)
}
// Provider postgresql session provider
type Provider struct {
maxlifetime int64
savePath string
}
// connect to postgresql
func (mp *Provider) connectInit() *sql.DB {
db, e := sql.Open("postgres", mp.savePath)
if e != nil {
return nil
}
return db
}
// SessionInit init postgresql session.
// savepath is the connection string of postgresql.
func (mp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
mp.maxlifetime = maxlifetime
mp.savePath = savePath
return nil
}
// SessionRead get postgresql session by sid
func (mp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=$1", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
_, err = c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)",
sid, "", time.Now().Format(time.RFC3339))
if err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(sessiondata)
if err != nil {
return nil, err
}
}
rs := &SessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
// SessionExist check postgresql session exist
func (mp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := mp.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from session where session_key=$1", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, err
}
return true, nil
}
// SessionRegenerate generate new sid for postgresql session
func (mp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=$1", oldsid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)",
oldsid, "", time.Now().Format(time.RFC3339))
}
c.Exec("update session set session_key=$1 where session_key=$2", sid, oldsid)
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(sessiondata)
if err != nil {
return nil, err
}
}
rs := &SessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
// SessionDestroy delete postgresql session by sid
func (mp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := mp.connectInit()
c.Exec("DELETE FROM session where session_key=$1", sid)
c.Close()
return nil
}
// SessionGC delete expired values in postgresql session
func (mp *Provider) SessionGC(context.Context) {
c := mp.connectInit()
c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime)
c.Close()
}
// SessionAll count values in postgresql session
func (mp *Provider) SessionAll(context.Context) int {
c := mp.connectInit()
defer c.Close()
var total int
err := c.QueryRow("SELECT count(*) as num from session").Scan(&total)
if err != nil {
return 0
}
return total
}
func init() {
session.Register("postgresql", postgresqlpder)
}

View File

@ -0,0 +1,252 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package redis for session provider
//
// depend on github.com/gomodule/redigo/redis
//
// go install github.com/gomodule/redigo/redis
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/redis"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package redis
import (
"context"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/go-redis/redis/v7"
"github.com/astaxie/beego/server/web/session"
)
var redispder = &Provider{}
// MaxPoolSize redis max pool size
var MaxPoolSize = 100
// SessionStore redis session store
type SessionStore struct {
p *redis.Client
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Set value in redis session
func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
return nil
}
// Get value in redis session
func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
return v
}
return nil
}
// Delete value in redis session
func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
return nil
}
// Flush clear all values in redis session
func (rs *SessionStore) Flush(context.Context) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
return nil
}
// SessionID get redis session id
func (rs *SessionStore) SessionID(context.Context) string {
return rs.sid
}
// SessionRelease save session values to redis
func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
}
c := rs.p
c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second)
}
// Provider redis session provider
type Provider struct {
maxlifetime int64
savePath string
poolsize int
password string
dbNum int
idleTimeout time.Duration
idleCheckFrequency time.Duration
maxRetries int
poollist *redis.Client
}
// SessionInit init redis session
// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second
// e.g. 127.0.0.1:6379,100,astaxie,0,30
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
rp.savePath = configs[0]
}
if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize < 0 {
rp.poolsize = MaxPoolSize
} else {
rp.poolsize = poolsize
}
} else {
rp.poolsize = MaxPoolSize
}
if len(configs) > 2 {
rp.password = configs[2]
}
if len(configs) > 3 {
dbnum, err := strconv.Atoi(configs[3])
if err != nil || dbnum < 0 {
rp.dbNum = 0
} else {
rp.dbNum = dbnum
}
} else {
rp.dbNum = 0
}
if len(configs) > 4 {
timeout, err := strconv.Atoi(configs[4])
if err == nil && timeout > 0 {
rp.idleTimeout = time.Duration(timeout) * time.Second
}
}
if len(configs) > 5 {
checkFrequency, err := strconv.Atoi(configs[5])
if err == nil && checkFrequency > 0 {
rp.idleCheckFrequency = time.Duration(checkFrequency) * time.Second
}
}
if len(configs) > 6 {
retries, err := strconv.Atoi(configs[6])
if err == nil && retries > 0 {
rp.maxRetries = retries
}
}
rp.poollist = redis.NewClient(&redis.Options{
Addr: rp.savePath,
Password: rp.password,
PoolSize: rp.poolsize,
DB: rp.dbNum,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.maxRetries,
})
return rp.poollist.Ping().Err()
}
// SessionRead read redis session by sid
func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
var kv map[interface{}]interface{}
kvs, err := rp.poollist.Get(sid).Result()
if err != nil && err != redis.Nil {
return nil, err
}
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
if kv, err = session.DecodeGob([]byte(kvs)); err != nil {
return nil, err
}
}
rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
// SessionExist check redis session exist by sid
func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := rp.poollist
if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 {
return false, err
}
return true, nil
}
// SessionRegenerate generate new sid for redis session
func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := rp.poollist
if existed, _ := c.Exists(oldsid).Result(); existed == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
c.Do(c.Context(), "SET", sid, "", "EX", rp.maxlifetime)
} else {
c.Rename(oldsid, sid)
c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second)
}
return rp.SessionRead(context.Background(), sid)
}
// SessionDestroy delete redis session by id
func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := rp.poollist
c.Del(sid)
return nil
}
// SessionGC Impelment method, no used.
func (rp *Provider) SessionGC(context.Context) {
}
// SessionAll return all activeSession
func (rp *Provider) SessionAll(context.Context) int {
return 0
}
func init() {
session.Register("redis", redispder)
}

View File

@ -0,0 +1,96 @@
package redis
import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/astaxie/beego/server/web/session"
)
func TestRedis(t *testing.T) {
sessionConfig := &session.ManagerConfig{
CookieName: "gosessionid",
EnableSetCookie: true,
Gclifetime: 3600,
Maxlifetime: 3600,
Secure: false,
CookieLifeTime: 3600,
}
redisAddr := os.Getenv("REDIS_ADDR")
if redisAddr == "" {
redisAddr = "127.0.0.1:6379"
}
sessionConfig.ProviderConfig = fmt.Sprintf("%s,100,,0,30", redisAddr)
globalSession, err := session.NewManager("redis", sessionConfig)
if err != nil {
t.Fatal("could not create manager:", err)
}
go globalSession.GC()
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess, err := globalSession.SessionStart(w, r)
if err != nil {
t.Fatal("session start failed:", err)
}
defer sess.SessionRelease(nil, w)
// SET AND GET
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set username failed:", err)
}
username := sess.Get(nil, "username")
if username != "astaxie" {
t.Fatal("get username failed")
}
// DELETE
err = sess.Delete(nil, "username")
if err != nil {
t.Fatal("delete username failed:", err)
}
username = sess.Get(nil, "username")
if username != nil {
t.Fatal("delete username failed")
}
// FLUSH
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set failed:", err)
}
err = sess.Set(nil, "password", "1qaz2wsx")
if err != nil {
t.Fatal("set failed:", err)
}
username = sess.Get(nil, "username")
if username != "astaxie" {
t.Fatal("get username failed")
}
password := sess.Get(nil, "password")
if password != "1qaz2wsx" {
t.Fatal("get password failed")
}
err = sess.Flush(nil)
if err != nil {
t.Fatal("flush failed:", err)
}
username = sess.Get(nil, "username")
if username != nil {
t.Fatal("flush failed")
}
password = sess.Get(nil, "password")
if password != nil {
t.Fatal("flush failed")
}
sess.SessionRelease(nil, w)
}

View File

@ -0,0 +1,247 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package redis for session provider
//
// depend on github.com/go-redis/redis
//
// go install github.com/go-redis/redis
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/redis_cluster"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("redis_cluster", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070;127.0.0.1:7071"}``)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package redis_cluster
import (
"context"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/astaxie/beego/server/web/session"
rediss "github.com/go-redis/redis/v7"
)
var redispder = &Provider{}
// MaxPoolSize redis_cluster max pool size
var MaxPoolSize = 1000
// SessionStore redis_cluster session store
type SessionStore struct {
p *rediss.ClusterClient
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Set value in redis_cluster session
func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
return nil
}
// Get value in redis_cluster session
func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
return v
}
return nil
}
// Delete value in redis_cluster session
func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
return nil
}
// Flush clear all values in redis_cluster session
func (rs *SessionStore) Flush(context.Context) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
return nil
}
// SessionID get redis_cluster session id
func (rs *SessionStore) SessionID(context.Context) string {
return rs.sid
}
// SessionRelease save session values to redis_cluster
func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
}
c := rs.p
c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second)
}
// Provider redis_cluster session provider
type Provider struct {
maxlifetime int64
savePath string
poolsize int
password string
dbNum int
idleTimeout time.Duration
idleCheckFrequency time.Duration
maxRetries int
poollist *rediss.ClusterClient
}
// SessionInit init redis_cluster session
// savepath like redis server addr,pool size,password,dbnum
// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
rp.savePath = configs[0]
}
if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize < 0 {
rp.poolsize = MaxPoolSize
} else {
rp.poolsize = poolsize
}
} else {
rp.poolsize = MaxPoolSize
}
if len(configs) > 2 {
rp.password = configs[2]
}
if len(configs) > 3 {
dbnum, err := strconv.Atoi(configs[3])
if err != nil || dbnum < 0 {
rp.dbNum = 0
} else {
rp.dbNum = dbnum
}
} else {
rp.dbNum = 0
}
if len(configs) > 4 {
timeout, err := strconv.Atoi(configs[4])
if err == nil && timeout > 0 {
rp.idleTimeout = time.Duration(timeout) * time.Second
}
}
if len(configs) > 5 {
checkFrequency, err := strconv.Atoi(configs[5])
if err == nil && checkFrequency > 0 {
rp.idleCheckFrequency = time.Duration(checkFrequency) * time.Second
}
}
if len(configs) > 6 {
retries, err := strconv.Atoi(configs[6])
if err == nil && retries > 0 {
rp.maxRetries = retries
}
}
rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{
Addrs: strings.Split(rp.savePath, ";"),
Password: rp.password,
PoolSize: rp.poolsize,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.maxRetries,
})
return rp.poollist.Ping().Err()
}
// SessionRead read redis_cluster session by sid
func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
var kv map[interface{}]interface{}
kvs, err := rp.poollist.Get(sid).Result()
if err != nil && err != rediss.Nil {
return nil, err
}
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
if kv, err = session.DecodeGob([]byte(kvs)); err != nil {
return nil, err
}
}
rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
// SessionExist check redis_cluster session exist by sid
func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := rp.poollist
if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 {
return false, err
}
return true, nil
}
// SessionRegenerate generate new sid for redis_cluster session
func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := rp.poollist
if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second)
} else {
c.Rename(oldsid, sid)
c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second)
}
return rp.SessionRead(context.Background(), sid)
}
// SessionDestroy delete redis session by id
func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := rp.poollist
c.Del(sid)
return nil
}
// SessionGC Impelment method, no used.
func (rp *Provider) SessionGC(context.Context) {
}
// SessionAll return all activeSession
func (rp *Provider) SessionAll(context.Context) int {
return 0
}
func init() {
session.Register("redis_cluster", redispder)
}

View File

@ -0,0 +1,261 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package redis for session provider
//
// depend on github.com/go-redis/redis
//
// go install github.com/go-redis/redis
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/redis_sentinel"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("redis_sentinel", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:26379;127.0.0.2:26379"}``)
// go globalSessions.GC()
// }
//
// more detail about params: please check the notes on the function SessionInit in this package
package redis_sentinel
import (
"context"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/go-redis/redis/v7"
"github.com/astaxie/beego/server/web/session"
)
var redispder = &Provider{}
// DefaultPoolSize redis_sentinel default pool size
var DefaultPoolSize = 100
// SessionStore redis_sentinel session store
type SessionStore struct {
p *redis.Client
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Set value in redis_sentinel session
func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
return nil
}
// Get value in redis_sentinel session
func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
return v
}
return nil
}
// Delete value in redis_sentinel session
func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
return nil
}
// Flush clear all values in redis_sentinel session
func (rs *SessionStore) Flush(context.Context) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
return nil
}
// SessionID get redis_sentinel session id
func (rs *SessionStore) SessionID(context.Context) string {
return rs.sid
}
// SessionRelease save session values to redis_sentinel
func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
}
c := rs.p
c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second)
}
// Provider redis_sentinel session provider
type Provider struct {
maxlifetime int64
savePath string
poolsize int
password string
dbNum int
idleTimeout time.Duration
idleCheckFrequency time.Duration
maxRetries int
poollist *redis.Client
masterName string
}
// SessionInit init redis_sentinel session
// savepath like redis sentinel addr,pool size,password,dbnum,masterName
// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
rp.savePath = configs[0]
}
if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize < 0 {
rp.poolsize = DefaultPoolSize
} else {
rp.poolsize = poolsize
}
} else {
rp.poolsize = DefaultPoolSize
}
if len(configs) > 2 {
rp.password = configs[2]
}
if len(configs) > 3 {
dbnum, err := strconv.Atoi(configs[3])
if err != nil || dbnum < 0 {
rp.dbNum = 0
} else {
rp.dbNum = dbnum
}
} else {
rp.dbNum = 0
}
if len(configs) > 4 {
if configs[4] != "" {
rp.masterName = configs[4]
} else {
rp.masterName = "mymaster"
}
} else {
rp.masterName = "mymaster"
}
if len(configs) > 5 {
timeout, err := strconv.Atoi(configs[4])
if err == nil && timeout > 0 {
rp.idleTimeout = time.Duration(timeout) * time.Second
}
}
if len(configs) > 6 {
checkFrequency, err := strconv.Atoi(configs[5])
if err == nil && checkFrequency > 0 {
rp.idleCheckFrequency = time.Duration(checkFrequency) * time.Second
}
}
if len(configs) > 7 {
retries, err := strconv.Atoi(configs[6])
if err == nil && retries > 0 {
rp.maxRetries = retries
}
}
rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{
SentinelAddrs: strings.Split(rp.savePath, ";"),
Password: rp.password,
PoolSize: rp.poolsize,
DB: rp.dbNum,
MasterName: rp.masterName,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.maxRetries,
})
return rp.poollist.Ping().Err()
}
// SessionRead read redis_sentinel session by sid
func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
var kv map[interface{}]interface{}
kvs, err := rp.poollist.Get(sid).Result()
if err != nil && err != redis.Nil {
return nil, err
}
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
if kv, err = session.DecodeGob([]byte(kvs)); err != nil {
return nil, err
}
}
rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
// SessionExist check redis_sentinel session exist by sid
func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := rp.poollist
if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 {
return false, err
}
return true, nil
}
// SessionRegenerate generate new sid for redis_sentinel session
func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := rp.poollist
if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second)
} else {
c.Rename(oldsid, sid)
c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second)
}
return rp.SessionRead(context.Background(), sid)
}
// SessionDestroy delete redis session by id
func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := rp.poollist
c.Del(sid)
return nil
}
// SessionGC Impelment method, no used.
func (rp *Provider) SessionGC(context.Context) {
}
// SessionAll return all activeSession
func (rp *Provider) SessionAll(context.Context) int {
return 0
}
func init() {
session.Register("redis_sentinel", redispder)
}

View File

@ -0,0 +1,90 @@
package redis_sentinel
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/astaxie/beego/server/web/session"
)
func TestRedisSentinel(t *testing.T) {
sessionConfig := &session.ManagerConfig{
CookieName: "gosessionid",
EnableSetCookie: true,
Gclifetime: 3600,
Maxlifetime: 3600,
Secure: false,
CookieLifeTime: 3600,
ProviderConfig: "127.0.0.1:6379,100,,0,master",
}
globalSessions, e := session.NewManager("redis_sentinel", sessionConfig)
if e != nil {
t.Log(e)
return
}
//todo test if e==nil
go globalSessions.GC()
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess, err := globalSessions.SessionStart(w, r)
if err != nil {
t.Fatal("session start failed:", err)
}
defer sess.SessionRelease(nil, w)
// SET AND GET
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set username failed:", err)
}
username := sess.Get(nil, "username")
if username != "astaxie" {
t.Fatal("get username failed")
}
// DELETE
err = sess.Delete(nil, "username")
if err != nil {
t.Fatal("delete username failed:", err)
}
username = sess.Get(nil, "username")
if username != nil {
t.Fatal("delete username failed")
}
// FLUSH
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set failed:", err)
}
err = sess.Set(nil, "password", "1qaz2wsx")
if err != nil {
t.Fatal("set failed:", err)
}
username = sess.Get(nil, "username")
if username != "astaxie" {
t.Fatal("get username failed")
}
password := sess.Get(nil, "password")
if password != "1qaz2wsx" {
t.Fatal("get password failed")
}
err = sess.Flush(nil)
if err != nil {
t.Fatal("flush failed:", err)
}
username = sess.Get(nil, "username")
if username != nil {
t.Fatal("flush failed")
}
password = sess.Get(nil, "password")
if password != nil {
t.Fatal("flush failed")
}
sess.SessionRelease(nil, w)
}

View File

@ -0,0 +1,181 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"context"
"crypto/aes"
"crypto/cipher"
"encoding/json"
"net/http"
"net/url"
"sync"
)
var cookiepder = &CookieProvider{}
// CookieSessionStore Cookie SessionStore
type CookieSessionStore struct {
sid string
values map[interface{}]interface{} // session data
lock sync.RWMutex
}
// Set value to cookie session.
// the value are encoded as gob with hash block string.
func (st *CookieSessionStore) Set(ctx context.Context, key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
return nil
}
// Get value from cookie session
func (st *CookieSessionStore) Get(ctx context.Context, key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
}
return nil
}
// Delete value in cookie session
func (st *CookieSessionStore) Delete(ctx context.Context, key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
return nil
}
// Flush Clean all values in cookie session
func (st *CookieSessionStore) Flush(context.Context) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
return nil
}
// SessionID Return id of this cookie session
func (st *CookieSessionStore) SessionID(context.Context) string {
return st.sid
}
// SessionRelease Write cookie session to http response cookie
func (st *CookieSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
st.lock.Lock()
encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values)
st.lock.Unlock()
if err == nil {
cookie := &http.Cookie{Name: cookiepder.config.CookieName,
Value: url.QueryEscape(encodedCookie),
Path: "/",
HttpOnly: true,
Secure: cookiepder.config.Secure,
MaxAge: cookiepder.config.Maxage}
http.SetCookie(w, cookie)
}
}
type cookieConfig struct {
SecurityKey string `json:"securityKey"`
BlockKey string `json:"blockKey"`
SecurityName string `json:"securityName"`
CookieName string `json:"cookieName"`
Secure bool `json:"secure"`
Maxage int `json:"maxage"`
}
// CookieProvider Cookie session provider
type CookieProvider struct {
maxlifetime int64
config *cookieConfig
block cipher.Block
}
// SessionInit Init cookie session provider with max lifetime and config json.
// maxlifetime is ignored.
// json config:
// securityKey - hash string
// blockKey - gob encode hash string. it's saved as aes crypto.
// securityName - recognized name in encoded cookie string
// cookieName - cookie name
// maxage - cookie max life time.
func (pder *CookieProvider) SessionInit(ctx context.Context, maxlifetime int64, config string) error {
pder.config = &cookieConfig{}
err := json.Unmarshal([]byte(config), pder.config)
if err != nil {
return err
}
if pder.config.BlockKey == "" {
pder.config.BlockKey = string(generateRandomKey(16))
}
if pder.config.SecurityName == "" {
pder.config.SecurityName = string(generateRandomKey(20))
}
pder.block, err = aes.NewCipher([]byte(pder.config.BlockKey))
if err != nil {
return err
}
pder.maxlifetime = maxlifetime
return nil
}
// SessionRead Get SessionStore in cooke.
// decode cooke string to map and put into SessionStore with sid.
func (pder *CookieProvider) SessionRead(ctx context.Context, sid string) (Store, error) {
maps, _ := decodeCookie(pder.block,
pder.config.SecurityKey,
pder.config.SecurityName,
sid, pder.maxlifetime)
if maps == nil {
maps = make(map[interface{}]interface{})
}
rs := &CookieSessionStore{sid: sid, values: maps}
return rs, nil
}
// SessionExist Cookie session is always existed
func (pder *CookieProvider) SessionExist(ctx context.Context, sid string) (bool, error) {
return true, nil
}
// SessionRegenerate Implement method, no used.
func (pder *CookieProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) {
return nil, nil
}
// SessionDestroy Implement method, no used.
func (pder *CookieProvider) SessionDestroy(ctx context.Context, sid string) error {
return nil
}
// SessionGC Implement method, no used.
func (pder *CookieProvider) SessionGC(context.Context) {
}
// SessionAll Implement method, return 0.
func (pder *CookieProvider) SessionAll(context.Context) int {
return 0
}
// SessionUpdate Implement method, no used.
func (pder *CookieProvider) SessionUpdate(ctx context.Context, sid string) error {
return nil
}
func init() {
Register("cookie", cookiepder)
}

View File

@ -0,0 +1,105 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestCookie(t *testing.T) {
config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
conf := new(ManagerConfig)
if err := json.Unmarshal([]byte(config), conf); err != nil {
t.Fatal("json decode error", err)
}
globalSessions, err := NewManager("cookie", conf)
if err != nil {
t.Fatal("init cookie session err", err)
}
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess, err := globalSessions.SessionStart(w, r)
if err != nil {
t.Fatal("set error,", err)
}
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set error,", err)
}
if username := sess.Get(nil, "username"); username != "astaxie" {
t.Fatal("get username error")
}
sess.SessionRelease(nil, w)
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
t.Fatal("setcookie error")
} else {
parts := strings.Split(strings.TrimSpace(cookiestr), ";")
for k, v := range parts {
nameval := strings.Split(v, "=")
if k == 0 && nameval[0] != "gosessionid" {
t.Fatal("error")
}
}
}
}
func TestDestorySessionCookie(t *testing.T) {
config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
conf := new(ManagerConfig)
if err := json.Unmarshal([]byte(config), conf); err != nil {
t.Fatal("json decode error", err)
}
globalSessions, err := NewManager("cookie", conf)
if err != nil {
t.Fatal("init cookie session err", err)
}
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
session, err := globalSessions.SessionStart(w, r)
if err != nil {
t.Fatal("session start err,", err)
}
// request again ,will get same sesssion id .
r1, _ := http.NewRequest("GET", "/", nil)
r1.Header.Set("Cookie", w.Header().Get("Set-Cookie"))
w = httptest.NewRecorder()
newSession, err := globalSessions.SessionStart(w, r1)
if err != nil {
t.Fatal("session start err,", err)
}
if newSession.SessionID(nil) != session.SessionID(nil) {
t.Fatal("get cookie session id is not the same again.")
}
// After destroy session , will get a new session id .
globalSessions.SessionDestroy(w, r1)
r2, _ := http.NewRequest("GET", "/", nil)
r2.Header.Set("Cookie", w.Header().Get("Set-Cookie"))
w = httptest.NewRecorder()
newSession, err = globalSessions.SessionStart(w, r2)
if err != nil {
t.Fatal("session start error")
}
if newSession.SessionID(nil) == session.SessionID(nil) {
t.Fatal("after destroy session and reqeust again ,get cookie session id is same.")
}
}

View File

@ -0,0 +1,316 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"context"
"errors"
"fmt"
"io/ioutil"
"net/http"
"os"
"path"
"path/filepath"
"strings"
"sync"
"time"
)
var (
filepder = &FileProvider{}
gcmaxlifetime int64
)
// FileSessionStore File session store
type FileSessionStore struct {
sid string
lock sync.RWMutex
values map[interface{}]interface{}
}
// Set value to file session
func (fs *FileSessionStore) Set(ctx context.Context, key, value interface{}) error {
fs.lock.Lock()
defer fs.lock.Unlock()
fs.values[key] = value
return nil
}
// Get value from file session
func (fs *FileSessionStore) Get(ctx context.Context, key interface{}) interface{} {
fs.lock.RLock()
defer fs.lock.RUnlock()
if v, ok := fs.values[key]; ok {
return v
}
return nil
}
// Delete value in file session by given key
func (fs *FileSessionStore) Delete(ctx context.Context, key interface{}) error {
fs.lock.Lock()
defer fs.lock.Unlock()
delete(fs.values, key)
return nil
}
// Flush Clean all values in file session
func (fs *FileSessionStore) Flush(context.Context) error {
fs.lock.Lock()
defer fs.lock.Unlock()
fs.values = make(map[interface{}]interface{})
return nil
}
// SessionID Get file session store id
func (fs *FileSessionStore) SessionID(context.Context) string {
return fs.sid
}
// SessionRelease Write file session to local file with Gob string
func (fs *FileSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
b, err := EncodeGob(fs.values)
if err != nil {
SLogger.Println(err)
return
}
_, err = os.Stat(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid))
var f *os.File
if err == nil {
f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777)
if err != nil {
SLogger.Println(err)
return
}
} else if os.IsNotExist(err) {
f, err = os.Create(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid))
if err != nil {
SLogger.Println(err)
return
}
} else {
return
}
f.Truncate(0)
f.Seek(0, 0)
f.Write(b)
f.Close()
}
// FileProvider File session provider
type FileProvider struct {
lock sync.RWMutex
maxlifetime int64
savePath string
}
// SessionInit Init file session provider.
// savePath sets the session files path.
func (fp *FileProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
fp.maxlifetime = maxlifetime
fp.savePath = savePath
return nil
}
// SessionRead Read file session by sid.
// if file is not exist, create it.
// the file path is generated from sid string.
func (fp *FileProvider) SessionRead(ctx context.Context, sid string) (Store, error) {
invalidChars := "./"
if strings.ContainsAny(sid, invalidChars) {
return nil, errors.New("the sid shouldn't have following characters: " + invalidChars)
}
if len(sid) < 2 {
return nil, errors.New("length of the sid is less than 2")
}
filepder.lock.Lock()
defer filepder.lock.Unlock()
err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0755)
if err != nil {
SLogger.Println(err.Error())
}
_, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
var f *os.File
if err == nil {
f, err = os.OpenFile(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), os.O_RDWR, 0777)
} else if os.IsNotExist(err) {
f, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
} else {
return nil, err
}
defer f.Close()
os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now())
var kv map[interface{}]interface{}
b, err := ioutil.ReadAll(f)
if err != nil {
return nil, err
}
if len(b) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = DecodeGob(b)
if err != nil {
return nil, err
}
}
ss := &FileSessionStore{sid: sid, values: kv}
return ss, nil
}
// SessionExist Check file session exist.
// it checks the file named from sid exist or not.
func (fp *FileProvider) SessionExist(ctx context.Context, sid string) (bool, error) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
if len(sid) < 2 {
SLogger.Println("min length of session id is 2 but got length: ", sid)
return false, errors.New("min length of session id is 2")
}
_, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
return err == nil, nil
}
// SessionDestroy Remove all files in this save path
func (fp *FileProvider) SessionDestroy(ctx context.Context, sid string) error {
filepder.lock.Lock()
defer filepder.lock.Unlock()
os.Remove(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
return nil
}
// SessionGC Recycle files in save path
func (fp *FileProvider) SessionGC(context.Context) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
gcmaxlifetime = fp.maxlifetime
filepath.Walk(fp.savePath, gcpath)
}
// SessionAll Get active file session number.
// it walks save path to count files.
func (fp *FileProvider) SessionAll(context.Context) int {
a := &activeSession{}
err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error {
return a.visit(path, f, err)
})
if err != nil {
SLogger.Printf("filepath.Walk() returned %v\n", err)
return 0
}
return a.total
}
// SessionRegenerate Generate new sid for file session.
// it delete old file and create new file named from new sid.
func (fp *FileProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
oldPath := path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]))
oldSidFile := path.Join(oldPath, oldsid)
newPath := path.Join(fp.savePath, string(sid[0]), string(sid[1]))
newSidFile := path.Join(newPath, sid)
// new sid file is exist
_, err := os.Stat(newSidFile)
if err == nil {
return nil, fmt.Errorf("newsid %s exist", newSidFile)
}
err = os.MkdirAll(newPath, 0755)
if err != nil {
SLogger.Println(err.Error())
}
// if old sid file exist
// 1.read and parse file content
// 2.write content to new sid file
// 3.remove old sid file, change new sid file atime and ctime
// 4.return FileSessionStore
_, err = os.Stat(oldSidFile)
if err == nil {
b, err := ioutil.ReadFile(oldSidFile)
if err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if len(b) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = DecodeGob(b)
if err != nil {
return nil, err
}
}
ioutil.WriteFile(newSidFile, b, 0777)
os.Remove(oldSidFile)
os.Chtimes(newSidFile, time.Now(), time.Now())
ss := &FileSessionStore{sid: sid, values: kv}
return ss, nil
}
// if old sid file not exist, just create new sid file and return
newf, err := os.Create(newSidFile)
if err != nil {
return nil, err
}
newf.Close()
ss := &FileSessionStore{sid: sid, values: make(map[interface{}]interface{})}
return ss, nil
}
// remove file in save path if expired
func gcpath(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
if (info.ModTime().Unix() + gcmaxlifetime) < time.Now().Unix() {
os.Remove(path)
}
return nil
}
type activeSession struct {
total int
}
func (as *activeSession) visit(paths string, f os.FileInfo, err error) error {
if err != nil {
return err
}
if f.IsDir() {
return nil
}
as.total = as.total + 1
return nil
}
func init() {
Register("file", filepder)
}

View File

@ -0,0 +1,427 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"context"
"fmt"
"os"
"sync"
"testing"
"time"
)
const sid = "Session_id"
const sidNew = "Session_id_new"
const sessionPath = "./_session_runtime"
var (
mutex sync.Mutex
)
func TestFileProvider_SessionInit(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
if fp.maxlifetime != 180 {
t.Error()
}
if fp.savePath != sessionPath {
t.Error()
}
}
func TestFileProvider_SessionExist(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
exists, err := fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if exists {
t.Error()
}
_, err = fp.SessionRead(context.Background(), sid)
if err != nil {
t.Error(err)
}
exists, err = fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if !exists {
t.Error()
}
}
func TestFileProvider_SessionExist2(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
exists, err := fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if exists {
t.Error()
}
exists, err = fp.SessionExist(context.Background(), "")
if err == nil {
t.Error()
}
if exists {
t.Error()
}
exists, err = fp.SessionExist(context.Background(), "1")
if err == nil {
t.Error()
}
if exists {
t.Error()
}
}
func TestFileProvider_SessionRead(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
s, err := fp.SessionRead(context.Background(), sid)
if err != nil {
t.Error(err)
}
_ = s.Set(nil, "sessionValue", 18975)
v := s.Get(nil, "sessionValue")
if v.(int) != 18975 {
t.Error()
}
}
func TestFileProvider_SessionRead1(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
_, err := fp.SessionRead(context.Background(), "")
if err == nil {
t.Error(err)
}
_, err = fp.SessionRead(context.Background(), "1")
if err == nil {
t.Error(err)
}
}
func TestFileProvider_SessionAll(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 546
for i := 1; i <= sessionCount; i++ {
_, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil {
t.Error(err)
}
}
if fp.SessionAll(nil) != sessionCount {
t.Error()
}
}
func TestFileProvider_SessionRegenerate(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
_, err := fp.SessionRead(context.Background(), sid)
if err != nil {
t.Error(err)
}
exists, err := fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if !exists {
t.Error()
}
_, err = fp.SessionRegenerate(context.Background(), sid, sidNew)
if err != nil {
t.Error(err)
}
exists, err = fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if exists {
t.Error()
}
exists, err = fp.SessionExist(context.Background(), sidNew)
if err != nil {
t.Error(err)
}
if !exists {
t.Error()
}
}
func TestFileProvider_SessionDestroy(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
_, err := fp.SessionRead(context.Background(), sid)
if err != nil {
t.Error(err)
}
exists, err := fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if !exists {
t.Error()
}
err = fp.SessionDestroy(context.Background(), sid)
if err != nil {
t.Error(err)
}
exists, err = fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if exists {
t.Error()
}
}
func TestFileProvider_SessionGC(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 1, sessionPath)
sessionCount := 412
for i := 1; i <= sessionCount; i++ {
_, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil {
t.Error(err)
}
}
time.Sleep(2 * time.Second)
fp.SessionGC(nil)
if fp.SessionAll(nil) != 0 {
t.Error()
}
}
func TestFileSessionStore_Set(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 100
s, _ := fp.SessionRead(context.Background(), sid)
for i := 1; i <= sessionCount; i++ {
err := s.Set(nil, i, i)
if err != nil {
t.Error(err)
}
}
}
func TestFileSessionStore_Get(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 100
s, _ := fp.SessionRead(context.Background(), sid)
for i := 1; i <= sessionCount; i++ {
_ = s.Set(nil, i, i)
v := s.Get(nil, i)
if v.(int) != i {
t.Error()
}
}
}
func TestFileSessionStore_Delete(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
s, _ := fp.SessionRead(context.Background(), sid)
s.Set(nil, "1", 1)
if s.Get(nil, "1") == nil {
t.Error()
}
s.Delete(nil, "1")
if s.Get(nil, "1") != nil {
t.Error()
}
}
func TestFileSessionStore_Flush(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 100
s, _ := fp.SessionRead(context.Background(), sid)
for i := 1; i <= sessionCount; i++ {
_ = s.Set(nil, i, i)
}
_ = s.Flush(nil)
for i := 1; i <= sessionCount; i++ {
if s.Get(nil, i) != nil {
t.Error()
}
}
}
func TestFileSessionStore_SessionID(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 85
for i := 1; i <= sessionCount; i++ {
s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil {
t.Error(err)
}
if s.SessionID(nil) != fmt.Sprintf("%s_%d", sid, i) {
t.Error(err)
}
}
}
func TestFileSessionStore_SessionRelease(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
filepder.savePath = sessionPath
sessionCount := 85
for i := 1; i <= sessionCount; i++ {
s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil {
t.Error(err)
}
s.Set(nil, i, i)
s.SessionRelease(nil, nil)
}
for i := 1; i <= sessionCount; i++ {
s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil {
t.Error(err)
}
if s.Get(nil, i).(int) != i {
t.Error()
}
}
}

View File

@ -0,0 +1,197 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"container/list"
"context"
"net/http"
"sync"
"time"
)
var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)}
// MemSessionStore memory session store.
// it saved sessions in a map in memory.
type MemSessionStore struct {
sid string //session id
timeAccessed time.Time //last access time
value map[interface{}]interface{} //session store
lock sync.RWMutex
}
// Set value to memory session
func (st *MemSessionStore) Set(ctx context.Context, key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.value[key] = value
return nil
}
// Get value from memory session by key
func (st *MemSessionStore) Get(ctx context.Context, key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.value[key]; ok {
return v
}
return nil
}
// Delete in memory session by key
func (st *MemSessionStore) Delete(ctx context.Context, key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.value, key)
return nil
}
// Flush clear all values in memory session
func (st *MemSessionStore) Flush(context.Context) error {
st.lock.Lock()
defer st.lock.Unlock()
st.value = make(map[interface{}]interface{})
return nil
}
// SessionID get this id of memory session store
func (st *MemSessionStore) SessionID(context.Context) string {
return st.sid
}
// SessionRelease Implement method, no used.
func (st *MemSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
}
// MemProvider Implement the provider interface
type MemProvider struct {
lock sync.RWMutex // locker
sessions map[string]*list.Element // map in memory
list *list.List // for gc
maxlifetime int64
savePath string
}
// SessionInit init memory session
func (pder *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
pder.maxlifetime = maxlifetime
pder.savePath = savePath
return nil
}
// SessionRead get memory session store by sid
func (pder *MemProvider) SessionRead(ctx context.Context, sid string) (Store, error) {
pder.lock.RLock()
if element, ok := pder.sessions[sid]; ok {
go pder.SessionUpdate(nil, sid)
pder.lock.RUnlock()
return element.Value.(*MemSessionStore), nil
}
pder.lock.RUnlock()
pder.lock.Lock()
newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})}
element := pder.list.PushFront(newsess)
pder.sessions[sid] = element
pder.lock.Unlock()
return newsess, nil
}
// SessionExist check session store exist in memory session by sid
func (pder *MemProvider) SessionExist(ctx context.Context, sid string) (bool, error) {
pder.lock.RLock()
defer pder.lock.RUnlock()
if _, ok := pder.sessions[sid]; ok {
return true, nil
}
return false, nil
}
// SessionRegenerate generate new sid for session store in memory session
func (pder *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) {
pder.lock.RLock()
if element, ok := pder.sessions[oldsid]; ok {
go pder.SessionUpdate(nil, oldsid)
pder.lock.RUnlock()
pder.lock.Lock()
element.Value.(*MemSessionStore).sid = sid
pder.sessions[sid] = element
delete(pder.sessions, oldsid)
pder.lock.Unlock()
return element.Value.(*MemSessionStore), nil
}
pder.lock.RUnlock()
pder.lock.Lock()
newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})}
element := pder.list.PushFront(newsess)
pder.sessions[sid] = element
pder.lock.Unlock()
return newsess, nil
}
// SessionDestroy delete session store in memory session by id
func (pder *MemProvider) SessionDestroy(ctx context.Context, sid string) error {
pder.lock.Lock()
defer pder.lock.Unlock()
if element, ok := pder.sessions[sid]; ok {
delete(pder.sessions, sid)
pder.list.Remove(element)
return nil
}
return nil
}
// SessionGC clean expired session stores in memory session
func (pder *MemProvider) SessionGC(context.Context) {
pder.lock.RLock()
for {
element := pder.list.Back()
if element == nil {
break
}
if (element.Value.(*MemSessionStore).timeAccessed.Unix() + pder.maxlifetime) < time.Now().Unix() {
pder.lock.RUnlock()
pder.lock.Lock()
pder.list.Remove(element)
delete(pder.sessions, element.Value.(*MemSessionStore).sid)
pder.lock.Unlock()
pder.lock.RLock()
} else {
break
}
}
pder.lock.RUnlock()
}
// SessionAll get count number of memory session
func (pder *MemProvider) SessionAll(context.Context) int {
return pder.list.Len()
}
// SessionUpdate expand time of session store by id in memory session
func (pder *MemProvider) SessionUpdate(ctx context.Context, sid string) error {
pder.lock.Lock()
defer pder.lock.Unlock()
if element, ok := pder.sessions[sid]; ok {
element.Value.(*MemSessionStore).timeAccessed = time.Now()
pder.list.MoveToFront(element)
return nil
}
return nil
}
func init() {
Register("memory", mempder)
}

View File

@ -0,0 +1,58 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestMem(t *testing.T) {
config := `{"cookieName":"gosessionid","gclifetime":10, "enableSetCookie":true}`
conf := new(ManagerConfig)
if err := json.Unmarshal([]byte(config), conf); err != nil {
t.Fatal("json decode error", err)
}
globalSessions, _ := NewManager("memory", conf)
go globalSessions.GC()
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess, err := globalSessions.SessionStart(w, r)
if err != nil {
t.Fatal("set error,", err)
}
defer sess.SessionRelease(nil, w)
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set error,", err)
}
if username := sess.Get(nil, "username"); username != "astaxie" {
t.Fatal("get username error")
}
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
t.Fatal("setcookie error")
} else {
parts := strings.Split(strings.TrimSpace(cookiestr), ";")
for k, v := range parts {
nameval := strings.Split(v, "=")
if k == 0 && nameval[0] != "gosessionid" {
t.Fatal("error")
}
}
}
}

View File

@ -0,0 +1,131 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"crypto/aes"
"encoding/json"
"testing"
)
func Test_gob(t *testing.T) {
a := make(map[interface{}]interface{})
a["username"] = "astaxie"
a[12] = 234
a["user"] = User{"asta", "xie"}
b, err := EncodeGob(a)
if err != nil {
t.Error(err)
}
c, err := DecodeGob(b)
if err != nil {
t.Error(err)
}
if len(c) == 0 {
t.Error("decodeGob empty")
}
if c["username"] != "astaxie" {
t.Error("decode string error")
}
if c[12] != 234 {
t.Error("decode int error")
}
if c["user"].(User).Username != "asta" {
t.Error("decode struct error")
}
}
type User struct {
Username string
NickName string
}
func TestGenerate(t *testing.T) {
str := generateRandomKey(20)
if len(str) != 20 {
t.Fatal("generate length is not equal to 20")
}
}
func TestCookieEncodeDecode(t *testing.T) {
hashKey := "testhashKey"
blockkey := generateRandomKey(16)
block, err := aes.NewCipher(blockkey)
if err != nil {
t.Fatal("NewCipher:", err)
}
securityName := string(generateRandomKey(20))
val := make(map[interface{}]interface{})
val["name"] = "astaxie"
val["gender"] = "male"
str, err := encodeCookie(block, hashKey, securityName, val)
if err != nil {
t.Fatal("encodeCookie:", err)
}
dst, err := decodeCookie(block, hashKey, securityName, str, 3600)
if err != nil {
t.Fatal("decodeCookie", err)
}
if dst["name"] != "astaxie" {
t.Fatal("dst get map error")
}
if dst["gender"] != "male" {
t.Fatal("dst get map error")
}
}
func TestParseConfig(t *testing.T) {
s := `{"cookieName":"gosessionid","gclifetime":3600}`
cf := new(ManagerConfig)
cf.EnableSetCookie = true
err := json.Unmarshal([]byte(s), cf)
if err != nil {
t.Fatal("parse json error,", err)
}
if cf.CookieName != "gosessionid" {
t.Fatal("parseconfig get cookiename error")
}
if cf.Gclifetime != 3600 {
t.Fatal("parseconfig get gclifetime error")
}
cc := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
cf2 := new(ManagerConfig)
cf2.EnableSetCookie = true
err = json.Unmarshal([]byte(cc), cf2)
if err != nil {
t.Fatal("parse json error,", err)
}
if cf2.CookieName != "gosessionid" {
t.Fatal("parseconfig get cookiename error")
}
if cf2.Gclifetime != 3600 {
t.Fatal("parseconfig get gclifetime error")
}
if cf2.EnableSetCookie {
t.Fatal("parseconfig get enableSetCookie error")
}
cconfig := new(cookieConfig)
err = json.Unmarshal([]byte(cf2.ProviderConfig), cconfig)
if err != nil {
t.Fatal("parse ProviderConfig err,", err)
}
if cconfig.CookieName != "gosessionid" {
t.Fatal("ProviderConfig get cookieName error")
}
if cconfig.SecurityKey != "beegocookiehashkey" {
t.Fatal("ProviderConfig get securityKey error")
}
}

View File

@ -0,0 +1,207 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"bytes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/gob"
"errors"
"fmt"
"io"
"strconv"
"time"
"github.com/astaxie/beego/core/utils"
)
func init() {
gob.Register([]interface{}{})
gob.Register(map[int]interface{}{})
gob.Register(map[string]interface{}{})
gob.Register(map[interface{}]interface{}{})
gob.Register(map[string]string{})
gob.Register(map[int]string{})
gob.Register(map[int]int{})
gob.Register(map[int]int64{})
}
// EncodeGob encode the obj to gob
func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) {
for _, v := range obj {
gob.Register(v)
}
buf := bytes.NewBuffer(nil)
enc := gob.NewEncoder(buf)
err := enc.Encode(obj)
if err != nil {
return []byte(""), err
}
return buf.Bytes(), nil
}
// DecodeGob decode data to map
func DecodeGob(encoded []byte) (map[interface{}]interface{}, error) {
buf := bytes.NewBuffer(encoded)
dec := gob.NewDecoder(buf)
var out map[interface{}]interface{}
err := dec.Decode(&out)
if err != nil {
return nil, err
}
return out, nil
}
// generateRandomKey creates a random key with the given strength.
func generateRandomKey(strength int) []byte {
k := make([]byte, strength)
if n, err := io.ReadFull(rand.Reader, k); n != strength || err != nil {
return utils.RandomCreateBytes(strength)
}
return k
}
// Encryption -----------------------------------------------------------------
// encrypt encrypts a value using the given block in counter mode.
//
// A random initialization vector (http://goo.gl/zF67k) with the length of the
// block size is prepended to the resulting ciphertext.
func encrypt(block cipher.Block, value []byte) ([]byte, error) {
iv := generateRandomKey(block.BlockSize())
if iv == nil {
return nil, errors.New("encrypt: failed to generate random iv")
}
// Encrypt it.
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(value, value)
// Return iv + ciphertext.
return append(iv, value...), nil
}
// decrypt decrypts a value using the given block in counter mode.
//
// The value to be decrypted must be prepended by a initialization vector
// (http://goo.gl/zF67k) with the length of the block size.
func decrypt(block cipher.Block, value []byte) ([]byte, error) {
size := block.BlockSize()
if len(value) > size {
// Extract iv.
iv := value[:size]
// Extract ciphertext.
value = value[size:]
// Decrypt it.
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(value, value)
return value, nil
}
return nil, errors.New("decrypt: the value could not be decrypted")
}
func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) {
var err error
var b []byte
// 1. EncodeGob.
if b, err = EncodeGob(value); err != nil {
return "", err
}
// 2. Encrypt (optional).
if b, err = encrypt(block, b); err != nil {
return "", err
}
b = encode(b)
// 3. Create MAC for "name|date|value". Extra pipe to be used later.
b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b))
h := hmac.New(sha256.New, []byte(hashKey))
h.Write(b)
sig := h.Sum(nil)
// Append mac, remove name.
b = append(b, sig...)[len(name)+1:]
// 4. Encode to base64.
b = encode(b)
// Done.
return string(b), nil
}
func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) {
// 1. Decode from base64.
b, err := decode([]byte(value))
if err != nil {
return nil, err
}
// 2. Verify MAC. Value is "date|value|mac".
parts := bytes.SplitN(b, []byte("|"), 3)
if len(parts) != 3 {
return nil, errors.New("Decode: invalid value format")
}
b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...)
h := hmac.New(sha256.New, []byte(hashKey))
h.Write(b)
sig := h.Sum(nil)
if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 {
return nil, errors.New("Decode: the value is not valid")
}
// 3. Verify date ranges.
var t1 int64
if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil {
return nil, errors.New("Decode: invalid timestamp")
}
t2 := time.Now().UTC().Unix()
if t1 > t2 {
return nil, errors.New("Decode: timestamp is too new")
}
if t1 < t2-gcmaxlifetime {
return nil, errors.New("Decode: expired timestamp")
}
// 4. Decrypt (optional).
b, err = decode(parts[1])
if err != nil {
return nil, err
}
if b, err = decrypt(block, b); err != nil {
return nil, err
}
// 5. DecodeGob.
dst, err := DecodeGob(b)
if err != nil {
return nil, err
}
return dst, nil
}
// Encoding -------------------------------------------------------------------
// encode encodes a value using base64.
func encode(value []byte) []byte {
encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value)))
base64.URLEncoding.Encode(encoded, value)
return encoded
}
// decode decodes a cookie using base64.
func decode(value []byte) ([]byte, error) {
decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value)))
b, err := base64.URLEncoding.Decode(decoded, value)
if err != nil {
return nil, err
}
return decoded[:b], nil
}

View File

@ -0,0 +1,384 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package session provider
//
// Usage:
// import(
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "cookieLifeTime": 3600, "providerConfig": ""}`)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package session
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/textproto"
"net/url"
"os"
"time"
)
// Store contains all data for one session process with specific id.
type Store interface {
Set(ctx context.Context, key, value interface{}) error //set session value
Get(ctx context.Context, key interface{}) interface{} //get session value
Delete(ctx context.Context, key interface{}) error //delete session value
SessionID(ctx context.Context) string //back current sessionID
SessionRelease(ctx context.Context, w http.ResponseWriter) // release the resource & save data to provider & return the data
Flush(ctx context.Context) error //delete all data
}
// Provider contains global session methods and saved SessionStores.
// it can operate a SessionStore by its id.
type Provider interface {
SessionInit(ctx context.Context, gclifetime int64, config string) error
SessionRead(ctx context.Context, sid string) (Store, error)
SessionExist(ctx context.Context, sid string) (bool, error)
SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error)
SessionDestroy(ctx context.Context, sid string) error
SessionAll(ctx context.Context) int //get all active session
SessionGC(ctx context.Context)
}
var provides = make(map[string]Provider)
// SLogger a helpful variable to log information about session
var SLogger = NewSessionLog(os.Stderr)
// Register makes a session provide available by the provided name.
// If Register is called twice with the same name or if driver is nil,
// it panics.
func Register(name string, provide Provider) {
if provide == nil {
panic("session: Register provide is nil")
}
if _, dup := provides[name]; dup {
panic("session: Register called twice for provider " + name)
}
provides[name] = provide
}
//GetProvider
func GetProvider(name string) (Provider, error) {
provider, ok := provides[name]
if !ok {
return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", name)
}
return provider, nil
}
// ManagerConfig define the session config
type ManagerConfig struct {
CookieName string `json:"cookieName"`
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
Gclifetime int64 `json:"gclifetime"`
Maxlifetime int64 `json:"maxLifetime"`
DisableHTTPOnly bool `json:"disableHTTPOnly"`
Secure bool `json:"secure"`
CookieLifeTime int `json:"cookieLifeTime"`
ProviderConfig string `json:"providerConfig"`
Domain string `json:"domain"`
SessionIDLength int64 `json:"sessionIDLength"`
EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"`
SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"`
EnableSidInURLQuery bool `json:"EnableSidInURLQuery"`
SessionIDPrefix string `json:"sessionIDPrefix"`
}
// Manager contains Provider and its configuration.
type Manager struct {
provider Provider
config *ManagerConfig
}
// NewManager Create new Manager with provider name and json config string.
// provider name:
// 1. cookie
// 2. file
// 3. memory
// 4. redis
// 5. mysql
// json config:
// 1. is https default false
// 2. hashfunc default sha1
// 3. hashkey default beegosessionkey
// 4. maxage default is none
func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) {
provider, ok := provides[provideName]
if !ok {
return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName)
}
if cf.Maxlifetime == 0 {
cf.Maxlifetime = cf.Gclifetime
}
if cf.EnableSidInHTTPHeader {
if cf.SessionNameInHTTPHeader == "" {
panic(errors.New("SessionNameInHTTPHeader is empty"))
}
strMimeHeader := textproto.CanonicalMIMEHeaderKey(cf.SessionNameInHTTPHeader)
if cf.SessionNameInHTTPHeader != strMimeHeader {
strErrMsg := "SessionNameInHTTPHeader (" + cf.SessionNameInHTTPHeader + ") has the wrong format, it should be like this : " + strMimeHeader
panic(errors.New(strErrMsg))
}
}
err := provider.SessionInit(nil, cf.Maxlifetime, cf.ProviderConfig)
if err != nil {
return nil, err
}
if cf.SessionIDLength == 0 {
cf.SessionIDLength = 16
}
return &Manager{
provider,
cf,
}, nil
}
// GetProvider return current manager's provider
func (manager *Manager) GetProvider() Provider {
return manager.provider
}
// getSid retrieves session identifier from HTTP Request.
// First try to retrieve id by reading from cookie, session cookie name is configurable,
// if not exist, then retrieve id from querying parameters.
//
// error is not nil when there is anything wrong.
// sid is empty when need to generate a new session id
// otherwise return an valid session id.
func (manager *Manager) getSid(r *http.Request) (string, error) {
cookie, errs := r.Cookie(manager.config.CookieName)
if errs != nil || cookie.Value == "" {
var sid string
if manager.config.EnableSidInURLQuery {
errs := r.ParseForm()
if errs != nil {
return "", errs
}
sid = r.FormValue(manager.config.CookieName)
}
// if not found in Cookie / param, then read it from request headers
if manager.config.EnableSidInHTTPHeader && sid == "" {
sids, isFound := r.Header[manager.config.SessionNameInHTTPHeader]
if isFound && len(sids) != 0 {
return sids[0], nil
}
}
return sid, nil
}
// HTTP Request contains cookie for sessionid info.
return url.QueryUnescape(cookie.Value)
}
// SessionStart generate or read the session id from http request.
// if session id exists, return SessionStore with this id.
func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session Store, err error) {
sid, errs := manager.getSid(r)
if errs != nil {
return nil, errs
}
if sid != "" {
exists, err := manager.provider.SessionExist(nil, sid)
if err != nil {
return nil, err
}
if exists {
return manager.provider.SessionRead(nil, sid)
}
}
// Generate a new session
sid, errs = manager.sessionID()
if errs != nil {
return nil, errs
}
session, err = manager.provider.SessionRead(nil, sid)
if err != nil {
return nil, err
}
cookie := &http.Cookie{
Name: manager.config.CookieName,
Value: url.QueryEscape(sid),
Path: "/",
HttpOnly: !manager.config.DisableHTTPOnly,
Secure: manager.isSecure(r),
Domain: manager.config.Domain,
}
if manager.config.CookieLifeTime > 0 {
cookie.MaxAge = manager.config.CookieLifeTime
cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second)
}
if manager.config.EnableSetCookie {
http.SetCookie(w, cookie)
}
r.AddCookie(cookie)
if manager.config.EnableSidInHTTPHeader {
r.Header.Set(manager.config.SessionNameInHTTPHeader, sid)
w.Header().Set(manager.config.SessionNameInHTTPHeader, sid)
}
return
}
// SessionDestroy Destroy session by its id in http request cookie.
func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
if manager.config.EnableSidInHTTPHeader {
r.Header.Del(manager.config.SessionNameInHTTPHeader)
w.Header().Del(manager.config.SessionNameInHTTPHeader)
}
cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" {
return
}
sid, _ := url.QueryUnescape(cookie.Value)
manager.provider.SessionDestroy(nil, sid)
if manager.config.EnableSetCookie {
expiration := time.Now()
cookie = &http.Cookie{Name: manager.config.CookieName,
Path: "/",
HttpOnly: !manager.config.DisableHTTPOnly,
Expires: expiration,
MaxAge: -1,
Domain: manager.config.Domain}
http.SetCookie(w, cookie)
}
}
// GetSessionStore Get SessionStore by its id.
func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) {
sessions, err = manager.provider.SessionRead(nil, sid)
return
}
// GC Start session gc process.
// it can do gc in times after gc lifetime.
func (manager *Manager) GC() {
manager.provider.SessionGC(nil)
time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() })
}
// SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request.
func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) (session Store) {
sid, err := manager.sessionID()
if err != nil {
return
}
cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" {
//delete old cookie
session, _ = manager.provider.SessionRead(nil, sid)
cookie = &http.Cookie{Name: manager.config.CookieName,
Value: url.QueryEscape(sid),
Path: "/",
HttpOnly: !manager.config.DisableHTTPOnly,
Secure: manager.isSecure(r),
Domain: manager.config.Domain,
}
} else {
oldsid, _ := url.QueryUnescape(cookie.Value)
session, _ = manager.provider.SessionRegenerate(nil, oldsid, sid)
cookie.Value = url.QueryEscape(sid)
cookie.HttpOnly = true
cookie.Path = "/"
}
if manager.config.CookieLifeTime > 0 {
cookie.MaxAge = manager.config.CookieLifeTime
cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second)
}
if manager.config.EnableSetCookie {
http.SetCookie(w, cookie)
}
r.AddCookie(cookie)
if manager.config.EnableSidInHTTPHeader {
r.Header.Set(manager.config.SessionNameInHTTPHeader, sid)
w.Header().Set(manager.config.SessionNameInHTTPHeader, sid)
}
return
}
// GetActiveSession Get all active sessions count number.
func (manager *Manager) GetActiveSession() int {
return manager.provider.SessionAll(nil)
}
// SetSecure Set cookie with https.
func (manager *Manager) SetSecure(secure bool) {
manager.config.Secure = secure
}
func (manager *Manager) sessionID() (string, error) {
b := make([]byte, manager.config.SessionIDLength)
n, err := rand.Read(b)
if n != len(b) || err != nil {
return "", fmt.Errorf("Could not successfully read from the system CSPRNG")
}
return manager.config.SessionIDPrefix + hex.EncodeToString(b), nil
}
// Set cookie with https.
func (manager *Manager) isSecure(req *http.Request) bool {
if !manager.config.Secure {
return false
}
if req.URL.Scheme != "" {
return req.URL.Scheme == "https"
}
if req.TLS == nil {
return false
}
return true
}
// Log implement the log.Logger
type Log struct {
*log.Logger
}
// NewSessionLog set io.Writer to create a Logger for session.
func NewSessionLog(out io.Writer) *Log {
sl := new(Log)
sl.Logger = log.New(out, "[SESSION]", 1e9)
return sl
}

View File

@ -0,0 +1,201 @@
package ssdb
import (
"context"
"errors"
"net/http"
"strconv"
"strings"
"sync"
"github.com/ssdb/gossdb/ssdb"
"github.com/astaxie/beego/server/web/session"
)
var ssdbProvider = &Provider{}
// Provider holds ssdb client and configs
type Provider struct {
client *ssdb.Client
host string
port int
maxLifetime int64
}
func (p *Provider) connectInit() error {
var err error
if p.host == "" || p.port == 0 {
return errors.New("SessionInit First")
}
p.client, err = ssdb.Connect(p.host, p.port)
return err
}
// SessionInit init the ssdb with the config
func (p *Provider) SessionInit(ctx context.Context, maxLifetime int64, savePath string) error {
p.maxLifetime = maxLifetime
address := strings.Split(savePath, ":")
p.host = address[0]
var err error
if p.port, err = strconv.Atoi(address[1]); err != nil {
return err
}
return p.connectInit()
}
// SessionRead return a ssdb client session Store
func (p *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
if p.client == nil {
if err := p.connectInit(); err != nil {
return nil, err
}
}
var kv map[interface{}]interface{}
value, err := p.client.Get(sid)
if err != nil {
return nil, err
}
if value == nil || len(value.(string)) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob([]byte(value.(string)))
if err != nil {
return nil, err
}
}
rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client}
return rs, nil
}
// SessionExist judged whether sid is exist in session
func (p *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
if p.client == nil {
if err := p.connectInit(); err != nil {
return false, err
}
}
value, err := p.client.Get(sid)
if err != nil {
panic(err)
}
if value == nil || len(value.(string)) == 0 {
return false, nil
}
return true, nil
}
// SessionRegenerate regenerate session with new sid and delete oldsid
func (p *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
// conn.Do("setx", key, v, ttl)
if p.client == nil {
if err := p.connectInit(); err != nil {
return nil, err
}
}
value, err := p.client.Get(oldsid)
if err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if value == nil || len(value.(string)) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob([]byte(value.(string)))
if err != nil {
return nil, err
}
_, err = p.client.Del(oldsid)
if err != nil {
return nil, err
}
}
_, e := p.client.Do("setx", sid, value, p.maxLifetime)
if e != nil {
return nil, e
}
rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client}
return rs, nil
}
// SessionDestroy destroy the sid
func (p *Provider) SessionDestroy(ctx context.Context, sid string) error {
if p.client == nil {
if err := p.connectInit(); err != nil {
return err
}
}
_, err := p.client.Del(sid)
return err
}
// SessionGC not implemented
func (p *Provider) SessionGC(context.Context) {
}
// SessionAll not implemented
func (p *Provider) SessionAll(context.Context) int {
return 0
}
// SessionStore holds the session information which stored in ssdb
type SessionStore struct {
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxLifetime int64
client *ssdb.Client
}
// Set the key and value
func (s *SessionStore) Set(ctx context.Context, key, value interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
s.values[key] = value
return nil
}
// Get return the value by the key
func (s *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
s.lock.Lock()
defer s.lock.Unlock()
if value, ok := s.values[key]; ok {
return value
}
return nil
}
// Delete the key in session store
func (s *SessionStore) Delete(ctx context.Context, key interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
delete(s.values, key)
return nil
}
// Flush delete all keys and values
func (s *SessionStore) Flush(context.Context) error {
s.lock.Lock()
defer s.lock.Unlock()
s.values = make(map[interface{}]interface{})
return nil
}
// SessionID return the sessionID
func (s *SessionStore) SessionID(context.Context) string {
return s.sid
}
// SessionRelease Store the keyvalues into ssdb
func (s *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(s.values)
if err != nil {
return
}
s.client.Do("setx", s.sid, string(b), s.maxLifetime)
}
func init() {
session.Register("ssdb", ssdbProvider)
}

235
server/web/staticfile.go Normal file
View File

@ -0,0 +1,235 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"bytes"
"errors"
"net/http"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/astaxie/beego/core/logs"
lru "github.com/hashicorp/golang-lru"
"github.com/astaxie/beego/server/web/context"
)
var errNotStaticRequest = errors.New("request not a static file request")
func serverStaticRouter(ctx *context.Context) {
if ctx.Input.Method() != "GET" && ctx.Input.Method() != "HEAD" {
return
}
forbidden, filePath, fileInfo, err := lookupFile(ctx)
if err == errNotStaticRequest {
return
}
if forbidden {
exception("403", ctx)
return
}
if filePath == "" || fileInfo == nil {
if BConfig.RunMode == DEV {
logs.Warn("Can't find/open the file:", filePath, err)
}
http.NotFound(ctx.ResponseWriter, ctx.Request)
return
}
if fileInfo.IsDir() {
requestURL := ctx.Input.URL()
if requestURL[len(requestURL)-1] != '/' {
redirectURL := requestURL + "/"
if ctx.Request.URL.RawQuery != "" {
redirectURL = redirectURL + "?" + ctx.Request.URL.RawQuery
}
ctx.Redirect(302, redirectURL)
} else {
//serveFile will list dir
http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath)
}
return
} else if fileInfo.Size() > int64(BConfig.WebConfig.StaticCacheFileSize) {
//over size file serve with http module
http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath)
return
}
var enableCompress = BConfig.EnableGzip && isStaticCompress(filePath)
var acceptEncoding string
if enableCompress {
acceptEncoding = context.ParseEncoding(ctx.Request)
}
b, n, sch, reader, err := openFile(filePath, fileInfo, acceptEncoding)
if err != nil {
if BConfig.RunMode == DEV {
logs.Warn("Can't compress the file:", filePath, err)
}
http.NotFound(ctx.ResponseWriter, ctx.Request)
return
}
if b {
ctx.Output.Header("Content-Encoding", n)
} else {
ctx.Output.Header("Content-Length", strconv.FormatInt(sch.size, 10))
}
http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, sch.modTime, reader)
}
type serveContentHolder struct {
data []byte
modTime time.Time
size int64
originSize int64 //original file size:to judge file changed
encoding string
}
type serveContentReader struct {
*bytes.Reader
}
var (
staticFileLruCache *lru.Cache
lruLock sync.RWMutex
)
func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, *serveContentReader, error) {
if staticFileLruCache == nil {
//avoid lru cache error
if BConfig.WebConfig.StaticCacheFileNum >= 1 {
staticFileLruCache, _ = lru.New(BConfig.WebConfig.StaticCacheFileNum)
} else {
staticFileLruCache, _ = lru.New(1)
}
}
mapKey := acceptEncoding + ":" + filePath
lruLock.RLock()
var mapFile *serveContentHolder
if cacheItem, ok := staticFileLruCache.Get(mapKey); ok {
mapFile = cacheItem.(*serveContentHolder)
}
lruLock.RUnlock()
if isOk(mapFile, fi) {
reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)}
return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil
}
lruLock.Lock()
defer lruLock.Unlock()
if cacheItem, ok := staticFileLruCache.Get(mapKey); ok {
mapFile = cacheItem.(*serveContentHolder)
}
if !isOk(mapFile, fi) {
file, err := os.Open(filePath)
if err != nil {
return false, "", nil, nil, err
}
defer file.Close()
var bufferWriter bytes.Buffer
_, n, err := context.WriteFile(acceptEncoding, &bufferWriter, file)
if err != nil {
return false, "", nil, nil, err
}
mapFile = &serveContentHolder{data: bufferWriter.Bytes(), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), originSize: fi.Size(), encoding: n}
if isOk(mapFile, fi) {
staticFileLruCache.Add(mapKey, mapFile)
}
}
reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)}
return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil
}
func isOk(s *serveContentHolder, fi os.FileInfo) bool {
if s == nil {
return false
} else if s.size > int64(BConfig.WebConfig.StaticCacheFileSize) {
return false
}
return s.modTime == fi.ModTime() && s.originSize == fi.Size()
}
// isStaticCompress detect static files
func isStaticCompress(filePath string) bool {
for _, statExtension := range BConfig.WebConfig.StaticExtensionsToGzip {
if strings.HasSuffix(strings.ToLower(filePath), strings.ToLower(statExtension)) {
return true
}
}
return false
}
// searchFile search the file by url path
// if none the static file prefix matches ,return notStaticRequestErr
func searchFile(ctx *context.Context) (string, os.FileInfo, error) {
requestPath := filepath.ToSlash(filepath.Clean(ctx.Request.URL.Path))
// special processing : favicon.ico/robots.txt can be in any static dir
if requestPath == "/favicon.ico" || requestPath == "/robots.txt" {
file := path.Join(".", requestPath)
if fi, _ := os.Stat(file); fi != nil {
return file, fi, nil
}
for _, staticDir := range BConfig.WebConfig.StaticDir {
filePath := path.Join(staticDir, requestPath)
if fi, _ := os.Stat(filePath); fi != nil {
return filePath, fi, nil
}
}
return "", nil, errNotStaticRequest
}
for prefix, staticDir := range BConfig.WebConfig.StaticDir {
if !strings.Contains(requestPath, prefix) {
continue
}
if prefix != "/" && len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' {
continue
}
filePath := path.Join(staticDir, requestPath[len(prefix):])
if fi, err := os.Stat(filePath); fi != nil {
return filePath, fi, err
}
}
return "", nil, errNotStaticRequest
}
// lookupFile find the file to serve
// if the file is dir ,search the index.html as default file( MUST NOT A DIR also)
// if the index.html not exist or is a dir, give a forbidden response depending on DirectoryIndex
func lookupFile(ctx *context.Context) (bool, string, os.FileInfo, error) {
fp, fi, err := searchFile(ctx)
if fp == "" || fi == nil {
return false, "", nil, err
}
if !fi.IsDir() {
return false, fp, fi, err
}
if requestURL := ctx.Input.URL(); requestURL[len(requestURL)-1] == '/' {
ifp := filepath.Join(fp, "index.html")
if ifi, _ := os.Stat(ifp); ifi != nil && ifi.Mode().IsRegular() {
return false, ifp, ifi, err
}
}
return !BConfig.WebConfig.DirectoryIndex, fp, fi, err
}

View File

@ -0,0 +1,99 @@
package web
import (
"bytes"
"compress/gzip"
"compress/zlib"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"testing"
)
var currentWorkDir, _ = os.Getwd()
var licenseFile = filepath.Join(currentWorkDir, "LICENSE")
func testOpenFile(encoding string, content []byte, t *testing.T) {
fi, _ := os.Stat(licenseFile)
b, n, sch, reader, err := openFile(licenseFile, fi, encoding)
if err != nil {
t.Log(err)
t.Fail()
}
t.Log("open static file encoding "+n, b)
assetOpenFileAndContent(sch, reader, content, t)
}
func TestOpenStaticFile_1(t *testing.T) {
file, _ := os.Open(licenseFile)
content, _ := ioutil.ReadAll(file)
testOpenFile("", content, t)
}
func TestOpenStaticFileGzip_1(t *testing.T) {
file, _ := os.Open(licenseFile)
var zipBuf bytes.Buffer
fileWriter, _ := gzip.NewWriterLevel(&zipBuf, gzip.BestCompression)
io.Copy(fileWriter, file)
fileWriter.Close()
content, _ := ioutil.ReadAll(&zipBuf)
testOpenFile("gzip", content, t)
}
func TestOpenStaticFileDeflate_1(t *testing.T) {
file, _ := os.Open(licenseFile)
var zipBuf bytes.Buffer
fileWriter, _ := zlib.NewWriterLevel(&zipBuf, zlib.BestCompression)
io.Copy(fileWriter, file)
fileWriter.Close()
content, _ := ioutil.ReadAll(&zipBuf)
testOpenFile("deflate", content, t)
}
func TestStaticCacheWork(t *testing.T) {
encodings := []string{"", "gzip", "deflate"}
fi, _ := os.Stat(licenseFile)
for _, encoding := range encodings {
_, _, first, _, err := openFile(licenseFile, fi, encoding)
if err != nil {
t.Error(err)
continue
}
_, _, second, _, err := openFile(licenseFile, fi, encoding)
if err != nil {
t.Error(err)
continue
}
address1 := fmt.Sprintf("%p", first)
address2 := fmt.Sprintf("%p", second)
if address1 != address2 {
t.Errorf("encoding '%v' can not hit cache", encoding)
}
}
}
func assetOpenFileAndContent(sch *serveContentHolder, reader *serveContentReader, content []byte, t *testing.T) {
t.Log(sch.size, len(content))
if sch.size != int64(len(content)) {
t.Log("static content file size not same")
t.Fail()
}
bs, _ := ioutil.ReadAll(reader)
for i, v := range content {
if v != bs[i] {
t.Log("content not same")
t.Fail()
}
}
if staticFileLruCache.Len() == 0 {
t.Log("men map is empty")
t.Fail()
}
}

151
server/web/statistics.go Normal file
View File

@ -0,0 +1,151 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"fmt"
"sync"
"time"
"github.com/astaxie/beego/core/utils"
)
// Statistics struct
type Statistics struct {
RequestURL string
RequestController string
RequestNum int64
MinTime time.Duration
MaxTime time.Duration
TotalTime time.Duration
}
// URLMap contains several statistics struct to log different data
type URLMap struct {
lock sync.RWMutex
LengthLimit int //limit the urlmap's length if it's equal to 0 there's no limit
urlmap map[string]map[string]*Statistics
}
// AddStatistics add statistics task.
// it needs request method, request url, request controller and statistics time duration
func (m *URLMap) AddStatistics(requestMethod, requestURL, requestController string, requesttime time.Duration) {
m.lock.Lock()
defer m.lock.Unlock()
if method, ok := m.urlmap[requestURL]; ok {
if s, ok := method[requestMethod]; ok {
s.RequestNum++
if s.MaxTime < requesttime {
s.MaxTime = requesttime
}
if s.MinTime > requesttime {
s.MinTime = requesttime
}
s.TotalTime += requesttime
} else {
nb := &Statistics{
RequestURL: requestURL,
RequestController: requestController,
RequestNum: 1,
MinTime: requesttime,
MaxTime: requesttime,
TotalTime: requesttime,
}
m.urlmap[requestURL][requestMethod] = nb
}
} else {
if m.LengthLimit > 0 && m.LengthLimit <= len(m.urlmap) {
return
}
methodmap := make(map[string]*Statistics)
nb := &Statistics{
RequestURL: requestURL,
RequestController: requestController,
RequestNum: 1,
MinTime: requesttime,
MaxTime: requesttime,
TotalTime: requesttime,
}
methodmap[requestMethod] = nb
m.urlmap[requestURL] = methodmap
}
}
// GetMap put url statistics result in io.Writer
func (m *URLMap) GetMap() map[string]interface{} {
m.lock.RLock()
defer m.lock.RUnlock()
var fields = []string{"requestUrl", "method", "times", "used", "max used", "min used", "avg used"}
var resultLists [][]string
content := make(map[string]interface{})
content["Fields"] = fields
for k, v := range m.urlmap {
for kk, vv := range v {
result := []string{
fmt.Sprintf("% -50s", k),
fmt.Sprintf("% -10s", kk),
fmt.Sprintf("% -16d", vv.RequestNum),
fmt.Sprintf("%d", vv.TotalTime),
fmt.Sprintf("% -16s", utils.ToShortTimeFormat(vv.TotalTime)),
fmt.Sprintf("%d", vv.MaxTime),
fmt.Sprintf("% -16s", utils.ToShortTimeFormat(vv.MaxTime)),
fmt.Sprintf("%d", vv.MinTime),
fmt.Sprintf("% -16s", utils.ToShortTimeFormat(vv.MinTime)),
fmt.Sprintf("%d", time.Duration(int64(vv.TotalTime)/vv.RequestNum)),
fmt.Sprintf("% -16s", utils.ToShortTimeFormat(time.Duration(int64(vv.TotalTime)/vv.RequestNum))),
}
resultLists = append(resultLists, result)
}
}
content["Data"] = resultLists
return content
}
// GetMapData return all mapdata
func (m *URLMap) GetMapData() []map[string]interface{} {
m.lock.RLock()
defer m.lock.RUnlock()
var resultLists []map[string]interface{}
for k, v := range m.urlmap {
for kk, vv := range v {
result := map[string]interface{}{
"request_url": k,
"method": kk,
"times": vv.RequestNum,
"total_time": utils.ToShortTimeFormat(vv.TotalTime),
"max_time": utils.ToShortTimeFormat(vv.MaxTime),
"min_time": utils.ToShortTimeFormat(vv.MinTime),
"avg_time": utils.ToShortTimeFormat(time.Duration(int64(vv.TotalTime) / vv.RequestNum)),
}
resultLists = append(resultLists, result)
}
}
return resultLists
}
// StatisticsMap hosld global statistics data map
var StatisticsMap *URLMap
func init() {
StatisticsMap = &URLMap{
urlmap: make(map[string]map[string]*Statistics),
}
}

View File

@ -0,0 +1,40 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"encoding/json"
"testing"
"time"
)
func TestStatics(t *testing.T) {
StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(2000))
StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(120000))
StatisticsMap.AddStatistics("GET", "/api/user", "&admin.user", time.Duration(13000))
StatisticsMap.AddStatistics("POST", "/api/admin", "&admin.user", time.Duration(14000))
StatisticsMap.AddStatistics("POST", "/api/user/astaxie", "&admin.user", time.Duration(12000))
StatisticsMap.AddStatistics("POST", "/api/user/xiemengjun", "&admin.user", time.Duration(13000))
StatisticsMap.AddStatistics("DELETE", "/api/user", "&admin.user", time.Duration(1400))
t.Log(StatisticsMap.GetMap())
data := StatisticsMap.GetMapData()
b, err := json.Marshal(data)
if err != nil {
t.Errorf(err.Error())
}
t.Log(string(b))
}

View File

@ -0,0 +1,174 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Swagger™ is a project used to describe and document RESTful APIs.
//
// The Swagger specification defines a set of files required to describe such an API. These files can then be used by the Swagger-UI project to display the API and Swagger-Codegen to generate clients in various languages. Additional utilities can also take advantage of the resulting files, such as testing tools.
// Now in version 2.0, Swagger is more enabling than ever. And it's 100% open source software.
// Package swagger struct definition
package swagger
// Swagger list the resource
type Swagger struct {
SwaggerVersion string `json:"swagger,omitempty" yaml:"swagger,omitempty"`
Infos Information `json:"info" yaml:"info"`
Host string `json:"host,omitempty" yaml:"host,omitempty"`
BasePath string `json:"basePath,omitempty" yaml:"basePath,omitempty"`
Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"`
Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"`
Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"`
Paths map[string]*Item `json:"paths" yaml:"paths"`
Definitions map[string]Schema `json:"definitions,omitempty" yaml:"definitions,omitempty"`
SecurityDefinitions map[string]Security `json:"securityDefinitions,omitempty" yaml:"securityDefinitions,omitempty"`
Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"`
Tags []Tag `json:"tags,omitempty" yaml:"tags,omitempty"`
ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"`
}
// Information Provides metadata about the API. The metadata can be used by the clients if needed.
type Information struct {
Title string `json:"title,omitempty" yaml:"title,omitempty"`
Description string `json:"description,omitempty" yaml:"description,omitempty"`
Version string `json:"version,omitempty" yaml:"version,omitempty"`
TermsOfService string `json:"termsOfService,omitempty" yaml:"termsOfService,omitempty"`
Contact Contact `json:"contact,omitempty" yaml:"contact,omitempty"`
License *License `json:"license,omitempty" yaml:"license,omitempty"`
}
// Contact information for the exposed API.
type Contact struct {
Name string `json:"name,omitempty" yaml:"name,omitempty"`
URL string `json:"url,omitempty" yaml:"url,omitempty"`
EMail string `json:"email,omitempty" yaml:"email,omitempty"`
}
// License information for the exposed API.
type License struct {
Name string `json:"name,omitempty" yaml:"name,omitempty"`
URL string `json:"url,omitempty" yaml:"url,omitempty"`
}
// Item Describes the operations available on a single path.
type Item struct {
Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"`
Get *Operation `json:"get,omitempty" yaml:"get,omitempty"`
Put *Operation `json:"put,omitempty" yaml:"put,omitempty"`
Post *Operation `json:"post,omitempty" yaml:"post,omitempty"`
Delete *Operation `json:"delete,omitempty" yaml:"delete,omitempty"`
Options *Operation `json:"options,omitempty" yaml:"options,omitempty"`
Head *Operation `json:"head,omitempty" yaml:"head,omitempty"`
Patch *Operation `json:"patch,omitempty" yaml:"patch,omitempty"`
}
// Operation Describes a single API operation on a path.
type Operation struct {
Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"`
Summary string `json:"summary,omitempty" yaml:"summary,omitempty"`
Description string `json:"description,omitempty" yaml:"description,omitempty"`
OperationID string `json:"operationId,omitempty" yaml:"operationId,omitempty"`
Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"`
Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"`
Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"`
Parameters []Parameter `json:"parameters,omitempty" yaml:"parameters,omitempty"`
Responses map[string]Response `json:"responses,omitempty" yaml:"responses,omitempty"`
Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"`
Deprecated bool `json:"deprecated,omitempty" yaml:"deprecated,omitempty"`
}
// Parameter Describes a single operation parameter.
type Parameter struct {
In string `json:"in,omitempty" yaml:"in,omitempty"`
Name string `json:"name,omitempty" yaml:"name,omitempty"`
Description string `json:"description,omitempty" yaml:"description,omitempty"`
Required bool `json:"required,omitempty" yaml:"required,omitempty"`
Schema *Schema `json:"schema,omitempty" yaml:"schema,omitempty"`
Type string `json:"type,omitempty" yaml:"type,omitempty"`
Format string `json:"format,omitempty" yaml:"format,omitempty"`
Items *ParameterItems `json:"items,omitempty" yaml:"items,omitempty"`
Default interface{} `json:"default,omitempty" yaml:"default,omitempty"`
}
// ParameterItems A limited subset of JSON-Schema's items object. It is used by parameter definitions that are not located in "body".
// http://swagger.io/specification/#itemsObject
type ParameterItems struct {
Type string `json:"type,omitempty" yaml:"type,omitempty"`
Format string `json:"format,omitempty" yaml:"format,omitempty"`
Items []*ParameterItems `json:"items,omitempty" yaml:"items,omitempty"` //Required if type is "array". Describes the type of items in the array.
CollectionFormat string `json:"collectionFormat,omitempty" yaml:"collectionFormat,omitempty"`
Default string `json:"default,omitempty" yaml:"default,omitempty"`
}
// Schema Object allows the definition of input and output data types.
type Schema struct {
Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"`
Title string `json:"title,omitempty" yaml:"title,omitempty"`
Format string `json:"format,omitempty" yaml:"format,omitempty"`
Description string `json:"description,omitempty" yaml:"description,omitempty"`
Required []string `json:"required,omitempty" yaml:"required,omitempty"`
Type string `json:"type,omitempty" yaml:"type,omitempty"`
Items *Schema `json:"items,omitempty" yaml:"items,omitempty"`
Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"`
Enum []interface{} `json:"enum,omitempty" yaml:"enum,omitempty"`
Example interface{} `json:"example,omitempty" yaml:"example,omitempty"`
}
// Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification
type Propertie struct {
Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"`
Title string `json:"title,omitempty" yaml:"title,omitempty"`
Description string `json:"description,omitempty" yaml:"description,omitempty"`
Default interface{} `json:"default,omitempty" yaml:"default,omitempty"`
Type string `json:"type,omitempty" yaml:"type,omitempty"`
Example interface{} `json:"example,omitempty" yaml:"example,omitempty"`
Required []string `json:"required,omitempty" yaml:"required,omitempty"`
Format string `json:"format,omitempty" yaml:"format,omitempty"`
ReadOnly bool `json:"readOnly,omitempty" yaml:"readOnly,omitempty"`
Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"`
Items *Propertie `json:"items,omitempty" yaml:"items,omitempty"`
AdditionalProperties *Propertie `json:"additionalProperties,omitempty" yaml:"additionalProperties,omitempty"`
}
// Response as they are returned from executing this operation.
type Response struct {
Description string `json:"description" yaml:"description"`
Schema *Schema `json:"schema,omitempty" yaml:"schema,omitempty"`
Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"`
}
// Security Allows the definition of a security scheme that can be used by the operations
type Security struct {
Type string `json:"type,omitempty" yaml:"type,omitempty"` // Valid values are "basic", "apiKey" or "oauth2".
Description string `json:"description,omitempty" yaml:"description,omitempty"`
Name string `json:"name,omitempty" yaml:"name,omitempty"`
In string `json:"in,omitempty" yaml:"in,omitempty"` // Valid values are "query" or "header".
Flow string `json:"flow,omitempty" yaml:"flow,omitempty"` // Valid values are "implicit", "password", "application" or "accessCode".
AuthorizationURL string `json:"authorizationUrl,omitempty" yaml:"authorizationUrl,omitempty"`
TokenURL string `json:"tokenUrl,omitempty" yaml:"tokenUrl,omitempty"`
Scopes map[string]string `json:"scopes,omitempty" yaml:"scopes,omitempty"` // The available scopes for the OAuth2 security scheme.
}
// Tag Allows adding meta data to a single tag that is used by the Operation Object
type Tag struct {
Name string `json:"name,omitempty" yaml:"name,omitempty"`
Description string `json:"description,omitempty" yaml:"description,omitempty"`
ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"`
}
// ExternalDocs include Additional external documentation
type ExternalDocs struct {
Description string `json:"description,omitempty" yaml:"description,omitempty"`
URL string `json:"url,omitempty" yaml:"url,omitempty"`
}

406
server/web/template.go Normal file
View File

@ -0,0 +1,406 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"errors"
"fmt"
"html/template"
"io"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"github.com/astaxie/beego/core/logs"
"github.com/astaxie/beego/core/utils"
)
var (
beegoTplFuncMap = make(template.FuncMap)
beeViewPathTemplateLocked = false
// beeViewPathTemplates caching map and supported template file extensions per view
beeViewPathTemplates = make(map[string]map[string]*template.Template)
templatesLock sync.RWMutex
// beeTemplateExt stores the template extension which will build
beeTemplateExt = []string{"tpl", "html", "gohtml"}
// beeTemplatePreprocessors stores associations of extension -> preprocessor handler
beeTemplateEngines = map[string]templatePreProcessor{}
beeTemplateFS = defaultFSFunc
)
// ExecuteTemplate applies the template with name to the specified data object,
// writing the output to wr.
// A template will be executed safely in parallel.
func ExecuteTemplate(wr io.Writer, name string, data interface{}) error {
return ExecuteViewPathTemplate(wr, name, BConfig.WebConfig.ViewsPath, data)
}
// ExecuteViewPathTemplate applies the template with name and from specific viewPath to the specified data object,
// writing the output to wr.
// A template will be executed safely in parallel.
func ExecuteViewPathTemplate(wr io.Writer, name string, viewPath string, data interface{}) error {
if BConfig.RunMode == DEV {
templatesLock.RLock()
defer templatesLock.RUnlock()
}
if beeTemplates, ok := beeViewPathTemplates[viewPath]; ok {
if t, ok := beeTemplates[name]; ok {
var err error
if t.Lookup(name) != nil {
err = t.ExecuteTemplate(wr, name, data)
} else {
err = t.Execute(wr, data)
}
if err != nil {
logs.Trace("template Execute err:", err)
}
return err
}
panic("can't find templatefile in the path:" + viewPath + "/" + name)
}
panic("Unknown view path:" + viewPath)
}
func init() {
beegoTplFuncMap["dateformat"] = DateFormat
beegoTplFuncMap["date"] = Date
beegoTplFuncMap["compare"] = Compare
beegoTplFuncMap["compare_not"] = CompareNot
beegoTplFuncMap["not_nil"] = NotNil
beegoTplFuncMap["not_null"] = NotNil
beegoTplFuncMap["substr"] = Substr
beegoTplFuncMap["html2str"] = HTML2str
beegoTplFuncMap["str2html"] = Str2html
beegoTplFuncMap["htmlquote"] = Htmlquote
beegoTplFuncMap["htmlunquote"] = Htmlunquote
beegoTplFuncMap["renderform"] = RenderForm
beegoTplFuncMap["assets_js"] = AssetsJs
beegoTplFuncMap["assets_css"] = AssetsCSS
beegoTplFuncMap["config"] = GetConfig
beegoTplFuncMap["map_get"] = MapGet
// Comparisons
beegoTplFuncMap["eq"] = eq // ==
beegoTplFuncMap["ge"] = ge // >=
beegoTplFuncMap["gt"] = gt // >
beegoTplFuncMap["le"] = le // <=
beegoTplFuncMap["lt"] = lt // <
beegoTplFuncMap["ne"] = ne // !=
beegoTplFuncMap["urlfor"] = URLFor // build a URL to match a Controller and it's method
}
// AddFuncMap let user to register a func in the template.
func AddFuncMap(key string, fn interface{}) error {
beegoTplFuncMap[key] = fn
return nil
}
type templatePreProcessor func(root, path string, funcs template.FuncMap) (*template.Template, error)
type templateFile struct {
root string
files map[string][]string
}
// visit will make the paths into two part,the first is subDir (without tf.root),the second is full path(without tf.root).
// if tf.root="views" and
// paths is "views/errors/404.html",the subDir will be "errors",the file will be "errors/404.html"
// paths is "views/admin/errors/404.html",the subDir will be "admin/errors",the file will be "admin/errors/404.html"
func (tf *templateFile) visit(paths string, f os.FileInfo, err error) error {
if f == nil {
return err
}
if f.IsDir() || (f.Mode()&os.ModeSymlink) > 0 {
return nil
}
if !HasTemplateExt(paths) {
return nil
}
replace := strings.NewReplacer("\\", "/")
file := strings.TrimLeft(replace.Replace(paths[len(tf.root):]), "/")
subDir := filepath.Dir(file)
tf.files[subDir] = append(tf.files[subDir], file)
return nil
}
// HasTemplateExt return this path contains supported template extension of beego or not.
func HasTemplateExt(paths string) bool {
for _, v := range beeTemplateExt {
if strings.HasSuffix(paths, "."+v) {
return true
}
}
return false
}
// AddTemplateExt add new extension for template.
func AddTemplateExt(ext string) {
for _, v := range beeTemplateExt {
if v == ext {
return
}
}
beeTemplateExt = append(beeTemplateExt, ext)
}
// AddViewPath adds a new path to the supported view paths.
//Can later be used by setting a controller ViewPath to this folder
//will panic if called after beego.Run()
func AddViewPath(viewPath string) error {
if beeViewPathTemplateLocked {
if _, exist := beeViewPathTemplates[viewPath]; exist {
return nil //Ignore if viewpath already exists
}
panic("Can not add new view paths after beego.Run()")
}
beeViewPathTemplates[viewPath] = make(map[string]*template.Template)
return BuildTemplate(viewPath)
}
func lockViewPaths() {
beeViewPathTemplateLocked = true
}
// BuildTemplate will build all template files in a directory.
// it makes beego can render any template file in view directory.
func BuildTemplate(dir string, files ...string) error {
var err error
fs := beeTemplateFS()
f, err := fs.Open(dir)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return errors.New("dir open err")
}
defer f.Close()
beeTemplates, ok := beeViewPathTemplates[dir]
if !ok {
panic("Unknown view path: " + dir)
}
self := &templateFile{
root: dir,
files: make(map[string][]string),
}
err = Walk(fs, dir, func(path string, f os.FileInfo, err error) error {
return self.visit(path, f, err)
})
if err != nil {
fmt.Printf("Walk() returned %v\n", err)
return err
}
buildAllFiles := len(files) == 0
for _, v := range self.files {
for _, file := range v {
if buildAllFiles || utils.InSlice(file, files) {
templatesLock.Lock()
ext := filepath.Ext(file)
var t *template.Template
if len(ext) == 0 {
t, err = getTemplate(self.root, fs, file, v...)
} else if fn, ok := beeTemplateEngines[ext[1:]]; ok {
t, err = fn(self.root, file, beegoTplFuncMap)
} else {
t, err = getTemplate(self.root, fs, file, v...)
}
if err != nil {
logs.Error("parse template err:", file, err)
templatesLock.Unlock()
return err
}
beeTemplates[file] = t
templatesLock.Unlock()
}
}
}
return nil
}
func getTplDeep(root string, fs http.FileSystem, file string, parent string, t *template.Template) (*template.Template, [][]string, error) {
var fileAbsPath string
var rParent string
var err error
if strings.HasPrefix(file, "../") {
rParent = filepath.Join(filepath.Dir(parent), file)
fileAbsPath = filepath.Join(root, filepath.Dir(parent), file)
} else {
rParent = file
fileAbsPath = filepath.Join(root, file)
}
f, err := fs.Open(fileAbsPath)
if err != nil {
panic("can't find template file:" + file)
}
defer f.Close()
data, err := ioutil.ReadAll(f)
if err != nil {
return nil, [][]string{}, err
}
t, err = t.New(file).Parse(string(data))
if err != nil {
return nil, [][]string{}, err
}
reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*template[ ]+\"([^\"]+)\"")
allSub := reg.FindAllStringSubmatch(string(data), -1)
for _, m := range allSub {
if len(m) == 2 {
tl := t.Lookup(m[1])
if tl != nil {
continue
}
if !HasTemplateExt(m[1]) {
continue
}
_, _, err = getTplDeep(root, fs, m[1], rParent, t)
if err != nil {
return nil, [][]string{}, err
}
}
}
return t, allSub, nil
}
func getTemplate(root string, fs http.FileSystem, file string, others ...string) (t *template.Template, err error) {
t = template.New(file).Delims(BConfig.WebConfig.TemplateLeft, BConfig.WebConfig.TemplateRight).Funcs(beegoTplFuncMap)
var subMods [][]string
t, subMods, err = getTplDeep(root, fs, file, "", t)
if err != nil {
return nil, err
}
t, err = _getTemplate(t, root, fs, subMods, others...)
if err != nil {
return nil, err
}
return
}
func _getTemplate(t0 *template.Template, root string, fs http.FileSystem, subMods [][]string, others ...string) (t *template.Template, err error) {
t = t0
for _, m := range subMods {
if len(m) == 2 {
tpl := t.Lookup(m[1])
if tpl != nil {
continue
}
//first check filename
for _, otherFile := range others {
if otherFile == m[1] {
var subMods1 [][]string
t, subMods1, err = getTplDeep(root, fs, otherFile, "", t)
if err != nil {
logs.Trace("template parse file err:", err)
} else if len(subMods1) > 0 {
t, err = _getTemplate(t, root, fs, subMods1, others...)
}
break
}
}
//second check define
for _, otherFile := range others {
var data []byte
fileAbsPath := filepath.Join(root, otherFile)
f, err := fs.Open(fileAbsPath)
if err != nil {
f.Close()
logs.Trace("template file parse error, not success open file:", err)
continue
}
data, err = ioutil.ReadAll(f)
f.Close()
if err != nil {
logs.Trace("template file parse error, not success read file:", err)
continue
}
reg := regexp.MustCompile(BConfig.WebConfig.TemplateLeft + "[ ]*define[ ]+\"([^\"]+)\"")
allSub := reg.FindAllStringSubmatch(string(data), -1)
for _, sub := range allSub {
if len(sub) == 2 && sub[1] == m[1] {
var subMods1 [][]string
t, subMods1, err = getTplDeep(root, fs, otherFile, "", t)
if err != nil {
logs.Trace("template parse file err:", err)
} else if len(subMods1) > 0 {
t, err = _getTemplate(t, root, fs, subMods1, others...)
if err != nil {
logs.Trace("template parse file err:", err)
}
}
break
}
}
}
}
}
return
}
type templateFSFunc func() http.FileSystem
func defaultFSFunc() http.FileSystem {
return FileSystem{}
}
// SetTemplateFSFunc set default filesystem function
func SetTemplateFSFunc(fnt templateFSFunc) {
beeTemplateFS = fnt
}
// SetViewsPath sets view directory path in beego application.
func SetViewsPath(path string) *HttpServer {
BConfig.WebConfig.ViewsPath = path
return BeeApp
}
// SetStaticPath sets static directory path and proper url pattern in beego application.
// if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public".
func SetStaticPath(url string, path string) *HttpServer {
if !strings.HasPrefix(url, "/") {
url = "/" + url
}
if url != "/" {
url = strings.TrimRight(url, "/")
}
BConfig.WebConfig.StaticDir[url] = path
return BeeApp
}
// DelStaticPath removes the static folder setting in this url pattern in beego application.
func DelStaticPath(url string) *HttpServer {
if !strings.HasPrefix(url, "/") {
url = "/" + url
}
if url != "/" {
url = strings.TrimRight(url, "/")
}
delete(BConfig.WebConfig.StaticDir, url)
return BeeApp
}
// AddTemplateEngine add a new templatePreProcessor which support extension
func AddTemplateEngine(extension string, fn templatePreProcessor) *HttpServer {
AddTemplateExt(extension)
beeTemplateEngines[extension] = fn
return BeeApp
}

333
server/web/template_test.go Normal file
View File

@ -0,0 +1,333 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"bytes"
"net/http"
"os"
"path/filepath"
"testing"
assetfs "github.com/elazarl/go-bindata-assetfs"
"github.com/stretchr/testify/assert"
"github.com/astaxie/beego/test"
)
var header = `{{define "header"}}
<h1>Hello, astaxie!</h1>
{{end}}`
var index = `<!DOCTYPE html>
<html>
<head>
<title>beego welcome template</title>
</head>
<body>
{{template "block"}}
{{template "header"}}
{{template "blocks/block.tpl"}}
</body>
</html>
`
var block = `{{define "block"}}
<h1>Hello, blocks!</h1>
{{end}}`
func TestTemplate(t *testing.T) {
wkdir, err := os.Getwd()
assert.Nil(t, err)
dir := filepath.Join(wkdir, "_beeTmp", "TestTemplate")
files := []string{
"header.tpl",
"index.tpl",
"blocks/block.tpl",
}
if err := os.MkdirAll(dir, 0777); err != nil {
t.Fatal(err)
}
for k, name := range files {
dirErr := os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777)
assert.Nil(t, dirErr)
if f, err := os.Create(filepath.Join(dir, name)); err != nil {
t.Fatal(err)
} else {
if k == 0 {
f.WriteString(header)
} else if k == 1 {
f.WriteString(index)
} else if k == 2 {
f.WriteString(block)
}
f.Close()
}
}
if err := AddViewPath(dir); err != nil {
t.Fatal(err)
}
beeTemplates := beeViewPathTemplates[dir]
if len(beeTemplates) != 3 {
t.Fatalf("should be 3 but got %v", len(beeTemplates))
}
if err := beeTemplates["index.tpl"].ExecuteTemplate(os.Stdout, "index.tpl", nil); err != nil {
t.Fatal(err)
}
for _, name := range files {
os.RemoveAll(filepath.Join(dir, name))
}
os.RemoveAll(dir)
}
var menu = `<div class="menu">
<ul>
<li>menu1</li>
<li>menu2</li>
<li>menu3</li>
</ul>
</div>
`
var user = `<!DOCTYPE html>
<html>
<head>
<title>beego welcome template</title>
</head>
<body>
{{template "../public/menu.tpl"}}
</body>
</html>
`
func TestRelativeTemplate(t *testing.T) {
wkdir, err := os.Getwd()
assert.Nil(t, err)
dir := filepath.Join(wkdir, "_beeTmp")
//Just add dir to known viewPaths
if err := AddViewPath(dir); err != nil {
t.Fatal(err)
}
files := []string{
"easyui/public/menu.tpl",
"easyui/rbac/user.tpl",
}
if err := os.MkdirAll(dir, 0777); err != nil {
t.Fatal(err)
}
for k, name := range files {
os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777)
if f, err := os.Create(filepath.Join(dir, name)); err != nil {
t.Fatal(err)
} else {
if k == 0 {
f.WriteString(menu)
} else if k == 1 {
f.WriteString(user)
}
f.Close()
}
}
if err := BuildTemplate(dir, files[1]); err != nil {
t.Fatal(err)
}
beeTemplates := beeViewPathTemplates[dir]
if err := beeTemplates["easyui/rbac/user.tpl"].ExecuteTemplate(os.Stdout, "easyui/rbac/user.tpl", nil); err != nil {
t.Fatal(err)
}
for _, name := range files {
os.RemoveAll(filepath.Join(dir, name))
}
os.RemoveAll(dir)
}
var add = `{{ template "layout_blog.tpl" . }}
{{ define "css" }}
<link rel="stylesheet" href="/static/css/current.css">
{{ end}}
{{ define "content" }}
<h2>{{ .Title }}</h2>
<p> This is SomeVar: {{ .SomeVar }}</p>
{{ end }}
{{ define "js" }}
<script src="/static/js/current.js"></script>
{{ end}}`
var layoutBlog = `<!DOCTYPE html>
<html>
<head>
<title>Lin Li</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta http-equiv="Content-Type" content="text/html; charset=utf-8">
<link rel="stylesheet" href="http://netdna.bootstrapcdn.com/bootstrap/3.0.3/css/bootstrap.min.css">
<link rel="stylesheet" href="http://netdna.bootstrapcdn.com/bootstrap/3.0.3/css/bootstrap-theme.min.css">
{{ block "css" . }}{{ end }}
</head>
<body>
<div class="container">
{{ block "content" . }}{{ end }}
</div>
<script type="text/javascript" src="http://code.jquery.com/jquery-2.0.3.min.js"></script>
<script src="http://netdna.bootstrapcdn.com/bootstrap/3.0.3/js/bootstrap.min.js"></script>
{{ block "js" . }}{{ end }}
</body>
</html>`
var output = `<!DOCTYPE html>
<html>
<head>
<title>Lin Li</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta http-equiv="Content-Type" content="text/html; charset=utf-8">
<link rel="stylesheet" href="http://netdna.bootstrapcdn.com/bootstrap/3.0.3/css/bootstrap.min.css">
<link rel="stylesheet" href="http://netdna.bootstrapcdn.com/bootstrap/3.0.3/css/bootstrap-theme.min.css">
<link rel="stylesheet" href="/static/css/current.css">
</head>
<body>
<div class="container">
<h2>Hello</h2>
<p> This is SomeVar: val</p>
</div>
<script type="text/javascript" src="http://code.jquery.com/jquery-2.0.3.min.js"></script>
<script src="http://netdna.bootstrapcdn.com/bootstrap/3.0.3/js/bootstrap.min.js"></script>
<script src="/static/js/current.js"></script>
</body>
</html>
`
func TestTemplateLayout(t *testing.T) {
wkdir, err := os.Getwd()
assert.Nil(t, err)
dir := filepath.Join(wkdir, "_beeTmp", "TestTemplateLayout")
files := []string{
"add.tpl",
"layout_blog.tpl",
}
if err := os.MkdirAll(dir, 0777); err != nil {
t.Fatal(err)
}
for k, name := range files {
dirErr := os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777)
assert.Nil(t, dirErr)
if f, err := os.Create(filepath.Join(dir, name)); err != nil {
t.Fatal(err)
} else {
if k == 0 {
_, writeErr := f.WriteString(add)
assert.Nil(t, writeErr)
} else if k == 1 {
_, writeErr := f.WriteString(layoutBlog)
assert.Nil(t, writeErr)
}
clErr := f.Close()
assert.Nil(t, clErr)
}
}
if err := AddViewPath(dir); err != nil {
t.Fatal(err)
}
beeTemplates := beeViewPathTemplates[dir]
if len(beeTemplates) != 2 {
t.Fatalf("should be 2 but got %v", len(beeTemplates))
}
out := bytes.NewBufferString("")
if err := beeTemplates["add.tpl"].ExecuteTemplate(out, "add.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil {
t.Fatal(err)
}
if out.String() != output {
t.Log(out.String())
t.Fatal("Compare failed")
}
for _, name := range files {
os.RemoveAll(filepath.Join(dir, name))
}
os.RemoveAll(dir)
}
type TestingFileSystem struct {
assetfs *assetfs.AssetFS
}
func (d TestingFileSystem) Open(name string) (http.File, error) {
return d.assetfs.Open(name)
}
var outputBinData = `<!DOCTYPE html>
<html>
<head>
<title>beego welcome template</title>
</head>
<body>
<h1>Hello, blocks!</h1>
<h1>Hello, astaxie!</h1>
<h2>Hello</h2>
<p> This is SomeVar: val</p>
</body>
</html>
`
func TestFsBinData(t *testing.T) {
SetTemplateFSFunc(func() http.FileSystem {
return TestingFileSystem{&assetfs.AssetFS{Asset: test.Asset, AssetDir: test.AssetDir, AssetInfo: test.AssetInfo}}
})
dir := "views"
if err := AddViewPath("views"); err != nil {
t.Fatal(err)
}
beeTemplates := beeViewPathTemplates[dir]
if len(beeTemplates) != 3 {
t.Fatalf("should be 3 but got %v", len(beeTemplates))
}
if err := beeTemplates["index.tpl"].ExecuteTemplate(os.Stdout, "index.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil {
t.Fatal(err)
}
out := bytes.NewBufferString("")
if err := beeTemplates["index.tpl"].ExecuteTemplate(out, "index.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil {
t.Fatal(err)
}
if out.String() != outputBinData {
t.Log(out.String())
t.Fatal("Compare failed")
}
}

781
server/web/templatefunc.go Normal file
View File

@ -0,0 +1,781 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"context"
"errors"
"fmt"
"html"
"html/template"
"net/url"
"reflect"
"regexp"
"strconv"
"strings"
"time"
)
const (
formatTime = "15:04:05"
formatDate = "2006-01-02"
formatDateTime = "2006-01-02 15:04:05"
formatDateTimeT = "2006-01-02T15:04:05"
)
// Substr returns the substr from start to length.
func Substr(s string, start, length int) string {
bt := []rune(s)
if start < 0 {
start = 0
}
if start > len(bt) {
start = start % len(bt)
}
var end int
if (start + length) > (len(bt) - 1) {
end = len(bt)
} else {
end = start + length
}
return string(bt[start:end])
}
// HTML2str returns escaping text convert from html.
func HTML2str(html string) string {
re := regexp.MustCompile(`\<[\S\s]+?\>`)
html = re.ReplaceAllStringFunc(html, strings.ToLower)
// remove STYLE
re = regexp.MustCompile(`\<style[\S\s]+?\</style\>`)
html = re.ReplaceAllString(html, "")
// remove SCRIPT
re = regexp.MustCompile(`\<script[\S\s]+?\</script\>`)
html = re.ReplaceAllString(html, "")
re = regexp.MustCompile(`\<[\S\s]+?\>`)
html = re.ReplaceAllString(html, "\n")
re = regexp.MustCompile(`\s{2,}`)
html = re.ReplaceAllString(html, "\n")
return strings.TrimSpace(html)
}
// DateFormat takes a time and a layout string and returns a string with the formatted date. Used by the template parser as "dateformat"
func DateFormat(t time.Time, layout string) (datestring string) {
datestring = t.Format(layout)
return
}
// DateFormat pattern rules.
var datePatterns = []string{
// year
"Y", "2006", // A full numeric representation of a year, 4 digits Examples: 1999 or 2003
"y", "06", // A two digit representation of a year Examples: 99 or 03
// month
"m", "01", // Numeric representation of a month, with leading zeros 01 through 12
"n", "1", // Numeric representation of a month, without leading zeros 1 through 12
"M", "Jan", // A short textual representation of a month, three letters Jan through Dec
"F", "January", // A full textual representation of a month, such as January or March January through December
// day
"d", "02", // Day of the month, 2 digits with leading zeros 01 to 31
"j", "2", // Day of the month without leading zeros 1 to 31
// week
"D", "Mon", // A textual representation of a day, three letters Mon through Sun
"l", "Monday", // A full textual representation of the day of the week Sunday through Saturday
// time
"g", "3", // 12-hour format of an hour without leading zeros 1 through 12
"G", "15", // 24-hour format of an hour without leading zeros 0 through 23
"h", "03", // 12-hour format of an hour with leading zeros 01 through 12
"H", "15", // 24-hour format of an hour with leading zeros 00 through 23
"a", "pm", // Lowercase Ante meridiem and Post meridiem am or pm
"A", "PM", // Uppercase Ante meridiem and Post meridiem AM or PM
"i", "04", // Minutes with leading zeros 00 to 59
"s", "05", // Seconds, with leading zeros 00 through 59
// time zone
"T", "MST",
"P", "-07:00",
"O", "-0700",
// RFC 2822
"r", time.RFC1123Z,
}
// DateParse Parse Date use PHP time format.
func DateParse(dateString, format string) (time.Time, error) {
replacer := strings.NewReplacer(datePatterns...)
format = replacer.Replace(format)
return time.ParseInLocation(format, dateString, time.Local)
}
// Date takes a PHP like date func to Go's time format.
func Date(t time.Time, format string) string {
replacer := strings.NewReplacer(datePatterns...)
format = replacer.Replace(format)
return t.Format(format)
}
// Compare is a quick and dirty comparison function. It will convert whatever you give it to strings and see if the two values are equal.
// Whitespace is trimmed. Used by the template parser as "eq".
func Compare(a, b interface{}) (equal bool) {
equal = false
if strings.TrimSpace(fmt.Sprintf("%v", a)) == strings.TrimSpace(fmt.Sprintf("%v", b)) {
equal = true
}
return
}
// CompareNot !Compare
func CompareNot(a, b interface{}) (equal bool) {
return !Compare(a, b)
}
// NotNil the same as CompareNot
func NotNil(a interface{}) (isNil bool) {
return CompareNot(a, nil)
}
// GetConfig get the Appconfig
func GetConfig(returnType, key string, defaultVal interface{}) (value interface{}, err error) {
switch returnType {
case "String":
value, err = AppConfig.String(context.Background(), key)
case "Bool":
value, err = AppConfig.Bool(context.Background(), key)
case "Int":
value, err = AppConfig.Int(context.Background(), key)
case "Int64":
value, err = AppConfig.Int64(context.Background(), key)
case "Float":
value, err = AppConfig.Float(context.Background(), key)
case "DIY":
value, err = AppConfig.DIY(context.Background(), key)
default:
err = errors.New("config keys must be of type String, Bool, Int, Int64, Float, or DIY")
}
if err != nil {
if reflect.TypeOf(returnType) != reflect.TypeOf(defaultVal) {
err = errors.New("defaultVal type does not match returnType")
} else {
value, err = defaultVal, nil
}
} else if reflect.TypeOf(value).Kind() == reflect.String {
if value == "" {
if reflect.TypeOf(defaultVal).Kind() != reflect.String {
err = errors.New("defaultVal type must be a String if the returnType is a String")
} else {
value = defaultVal.(string)
}
}
}
return
}
// Str2html Convert string to template.HTML type.
func Str2html(raw string) template.HTML {
return template.HTML(raw)
}
// Htmlquote returns quoted html string.
func Htmlquote(text string) string {
// HTML编码为实体符号
/*
Encodes `text` for raw use in HTML.
>>> htmlquote("<'&\\">")
'&lt;&#39;&amp;&quot;&gt;'
*/
text = html.EscapeString(text)
text = strings.NewReplacer(
``, "&ldquo;",
``, "&rdquo;",
` `, "&nbsp;",
).Replace(text)
return strings.TrimSpace(text)
}
// Htmlunquote returns unquoted html string.
func Htmlunquote(text string) string {
// 实体符号解释为HTML
/*
Decodes `text` that's HTML quoted.
>>> htmlunquote('&lt;&#39;&amp;&quot;&gt;')
'<\\'&">'
*/
text = html.UnescapeString(text)
return strings.TrimSpace(text)
}
// URLFor returns url string with another registered controller handler with params.
// usage:
//
// URLFor(".index")
// print URLFor("index")
// router /login
// print URLFor("login")
// print URLFor("login", "next","/"")
// router /profile/:username
// print UrlFor("profile", ":username","John Doe")
// result:
// /
// /login
// /login?next=/
// /user/John%20Doe
//
// more detail http://beego.me/docs/mvc/controller/urlbuilding.md
func URLFor(endpoint string, values ...interface{}) string {
return BeeApp.Handlers.URLFor(endpoint, values...)
}
// AssetsJs returns script tag with src string.
func AssetsJs(text string) template.HTML {
text = "<script src=\"" + text + "\"></script>"
return template.HTML(text)
}
// AssetsCSS returns stylesheet link tag with src string.
func AssetsCSS(text string) template.HTML {
text = "<link href=\"" + text + "\" rel=\"stylesheet\" />"
return template.HTML(text)
}
// ParseForm will parse form values to struct via tag.
// Support for anonymous struct.
func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) error {
for i := 0; i < objT.NumField(); i++ {
fieldV := objV.Field(i)
if !fieldV.CanSet() {
continue
}
fieldT := objT.Field(i)
if fieldT.Anonymous && fieldT.Type.Kind() == reflect.Struct {
err := parseFormToStruct(form, fieldT.Type, fieldV)
if err != nil {
return err
}
continue
}
tags := strings.Split(fieldT.Tag.Get("form"), ",")
var tag string
if len(tags) == 0 || len(tags[0]) == 0 {
tag = fieldT.Name
} else if tags[0] == "-" {
continue
} else {
tag = tags[0]
}
formValues := form[tag]
var value string
if len(formValues) == 0 {
defaultValue := fieldT.Tag.Get("default")
if defaultValue != "" {
value = defaultValue
} else {
continue
}
}
if len(formValues) == 1 {
value = formValues[0]
if value == "" {
continue
}
}
switch fieldT.Type.Kind() {
case reflect.Bool:
if strings.ToLower(value) == "on" || strings.ToLower(value) == "1" || strings.ToLower(value) == "yes" {
fieldV.SetBool(true)
continue
}
if strings.ToLower(value) == "off" || strings.ToLower(value) == "0" || strings.ToLower(value) == "no" {
fieldV.SetBool(false)
continue
}
b, err := strconv.ParseBool(value)
if err != nil {
return err
}
fieldV.SetBool(b)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
x, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return err
}
fieldV.SetInt(x)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
x, err := strconv.ParseUint(value, 10, 64)
if err != nil {
return err
}
fieldV.SetUint(x)
case reflect.Float32, reflect.Float64:
x, err := strconv.ParseFloat(value, 64)
if err != nil {
return err
}
fieldV.SetFloat(x)
case reflect.Interface:
fieldV.Set(reflect.ValueOf(value))
case reflect.String:
fieldV.SetString(value)
case reflect.Struct:
switch fieldT.Type.String() {
case "time.Time":
var (
t time.Time
err error
)
if len(value) >= 25 {
value = value[:25]
t, err = time.ParseInLocation(time.RFC3339, value, time.Local)
} else if strings.HasSuffix(strings.ToUpper(value), "Z") {
t, err = time.ParseInLocation(time.RFC3339, value, time.Local)
} else if len(value) >= 19 {
if strings.Contains(value, "T") {
value = value[:19]
t, err = time.ParseInLocation(formatDateTimeT, value, time.Local)
} else {
value = value[:19]
t, err = time.ParseInLocation(formatDateTime, value, time.Local)
}
} else if len(value) >= 10 {
if len(value) > 10 {
value = value[:10]
}
t, err = time.ParseInLocation(formatDate, value, time.Local)
} else if len(value) >= 8 {
if len(value) > 8 {
value = value[:8]
}
t, err = time.ParseInLocation(formatTime, value, time.Local)
}
if err != nil {
return err
}
fieldV.Set(reflect.ValueOf(t))
}
case reflect.Slice:
if fieldT.Type == sliceOfInts {
formVals := form[tag]
fieldV.Set(reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(int(1))), len(formVals), len(formVals)))
for i := 0; i < len(formVals); i++ {
val, err := strconv.Atoi(formVals[i])
if err != nil {
return err
}
fieldV.Index(i).SetInt(int64(val))
}
} else if fieldT.Type == sliceOfStrings {
formVals := form[tag]
fieldV.Set(reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf("")), len(formVals), len(formVals)))
for i := 0; i < len(formVals); i++ {
fieldV.Index(i).SetString(formVals[i])
}
}
}
}
return nil
}
// ParseForm will parse form values to struct via tag.
func ParseForm(form url.Values, obj interface{}) error {
objT := reflect.TypeOf(obj)
objV := reflect.ValueOf(obj)
if !isStructPtr(objT) {
return fmt.Errorf("%v must be a struct pointer", obj)
}
objT = objT.Elem()
objV = objV.Elem()
return parseFormToStruct(form, objT, objV)
}
var sliceOfInts = reflect.TypeOf([]int(nil))
var sliceOfStrings = reflect.TypeOf([]string(nil))
var unKind = map[reflect.Kind]bool{
reflect.Uintptr: true,
reflect.Complex64: true,
reflect.Complex128: true,
reflect.Array: true,
reflect.Chan: true,
reflect.Func: true,
reflect.Map: true,
reflect.Ptr: true,
reflect.Slice: true,
reflect.Struct: true,
reflect.UnsafePointer: true,
}
// RenderForm will render object to form html.
// obj must be a struct pointer.
func RenderForm(obj interface{}) template.HTML {
objT := reflect.TypeOf(obj)
objV := reflect.ValueOf(obj)
if !isStructPtr(objT) {
return template.HTML("")
}
objT = objT.Elem()
objV = objV.Elem()
var raw []string
for i := 0; i < objT.NumField(); i++ {
fieldV := objV.Field(i)
if !fieldV.CanSet() || unKind[fieldV.Kind()] {
continue
}
fieldT := objT.Field(i)
label, name, fType, id, class, ignored, required := parseFormTag(fieldT)
if ignored {
continue
}
raw = append(raw, renderFormField(label, name, fType, fieldV.Interface(), id, class, required))
}
return template.HTML(strings.Join(raw, "</br>"))
}
// renderFormField returns a string containing HTML of a single form field.
func renderFormField(label, name, fType string, value interface{}, id string, class string, required bool) string {
if id != "" {
id = " id=\"" + id + "\""
}
if class != "" {
class = " class=\"" + class + "\""
}
requiredString := ""
if required {
requiredString = " required"
}
if isValidForInput(fType) {
return fmt.Sprintf(`%v<input%v%v name="%v" type="%v" value="%v"%v>`, label, id, class, name, fType, value, requiredString)
}
return fmt.Sprintf(`%v<%v%v%v name="%v"%v>%v</%v>`, label, fType, id, class, name, requiredString, value, fType)
}
// isValidForInput checks if fType is a valid value for the `type` property of an HTML input element.
func isValidForInput(fType string) bool {
validInputTypes := strings.Fields("text password checkbox radio submit reset hidden image file button search email url tel number range date month week time datetime datetime-local color")
for _, validType := range validInputTypes {
if fType == validType {
return true
}
}
return false
}
// parseFormTag takes the stuct-tag of a StructField and parses the `form` value.
// returned are the form label, name-property, type and wether the field should be ignored.
func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id string, class string, ignored bool, required bool) {
tags := strings.Split(fieldT.Tag.Get("form"), ",")
label = fieldT.Name + ": "
name = fieldT.Name
fType = "text"
ignored = false
id = fieldT.Tag.Get("id")
class = fieldT.Tag.Get("class")
required = false
requiredField := fieldT.Tag.Get("required")
if requiredField != "-" && requiredField != "" {
required, _ = strconv.ParseBool(requiredField)
}
switch len(tags) {
case 1:
if tags[0] == "-" {
ignored = true
}
if len(tags[0]) > 0 {
name = tags[0]
}
case 2:
if len(tags[0]) > 0 {
name = tags[0]
}
if len(tags[1]) > 0 {
fType = tags[1]
}
case 3:
if len(tags[0]) > 0 {
name = tags[0]
}
if len(tags[1]) > 0 {
fType = tags[1]
}
if len(tags[2]) > 0 {
label = tags[2]
}
}
return
}
func isStructPtr(t reflect.Type) bool {
return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct
}
// go1.2 added template funcs. begin
var (
errBadComparisonType = errors.New("invalid type for comparison")
errBadComparison = errors.New("incompatible types for comparison")
errNoComparison = errors.New("missing argument for comparison")
)
type kind int
const (
invalidKind kind = iota
boolKind
complexKind
intKind
floatKind
stringKind
uintKind
)
func basicKind(v reflect.Value) (kind, error) {
switch v.Kind() {
case reflect.Bool:
return boolKind, nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return intKind, nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return uintKind, nil
case reflect.Float32, reflect.Float64:
return floatKind, nil
case reflect.Complex64, reflect.Complex128:
return complexKind, nil
case reflect.String:
return stringKind, nil
}
return invalidKind, errBadComparisonType
}
// eq evaluates the comparison a == b || a == c || ...
func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) {
v1 := reflect.ValueOf(arg1)
k1, err := basicKind(v1)
if err != nil {
return false, err
}
if len(arg2) == 0 {
return false, errNoComparison
}
for _, arg := range arg2 {
v2 := reflect.ValueOf(arg)
k2, err := basicKind(v2)
if err != nil {
return false, err
}
if k1 != k2 {
return false, errBadComparison
}
truth := false
switch k1 {
case boolKind:
truth = v1.Bool() == v2.Bool()
case complexKind:
truth = v1.Complex() == v2.Complex()
case floatKind:
truth = v1.Float() == v2.Float()
case intKind:
truth = v1.Int() == v2.Int()
case stringKind:
truth = v1.String() == v2.String()
case uintKind:
truth = v1.Uint() == v2.Uint()
default:
panic("invalid kind")
}
if truth {
return true, nil
}
}
return false, nil
}
// ne evaluates the comparison a != b.
func ne(arg1, arg2 interface{}) (bool, error) {
// != is the inverse of ==.
equal, err := eq(arg1, arg2)
return !equal, err
}
// lt evaluates the comparison a < b.
func lt(arg1, arg2 interface{}) (bool, error) {
v1 := reflect.ValueOf(arg1)
k1, err := basicKind(v1)
if err != nil {
return false, err
}
v2 := reflect.ValueOf(arg2)
k2, err := basicKind(v2)
if err != nil {
return false, err
}
if k1 != k2 {
return false, errBadComparison
}
truth := false
switch k1 {
case boolKind, complexKind:
return false, errBadComparisonType
case floatKind:
truth = v1.Float() < v2.Float()
case intKind:
truth = v1.Int() < v2.Int()
case stringKind:
truth = v1.String() < v2.String()
case uintKind:
truth = v1.Uint() < v2.Uint()
default:
panic("invalid kind")
}
return truth, nil
}
// le evaluates the comparison <= b.
func le(arg1, arg2 interface{}) (bool, error) {
// <= is < or ==.
lessThan, err := lt(arg1, arg2)
if lessThan || err != nil {
return lessThan, err
}
return eq(arg1, arg2)
}
// gt evaluates the comparison a > b.
func gt(arg1, arg2 interface{}) (bool, error) {
// > is the inverse of <=.
lessOrEqual, err := le(arg1, arg2)
if err != nil {
return false, err
}
return !lessOrEqual, nil
}
// ge evaluates the comparison a >= b.
func ge(arg1, arg2 interface{}) (bool, error) {
// >= is the inverse of <.
lessThan, err := lt(arg1, arg2)
if err != nil {
return false, err
}
return !lessThan, nil
}
// MapGet getting value from map by keys
// usage:
// Data["m"] = M{
// "a": 1,
// "1": map[string]float64{
// "c": 4,
// },
// }
//
// {{ map_get m "a" }} // return 1
// {{ map_get m 1 "c" }} // return 4
func MapGet(arg1 interface{}, arg2 ...interface{}) (interface{}, error) {
arg1Type := reflect.TypeOf(arg1)
arg1Val := reflect.ValueOf(arg1)
if arg1Type.Kind() == reflect.Map && len(arg2) > 0 {
// check whether arg2[0] type equals to arg1 key type
// if they are different, make conversion
arg2Val := reflect.ValueOf(arg2[0])
arg2Type := reflect.TypeOf(arg2[0])
if arg2Type.Kind() != arg1Type.Key().Kind() {
// convert arg2Value to string
var arg2ConvertedVal interface{}
arg2String := fmt.Sprintf("%v", arg2[0])
// convert string representation to any other type
switch arg1Type.Key().Kind() {
case reflect.Bool:
arg2ConvertedVal, _ = strconv.ParseBool(arg2String)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
arg2ConvertedVal, _ = strconv.ParseInt(arg2String, 0, 64)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
arg2ConvertedVal, _ = strconv.ParseUint(arg2String, 0, 64)
case reflect.Float32, reflect.Float64:
arg2ConvertedVal, _ = strconv.ParseFloat(arg2String, 64)
case reflect.String:
arg2ConvertedVal = arg2String
default:
arg2ConvertedVal = arg2Val.Interface()
}
arg2Val = reflect.ValueOf(arg2ConvertedVal)
}
storedVal := arg1Val.MapIndex(arg2Val)
if storedVal.IsValid() {
var result interface{}
switch arg1Type.Elem().Kind() {
case reflect.Bool:
result = storedVal.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
result = storedVal.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
result = storedVal.Uint()
case reflect.Float32, reflect.Float64:
result = storedVal.Float()
case reflect.String:
result = storedVal.String()
default:
result = storedVal.Interface()
}
// if there is more keys, handle this recursively
if len(arg2) > 1 {
return MapGet(result, arg2[1:]...)
}
return result, nil
}
return nil, nil
}
return nil, nil
}

View File

@ -0,0 +1,380 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"html/template"
"net/url"
"reflect"
"testing"
"time"
)
func TestSubstr(t *testing.T) {
s := `012345`
if Substr(s, 0, 2) != "01" {
t.Error("should be equal")
}
if Substr(s, 0, 100) != "012345" {
t.Error("should be equal")
}
if Substr(s, 12, 100) != "012345" {
t.Error("should be equal")
}
}
func TestHtml2str(t *testing.T) {
h := `<HTML><style></style><script>x<x</script></HTML><123> 123\n
\n`
if HTML2str(h) != "123\\n\n\\n" {
t.Error("should be equal")
}
}
func TestDateFormat(t *testing.T) {
ts := "Mon, 01 Jul 2013 13:27:42 CST"
tt, _ := time.Parse(time.RFC1123, ts)
if ss := DateFormat(tt, "2006-01-02 15:04:05"); ss != "2013-07-01 13:27:42" {
t.Errorf("2013-07-01 13:27:42 does not equal %v", ss)
}
}
func TestDate(t *testing.T) {
ts := "Mon, 01 Jul 2013 13:27:42 CST"
tt, _ := time.Parse(time.RFC1123, ts)
if ss := Date(tt, "Y-m-d H:i:s"); ss != "2013-07-01 13:27:42" {
t.Errorf("2013-07-01 13:27:42 does not equal %v", ss)
}
if ss := Date(tt, "y-n-j h:i:s A"); ss != "13-7-1 01:27:42 PM" {
t.Errorf("13-7-1 01:27:42 PM does not equal %v", ss)
}
if ss := Date(tt, "D, d M Y g:i:s a"); ss != "Mon, 01 Jul 2013 1:27:42 pm" {
t.Errorf("Mon, 01 Jul 2013 1:27:42 pm does not equal %v", ss)
}
if ss := Date(tt, "l, d F Y G:i:s"); ss != "Monday, 01 July 2013 13:27:42" {
t.Errorf("Monday, 01 July 2013 13:27:42 does not equal %v", ss)
}
}
func TestCompareRelated(t *testing.T) {
if !Compare("abc", "abc") {
t.Error("should be equal")
}
if Compare("abc", "aBc") {
t.Error("should be not equal")
}
if !Compare("1", 1) {
t.Error("should be equal")
}
if CompareNot("abc", "abc") {
t.Error("should be equal")
}
if !CompareNot("abc", "aBc") {
t.Error("should be not equal")
}
if !NotNil("a string") {
t.Error("should not be nil")
}
}
func TestHtmlquote(t *testing.T) {
h := `&lt;&#39;&nbsp;&rdquo;&ldquo;&amp;&#34;&gt;`
s := `<' ”“&">`
if Htmlquote(s) != h {
t.Error("should be equal")
}
}
func TestHtmlunquote(t *testing.T) {
h := `&lt;&#39;&nbsp;&rdquo;&ldquo;&amp;&#34;&gt;`
s := `<' ”“&">`
if Htmlunquote(h) != s {
t.Error("should be equal")
}
}
func TestParseForm(t *testing.T) {
type ExtendInfo struct {
Hobby []string `form:"hobby"`
Memo string
}
type OtherInfo struct {
Organization string `form:"organization"`
Title string `form:"title"`
ExtendInfo
}
type user struct {
ID int `form:"-"`
tag string `form:"tag"`
Name interface{} `form:"username"`
Age int `form:"age,text"`
Email string
Intro string `form:",textarea"`
StrBool bool `form:"strbool"`
Date time.Time `form:"date,2006-01-02"`
OtherInfo
}
u := user{}
form := url.Values{
"ID": []string{"1"},
"-": []string{"1"},
"tag": []string{"no"},
"username": []string{"test"},
"age": []string{"40"},
"Email": []string{"test@gmail.com"},
"Intro": []string{"I am an engineer!"},
"strbool": []string{"yes"},
"date": []string{"2014-11-12"},
"organization": []string{"beego"},
"title": []string{"CXO"},
"hobby": []string{"", "Basketball", "Football"},
"memo": []string{"nothing"},
}
if err := ParseForm(form, u); err == nil {
t.Fatal("nothing will be changed")
}
if err := ParseForm(form, &u); err != nil {
t.Fatal(err)
}
if u.ID != 0 {
t.Errorf("ID should equal 0 but got %v", u.ID)
}
if len(u.tag) != 0 {
t.Errorf("tag's length should equal 0 but got %v", len(u.tag))
}
if u.Name.(string) != "test" {
t.Errorf("Name should equal `test` but got `%v`", u.Name.(string))
}
if u.Age != 40 {
t.Errorf("Age should equal 40 but got %v", u.Age)
}
if u.Email != "test@gmail.com" {
t.Errorf("Email should equal `test@gmail.com` but got `%v`", u.Email)
}
if u.Intro != "I am an engineer!" {
t.Errorf("Intro should equal `I am an engineer!` but got `%v`", u.Intro)
}
if !u.StrBool {
t.Errorf("strboll should equal `true`, but got `%v`", u.StrBool)
}
y, m, d := u.Date.Date()
if y != 2014 || m.String() != "November" || d != 12 {
t.Errorf("Date should equal `2014-11-12`, but got `%v`", u.Date.String())
}
if u.Organization != "beego" {
t.Errorf("Organization should equal `beego`, but got `%v`", u.Organization)
}
if u.Title != "CXO" {
t.Errorf("Title should equal `CXO`, but got `%v`", u.Title)
}
if u.Hobby[0] != "" {
t.Errorf("Hobby should equal ``, but got `%v`", u.Hobby[0])
}
if u.Hobby[1] != "Basketball" {
t.Errorf("Hobby should equal `Basketball`, but got `%v`", u.Hobby[1])
}
if u.Hobby[2] != "Football" {
t.Errorf("Hobby should equal `Football`, but got `%v`", u.Hobby[2])
}
if len(u.Memo) != 0 {
t.Errorf("Memo's length should equal 0 but got %v", len(u.Memo))
}
}
func TestRenderForm(t *testing.T) {
type user struct {
ID int `form:"-"`
Name interface{} `form:"username"`
Age int `form:"age,text,年龄:"`
Sex string
Email []string
Intro string `form:",textarea"`
Ignored string `form:"-"`
}
u := user{Name: "test", Intro: "Some Text"}
output := RenderForm(u)
if output != template.HTML("") {
t.Errorf("output should be empty but got %v", output)
}
output = RenderForm(&u)
result := template.HTML(
`Name: <input name="username" type="text" value="test"></br>` +
`年龄:<input name="age" type="text" value="0"></br>` +
`Sex: <input name="Sex" type="text" value=""></br>` +
`Intro: <textarea name="Intro">Some Text</textarea>`)
if output != result {
t.Errorf("output should equal `%v` but got `%v`", result, output)
}
}
func TestRenderFormField(t *testing.T) {
html := renderFormField("Label: ", "Name", "text", "Value", "", "", false)
if html != `Label: <input name="Name" type="text" value="Value">` {
t.Errorf("Wrong html output for input[type=text]: %v ", html)
}
html = renderFormField("Label: ", "Name", "textarea", "Value", "", "", false)
if html != `Label: <textarea name="Name">Value</textarea>` {
t.Errorf("Wrong html output for textarea: %v ", html)
}
html = renderFormField("Label: ", "Name", "textarea", "Value", "", "", true)
if html != `Label: <textarea name="Name" required>Value</textarea>` {
t.Errorf("Wrong html output for textarea: %v ", html)
}
}
func TestParseFormTag(t *testing.T) {
// create struct to contain field with different types of struct-tag `form`
type user struct {
All int `form:"name,text,年龄:"`
NoName int `form:",hidden,年龄:"`
OnlyLabel int `form:",,年龄:"`
OnlyName int `form:"name" id:"name" class:"form-name"`
Ignored int `form:"-"`
Required int `form:"name" required:"true"`
IgnoreRequired int `form:"name"`
NotRequired int `form:"name" required:"false"`
}
objT := reflect.TypeOf(&user{}).Elem()
label, name, fType, _, _, ignored, _ := parseFormTag(objT.Field(0))
if !(name == "name" && label == "年龄:" && fType == "text" && !ignored) {
t.Errorf("Form Tag with name, label and type was not correctly parsed.")
}
label, name, fType, _, _, ignored, _ = parseFormTag(objT.Field(1))
if !(name == "NoName" && label == "年龄:" && fType == "hidden" && !ignored) {
t.Errorf("Form Tag with label and type but without name was not correctly parsed.")
}
label, name, fType, _, _, ignored, _ = parseFormTag(objT.Field(2))
if !(name == "OnlyLabel" && label == "年龄:" && fType == "text" && !ignored) {
t.Errorf("Form Tag containing only label was not correctly parsed.")
}
label, name, fType, id, class, ignored, _ := parseFormTag(objT.Field(3))
if !(name == "name" && label == "OnlyName: " && fType == "text" && !ignored &&
id == "name" && class == "form-name") {
t.Errorf("Form Tag containing only name was not correctly parsed.")
}
_, _, _, _, _, ignored, _ = parseFormTag(objT.Field(4))
if !ignored {
t.Errorf("Form Tag that should be ignored was not correctly parsed.")
}
_, name, _, _, _, _, required := parseFormTag(objT.Field(5))
if !(name == "name" && required) {
t.Errorf("Form Tag containing only name and required was not correctly parsed.")
}
_, name, _, _, _, _, required = parseFormTag(objT.Field(6))
if !(name == "name" && !required) {
t.Errorf("Form Tag containing only name and ignore required was not correctly parsed.")
}
_, name, _, _, _, _, required = parseFormTag(objT.Field(7))
if !(name == "name" && !required) {
t.Errorf("Form Tag containing only name and not required was not correctly parsed.")
}
}
func TestMapGet(t *testing.T) {
// test one level map
m1 := map[string]int64{
"a": 1,
"1": 2,
}
if res, err := MapGet(m1, "a"); err == nil {
if res.(int64) != 1 {
t.Errorf("Should return 1, but return %v", res)
}
} else {
t.Errorf("Error happens %v", err)
}
if res, err := MapGet(m1, "1"); err == nil {
if res.(int64) != 2 {
t.Errorf("Should return 2, but return %v", res)
}
} else {
t.Errorf("Error happens %v", err)
}
if res, err := MapGet(m1, 1); err == nil {
if res.(int64) != 2 {
t.Errorf("Should return 2, but return %v", res)
}
} else {
t.Errorf("Error happens %v", err)
}
// test 2 level map
m2 := M{
"1": map[string]float64{
"2": 3.5,
},
}
if res, err := MapGet(m2, 1, 2); err == nil {
if res.(float64) != 3.5 {
t.Errorf("Should return 3.5, but return %v", res)
}
} else {
t.Errorf("Error happens %v", err)
}
// test 5 level map
m5 := M{
"1": M{
"2": M{
"3": M{
"4": M{
"5": 1.2,
},
},
},
},
}
if res, err := MapGet(m5, 1, 2, 3, 4, 5); err == nil {
if res.(float64) != 1.2 {
t.Errorf("Should return 1.2, but return %v", res)
}
} else {
t.Errorf("Error happens %v", err)
}
// check whether element not exists in map
if res, err := MapGet(m5, 5, 4, 3, 2, 1); err == nil {
if res != nil {
t.Errorf("Should return nil, but return %v", res)
}
} else {
t.Errorf("Error happens %v", err)
}
}

586
server/web/tree.go Normal file
View File

@ -0,0 +1,586 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"path"
"regexp"
"strings"
"github.com/astaxie/beego/core/utils"
"github.com/astaxie/beego/server/web/context"
)
var (
allowSuffixExt = []string{".json", ".xml", ".html"}
)
// Tree has three elements: FixRouter/wildcard/leaves
// fixRouter stores Fixed Router
// wildcard stores params
// leaves store the endpoint information
type Tree struct {
// prefix set for static router
prefix string
// search fix route first
fixrouters []*Tree
// if set, failure to match fixrouters search then search wildcard
wildcard *Tree
// if set, failure to match wildcard search
leaves []*leafInfo
}
// NewTree return a new Tree
func NewTree() *Tree {
return &Tree{}
}
// AddTree will add tree to the exist Tree
// prefix should has no params
func (t *Tree) AddTree(prefix string, tree *Tree) {
t.addtree(splitPath(prefix), tree, nil, "")
}
func (t *Tree) addtree(segments []string, tree *Tree, wildcards []string, reg string) {
if len(segments) == 0 {
panic("prefix should has path")
}
seg := segments[0]
iswild, params, regexpStr := splitSegment(seg)
// if it's ? meaning can igone this, so add one more rule for it
if len(params) > 0 && params[0] == ":" {
params = params[1:]
if len(segments[1:]) > 0 {
t.addtree(segments[1:], tree, append(wildcards, params...), reg)
} else {
filterTreeWithPrefix(tree, wildcards, reg)
}
}
// Rule: /login/*/access match /login/2009/11/access
// if already has *, and when loop the access, should as a regexpStr
if !iswild && utils.InSlice(":splat", wildcards) {
iswild = true
regexpStr = seg
}
// Rule: /user/:id/*
if seg == "*" && len(wildcards) > 0 && reg == "" {
regexpStr = "(.+)"
}
if len(segments) == 1 {
if iswild {
if regexpStr != "" {
if reg == "" {
rr := ""
for _, w := range wildcards {
if w == ":splat" {
rr = rr + "(.+)/"
} else {
rr = rr + "([^/]+)/"
}
}
regexpStr = rr + regexpStr
} else {
regexpStr = "/" + regexpStr
}
} else if reg != "" {
if seg == "*.*" {
regexpStr = "([^.]+).(.+)"
} else {
for _, w := range params {
if w == "." || w == ":" {
continue
}
regexpStr = "([^/]+)/" + regexpStr
}
}
}
reg = strings.Trim(reg+"/"+regexpStr, "/")
filterTreeWithPrefix(tree, append(wildcards, params...), reg)
t.wildcard = tree
} else {
reg = strings.Trim(reg+"/"+regexpStr, "/")
filterTreeWithPrefix(tree, append(wildcards, params...), reg)
tree.prefix = seg
t.fixrouters = append(t.fixrouters, tree)
}
return
}
if iswild {
if t.wildcard == nil {
t.wildcard = NewTree()
}
if regexpStr != "" {
if reg == "" {
rr := ""
for _, w := range wildcards {
if w == ":splat" {
rr = rr + "(.+)/"
} else {
rr = rr + "([^/]+)/"
}
}
regexpStr = rr + regexpStr
} else {
regexpStr = "/" + regexpStr
}
} else if reg != "" {
if seg == "*.*" {
regexpStr = "([^.]+).(.+)"
params = params[1:]
} else {
for range params {
regexpStr = "([^/]+)/" + regexpStr
}
}
} else {
if seg == "*.*" {
params = params[1:]
}
}
reg = strings.TrimRight(strings.TrimRight(reg, "/")+"/"+regexpStr, "/")
t.wildcard.addtree(segments[1:], tree, append(wildcards, params...), reg)
} else {
subTree := NewTree()
subTree.prefix = seg
t.fixrouters = append(t.fixrouters, subTree)
subTree.addtree(segments[1:], tree, append(wildcards, params...), reg)
}
}
func filterTreeWithPrefix(t *Tree, wildcards []string, reg string) {
for _, v := range t.fixrouters {
filterTreeWithPrefix(v, wildcards, reg)
}
if t.wildcard != nil {
filterTreeWithPrefix(t.wildcard, wildcards, reg)
}
for _, l := range t.leaves {
if reg != "" {
if l.regexps != nil {
l.wildcards = append(wildcards, l.wildcards...)
l.regexps = regexp.MustCompile("^" + reg + "/" + strings.Trim(l.regexps.String(), "^$") + "$")
} else {
for _, v := range l.wildcards {
if v == ":splat" {
reg = reg + "/(.+)"
} else {
reg = reg + "/([^/]+)"
}
}
l.regexps = regexp.MustCompile("^" + reg + "$")
l.wildcards = append(wildcards, l.wildcards...)
}
} else {
l.wildcards = append(wildcards, l.wildcards...)
if l.regexps != nil {
for _, w := range wildcards {
if w == ":splat" {
reg = "(.+)/" + reg
} else {
reg = "([^/]+)/" + reg
}
}
l.regexps = regexp.MustCompile("^" + reg + strings.Trim(l.regexps.String(), "^$") + "$")
}
}
}
}
// AddRouter call addseg function
func (t *Tree) AddRouter(pattern string, runObject interface{}) {
t.addseg(splitPath(pattern), runObject, nil, "")
}
// "/"
// "admin" ->
func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, reg string) {
if len(segments) == 0 {
if reg != "" {
t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards, regexps: regexp.MustCompile("^" + reg + "$")})
} else {
t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards})
}
} else {
seg := segments[0]
iswild, params, regexpStr := splitSegment(seg)
// if it's ? meaning can igone this, so add one more rule for it
if len(params) > 0 && params[0] == ":" {
t.addseg(segments[1:], route, wildcards, reg)
params = params[1:]
}
// Rule: /login/*/access match /login/2009/11/access
// if already has *, and when loop the access, should as a regexpStr
if !iswild && utils.InSlice(":splat", wildcards) {
iswild = true
regexpStr = seg
}
// Rule: /user/:id/*
if seg == "*" && len(wildcards) > 0 && reg == "" {
regexpStr = "(.+)"
}
if iswild {
if t.wildcard == nil {
t.wildcard = NewTree()
}
if regexpStr != "" {
if reg == "" {
rr := ""
for _, w := range wildcards {
if w == ":splat" {
rr = rr + "(.+)/"
} else {
rr = rr + "([^/]+)/"
}
}
regexpStr = rr + regexpStr
} else {
regexpStr = "/" + regexpStr
}
} else if reg != "" {
if seg == "*.*" {
regexpStr = "/([^.]+).(.+)"
params = params[1:]
} else {
for range params {
regexpStr = "/([^/]+)" + regexpStr
}
}
} else {
if seg == "*.*" {
params = params[1:]
}
}
t.wildcard.addseg(segments[1:], route, append(wildcards, params...), reg+regexpStr)
} else {
var subTree *Tree
for _, sub := range t.fixrouters {
if sub.prefix == seg {
subTree = sub
break
}
}
if subTree == nil {
subTree = NewTree()
subTree.prefix = seg
t.fixrouters = append(t.fixrouters, subTree)
}
subTree.addseg(segments[1:], route, wildcards, reg)
}
}
}
// Match router to runObject & params
func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{}) {
if len(pattern) == 0 || pattern[0] != '/' {
return nil
}
w := make([]string, 0, 20)
return t.match(pattern[1:], pattern, w, ctx)
}
func (t *Tree) match(treePattern string, pattern string, wildcardValues []string, ctx *context.Context) (runObject interface{}) {
if len(pattern) > 0 {
i := 0
for ; i < len(pattern) && pattern[i] == '/'; i++ {
}
pattern = pattern[i:]
}
// Handle leaf nodes:
if len(pattern) == 0 {
for _, l := range t.leaves {
if ok := l.match(treePattern, wildcardValues, ctx); ok {
return l.runObject
}
}
if t.wildcard != nil {
for _, l := range t.wildcard.leaves {
if ok := l.match(treePattern, wildcardValues, ctx); ok {
return l.runObject
}
}
}
return nil
}
var seg string
i, l := 0, len(pattern)
for ; i < l && pattern[i] != '/'; i++ {
}
if i == 0 {
seg = pattern
pattern = ""
} else {
seg = pattern[:i]
pattern = pattern[i:]
}
for _, subTree := range t.fixrouters {
if subTree.prefix == seg {
if len(pattern) != 0 && pattern[0] == '/' {
treePattern = pattern[1:]
} else {
treePattern = pattern
}
runObject = subTree.match(treePattern, pattern, wildcardValues, ctx)
if runObject != nil {
break
}
}
}
if runObject == nil && len(t.fixrouters) > 0 {
// Filter the .json .xml .html extension
for _, str := range allowSuffixExt {
if strings.HasSuffix(seg, str) {
for _, subTree := range t.fixrouters {
if subTree.prefix == seg[:len(seg)-len(str)] {
runObject = subTree.match(treePattern, pattern, wildcardValues, ctx)
if runObject != nil {
ctx.Input.SetParam(":ext", str[1:])
}
}
}
}
}
}
if runObject == nil && t.wildcard != nil {
runObject = t.wildcard.match(treePattern, pattern, append(wildcardValues, seg), ctx)
}
if runObject == nil && len(t.leaves) > 0 {
wildcardValues = append(wildcardValues, seg)
start, i := 0, 0
for ; i < len(pattern); i++ {
if pattern[i] == '/' {
if i != 0 && start < len(pattern) {
wildcardValues = append(wildcardValues, pattern[start:i])
}
start = i + 1
continue
}
}
if start > 0 {
wildcardValues = append(wildcardValues, pattern[start:i])
}
for _, l := range t.leaves {
if ok := l.match(treePattern, wildcardValues, ctx); ok {
return l.runObject
}
}
}
return runObject
}
type leafInfo struct {
// names of wildcards that lead to this leaf. eg, ["id" "name"] for the wildcard ":id" and ":name"
wildcards []string
// if the leaf is regexp
regexps *regexp.Regexp
runObject interface{}
}
func (leaf *leafInfo) match(treePattern string, wildcardValues []string, ctx *context.Context) (ok bool) {
// fmt.Println("Leaf:", wildcardValues, leaf.wildcards, leaf.regexps)
if leaf.regexps == nil {
if len(wildcardValues) == 0 && len(leaf.wildcards) == 0 { // static path
return true
}
// match *
if len(leaf.wildcards) == 1 && leaf.wildcards[0] == ":splat" {
ctx.Input.SetParam(":splat", treePattern)
return true
}
// match *.* or :id
if len(leaf.wildcards) >= 2 && leaf.wildcards[len(leaf.wildcards)-2] == ":path" && leaf.wildcards[len(leaf.wildcards)-1] == ":ext" {
if len(leaf.wildcards) == 2 {
lastone := wildcardValues[len(wildcardValues)-1]
strs := strings.SplitN(lastone, ".", 2)
if len(strs) == 2 {
ctx.Input.SetParam(":ext", strs[1])
}
ctx.Input.SetParam(":path", path.Join(path.Join(wildcardValues[:len(wildcardValues)-1]...), strs[0]))
return true
} else if len(wildcardValues) < 2 {
return false
}
var index int
for index = 0; index < len(leaf.wildcards)-2; index++ {
ctx.Input.SetParam(leaf.wildcards[index], wildcardValues[index])
}
lastone := wildcardValues[len(wildcardValues)-1]
strs := strings.SplitN(lastone, ".", 2)
if len(strs) == 2 {
ctx.Input.SetParam(":ext", strs[1])
}
if index > (len(wildcardValues) - 1) {
ctx.Input.SetParam(":path", "")
} else {
ctx.Input.SetParam(":path", path.Join(path.Join(wildcardValues[index:len(wildcardValues)-1]...), strs[0]))
}
return true
}
// match :id
if len(leaf.wildcards) != len(wildcardValues) {
return false
}
for j, v := range leaf.wildcards {
ctx.Input.SetParam(v, wildcardValues[j])
}
return true
}
if !leaf.regexps.MatchString(path.Join(wildcardValues...)) {
return false
}
matches := leaf.regexps.FindStringSubmatch(path.Join(wildcardValues...))
for i, match := range matches[1:] {
if i < len(leaf.wildcards) {
ctx.Input.SetParam(leaf.wildcards[i], match)
}
}
return true
}
// "/" -> []
// "/admin" -> ["admin"]
// "/admin/" -> ["admin"]
// "/admin/users" -> ["admin", "users"]
func splitPath(key string) []string {
key = strings.Trim(key, "/ ")
if key == "" {
return []string{}
}
return strings.Split(key, "/")
}
// "admin" -> false, nil, ""
// ":id" -> true, [:id], ""
// "?:id" -> true, [: :id], "" : meaning can empty
// ":id:int" -> true, [:id], ([0-9]+)
// ":name:string" -> true, [:name], ([\w]+)
// ":id([0-9]+)" -> true, [:id], ([0-9]+)
// ":id([0-9]+)_:name" -> true, [:id :name], ([0-9]+)_(.+)
// "cms_:id_:page.html" -> true, [:id_ :page], cms_(.+)(.+).html
// "cms_:id(.+)_:page.html" -> true, [:id :page], cms_(.+)_(.+).html
// "*" -> true, [:splat], ""
// "*.*" -> true,[. :path :ext], "" . meaning separator
func splitSegment(key string) (bool, []string, string) {
if strings.HasPrefix(key, "*") {
if key == "*.*" {
return true, []string{".", ":path", ":ext"}, ""
}
return true, []string{":splat"}, ""
}
if strings.ContainsAny(key, ":") {
var paramsNum int
var out []rune
var start bool
var startexp bool
var param []rune
var expt []rune
var skipnum int
params := []string{}
reg := regexp.MustCompile(`[a-zA-Z0-9_]+`)
for i, v := range key {
if skipnum > 0 {
skipnum--
continue
}
if start {
// :id:int and :name:string
if v == ':' {
if len(key) >= i+4 {
if key[i+1:i+4] == "int" {
out = append(out, []rune("([0-9]+)")...)
params = append(params, ":"+string(param))
start = false
startexp = false
skipnum = 3
param = make([]rune, 0)
paramsNum++
continue
}
}
if len(key) >= i+7 {
if key[i+1:i+7] == "string" {
out = append(out, []rune(`([\w]+)`)...)
params = append(params, ":"+string(param))
paramsNum++
start = false
startexp = false
skipnum = 6
param = make([]rune, 0)
continue
}
}
}
// params only support a-zA-Z0-9
if reg.MatchString(string(v)) {
param = append(param, v)
continue
}
if v != '(' {
out = append(out, []rune(`(.+)`)...)
params = append(params, ":"+string(param))
param = make([]rune, 0)
paramsNum++
start = false
startexp = false
}
}
if startexp {
if v != ')' {
expt = append(expt, v)
continue
}
}
// Escape Sequence '\'
if i > 0 && key[i-1] == '\\' {
out = append(out, v)
} else if v == ':' {
param = make([]rune, 0)
start = true
} else if v == '(' {
startexp = true
start = false
if len(param) > 0 {
params = append(params, ":"+string(param))
param = make([]rune, 0)
}
paramsNum++
expt = make([]rune, 0)
expt = append(expt, '(')
} else if v == ')' {
startexp = false
expt = append(expt, ')')
out = append(out, expt...)
param = make([]rune, 0)
} else if v == '?' {
params = append(params, ":")
} else {
out = append(out, v)
}
}
if len(param) > 0 {
if paramsNum > 0 {
out = append(out, []rune(`(.+)`)...)
}
params = append(params, ":"+string(param))
}
return true, params, string(out)
}
return false, nil, ""
}

306
server/web/tree_test.go Normal file
View File

@ -0,0 +1,306 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"strings"
"testing"
"github.com/astaxie/beego/server/web/context"
)
type testinfo struct {
url string
requesturl string
params map[string]string
}
var routers []testinfo
func init() {
routers = make([]testinfo, 0)
routers = append(routers, testinfo{"/topic/?:auth:int", "/topic", nil})
routers = append(routers, testinfo{"/topic/?:auth:int", "/topic/123", map[string]string{":auth": "123"}})
routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1", map[string]string{":id": "1"}})
routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1/2", map[string]string{":id": "1", ":auth": "2"}})
routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1", map[string]string{":id": "1"}})
routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1/123", map[string]string{":id": "1", ":auth": "123"}})
routers = append(routers, testinfo{"/:id", "/123", map[string]string{":id": "123"}})
routers = append(routers, testinfo{"/hello/?:id", "/hello", map[string]string{":id": ""}})
routers = append(routers, testinfo{"/", "/", nil})
routers = append(routers, testinfo{"/customer/login", "/customer/login", nil})
routers = append(routers, testinfo{"/customer/login", "/customer/login.json", map[string]string{":ext": "json"}})
routers = append(routers, testinfo{"/*", "/http://customer/123/", map[string]string{":splat": "http://customer/123/"}})
routers = append(routers, testinfo{"/*", "/customer/2009/12/11", map[string]string{":splat": "customer/2009/12/11"}})
routers = append(routers, testinfo{"/aa/*/bb", "/aa/2009/bb", map[string]string{":splat": "2009"}})
routers = append(routers, testinfo{"/cc/*/dd", "/cc/2009/11/dd", map[string]string{":splat": "2009/11"}})
routers = append(routers, testinfo{"/cc/:id/*", "/cc/2009/11/dd", map[string]string{":id": "2009", ":splat": "11/dd"}})
routers = append(routers, testinfo{"/ee/:year/*/ff", "/ee/2009/11/ff", map[string]string{":year": "2009", ":splat": "11"}})
routers = append(routers, testinfo{"/thumbnail/:size/uploads/*",
"/thumbnail/100x100/uploads/items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg",
map[string]string{":size": "100x100", ":splat": "items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg"}})
routers = append(routers, testinfo{"/*.*", "/nice/api.json", map[string]string{":path": "nice/api", ":ext": "json"}})
routers = append(routers, testinfo{"/:name/*.*", "/nice/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}})
routers = append(routers, testinfo{"/:name/test/*.*", "/nice/test/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}})
routers = append(routers, testinfo{"/dl/:width:int/:height:int/*.*",
"/dl/48/48/05ac66d9bda00a3acf948c43e306fc9a.jpg",
map[string]string{":width": "48", ":height": "48", ":ext": "jpg", ":path": "05ac66d9bda00a3acf948c43e306fc9a"}})
routers = append(routers, testinfo{"/v1/shop/:id:int", "/v1/shop/123", map[string]string{":id": "123"}})
routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(a)", map[string]string{":id": "123"}})
routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(b)", map[string]string{":id": "123"}})
routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(c)", map[string]string{":id": "123"}})
routers = append(routers, testinfo{"/:year:int/:month:int/:id/:endid", "/1111/111/aaa/aaa", map[string]string{":year": "1111", ":month": "111", ":id": "aaa", ":endid": "aaa"}})
routers = append(routers, testinfo{"/v1/shop/:id/:name", "/v1/shop/123/nike", map[string]string{":id": "123", ":name": "nike"}})
routers = append(routers, testinfo{"/v1/shop/:id/account", "/v1/shop/123/account", map[string]string{":id": "123"}})
routers = append(routers, testinfo{"/v1/shop/:name:string", "/v1/shop/nike", map[string]string{":name": "nike"}})
routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)", "/v1/shop//123", map[string]string{":id": "123"}})
routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)_:name", "/v1/shop/123_nike", map[string]string{":id": "123", ":name": "nike"}})
routers = append(routers, testinfo{"/v1/shop/:id(.+)_cms.html", "/v1/shop/123_cms.html", map[string]string{":id": "123"}})
routers = append(routers, testinfo{"/v1/shop/cms_:id(.+)_:page(.+).html", "/v1/shop/cms_123_1.html", map[string]string{":id": "123", ":page": "1"}})
routers = append(routers, testinfo{"/v1/:v/cms/aaa_:id(.+)_:page(.+).html", "/v1/2/cms/aaa_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}})
routers = append(routers, testinfo{"/v1/:v/cms_:id(.+)_:page(.+).html", "/v1/2/cms_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}})
routers = append(routers, testinfo{"/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", "/v1/2_cms/ttt_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}})
routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members", map[string]string{":pid": "1"}})
routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members/2", map[string]string{":pid": "1", ":mid": "2"}})
}
func TestTreeRouters(t *testing.T) {
for _, r := range routers {
tr := NewTree()
tr.AddRouter(r.url, "astaxie")
ctx := context.NewContext()
obj := tr.Match(r.requesturl, ctx)
if obj == nil || obj.(string) != "astaxie" {
t.Fatal(r.url+" can't get obj, Expect ", r.requesturl)
}
if r.params != nil {
for k, v := range r.params {
if vv := ctx.Input.Param(k); vv != v {
t.Fatal("The Rule: " + r.url + "\nThe RequestURL:" + r.requesturl + "\nThe Key is " + k + ", The Value should be: " + v + ", but get: " + vv)
} else if vv == "" && v != "" {
t.Fatal(r.url + " " + r.requesturl + " get param empty:" + k)
}
}
}
}
}
func TestStaticPath(t *testing.T) {
tr := NewTree()
tr.AddRouter("/topic/:id", "wildcard")
tr.AddRouter("/topic", "static")
ctx := context.NewContext()
obj := tr.Match("/topic", ctx)
if obj == nil || obj.(string) != "static" {
t.Fatal("/topic is a static route")
}
obj = tr.Match("/topic/1", ctx)
if obj == nil || obj.(string) != "wildcard" {
t.Fatal("/topic/1 is a wildcard route")
}
}
func TestAddTree(t *testing.T) {
tr := NewTree()
tr.AddRouter("/shop/:id/account", "astaxie")
tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie")
t1 := NewTree()
t1.AddTree("/v1/zl", tr)
ctx := context.NewContext()
obj := t1.Match("/v1/zl/shop/123/account", ctx)
if obj == nil || obj.(string) != "astaxie" {
t.Fatal("/v1/zl/shop/:id/account can't get obj ")
}
if ctx.Input.ParamsLen() == 0 {
t.Fatal("get param error")
}
if ctx.Input.Param(":id") != "123" {
t.Fatal("get :id param error")
}
ctx.Input.Reset(ctx)
obj = t1.Match("/v1/zl/shop/123/ttt_1_12.html", ctx)
if obj == nil || obj.(string) != "astaxie" {
t.Fatal("/v1/zl//shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ")
}
if ctx.Input.ParamsLen() == 0 {
t.Fatal("get param error")
}
if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" {
t.Fatal("get :sd :id :page param error")
}
t2 := NewTree()
t2.AddTree("/v1/:shopid", tr)
ctx.Input.Reset(ctx)
obj = t2.Match("/v1/zl/shop/123/account", ctx)
if obj == nil || obj.(string) != "astaxie" {
t.Fatal("/v1/:shopid/shop/:id/account can't get obj ")
}
if ctx.Input.ParamsLen() == 0 {
t.Fatal("get param error")
}
if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":shopid") != "zl" {
t.Fatal("get :id :shopid param error")
}
ctx.Input.Reset(ctx)
obj = t2.Match("/v1/zl/shop/123/ttt_1_12.html", ctx)
if obj == nil || obj.(string) != "astaxie" {
t.Fatal("/v1/:shopid/shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ")
}
if ctx.Input.ParamsLen() == 0 {
t.Fatal("get :shopid param error")
}
if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" || ctx.Input.Param(":shopid") != "zl" {
t.Fatal("get :sd :id :page :shopid param error")
}
}
func TestAddTree2(t *testing.T) {
tr := NewTree()
tr.AddRouter("/shop/:id/account", "astaxie")
tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie")
t3 := NewTree()
t3.AddTree("/:version(v1|v2)/:prefix", tr)
ctx := context.NewContext()
obj := t3.Match("/v1/zl/shop/123/account", ctx)
if obj == nil || obj.(string) != "astaxie" {
t.Fatal("/:version(v1|v2)/:prefix/shop/:id/account can't get obj ")
}
if ctx.Input.ParamsLen() == 0 {
t.Fatal("get param error")
}
if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":prefix") != "zl" || ctx.Input.Param(":version") != "v1" {
t.Fatal("get :id :prefix :version param error")
}
}
func TestAddTree3(t *testing.T) {
tr := NewTree()
tr.AddRouter("/create", "astaxie")
tr.AddRouter("/shop/:sd/account", "astaxie")
t3 := NewTree()
t3.AddTree("/table/:num", tr)
ctx := context.NewContext()
obj := t3.Match("/table/123/shop/123/account", ctx)
if obj == nil || obj.(string) != "astaxie" {
t.Fatal("/table/:num/shop/:sd/account can't get obj ")
}
if ctx.Input.ParamsLen() == 0 {
t.Fatal("get param error")
}
if ctx.Input.Param(":num") != "123" || ctx.Input.Param(":sd") != "123" {
t.Fatal("get :num :sd param error")
}
ctx.Input.Reset(ctx)
obj = t3.Match("/table/123/create", ctx)
if obj == nil || obj.(string) != "astaxie" {
t.Fatal("/table/:num/create can't get obj ")
}
}
func TestAddTree4(t *testing.T) {
tr := NewTree()
tr.AddRouter("/create", "astaxie")
tr.AddRouter("/shop/:sd/:account", "astaxie")
t4 := NewTree()
t4.AddTree("/:info:int/:num/:id", tr)
ctx := context.NewContext()
obj := t4.Match("/12/123/456/shop/123/account", ctx)
if obj == nil || obj.(string) != "astaxie" {
t.Fatal("/:info:int/:num/:id/shop/:sd/:account can't get obj ")
}
if ctx.Input.ParamsLen() == 0 {
t.Fatal("get param error")
}
if ctx.Input.Param(":info") != "12" || ctx.Input.Param(":num") != "123" ||
ctx.Input.Param(":id") != "456" || ctx.Input.Param(":sd") != "123" ||
ctx.Input.Param(":account") != "account" {
t.Fatal("get :info :num :id :sd :account param error")
}
ctx.Input.Reset(ctx)
obj = t4.Match("/12/123/456/create", ctx)
if obj == nil || obj.(string) != "astaxie" {
t.Fatal("/:info:int/:num/:id/create can't get obj ")
}
}
// Test for issue #1595
func TestAddTree5(t *testing.T) {
tr := NewTree()
tr.AddRouter("/v1/shop/:id", "shopdetail")
tr.AddRouter("/v1/shop/", "shophome")
ctx := context.NewContext()
obj := tr.Match("/v1/shop/", ctx)
if obj == nil || obj.(string) != "shophome" {
t.Fatal("url /v1/shop/ need match router /v1/shop/ ")
}
}
func TestSplitPath(t *testing.T) {
a := splitPath("")
if len(a) != 0 {
t.Fatal("/ should retrun []")
}
a = splitPath("/")
if len(a) != 0 {
t.Fatal("/ should retrun []")
}
a = splitPath("/admin")
if len(a) != 1 || a[0] != "admin" {
t.Fatal("/admin should retrun [admin]")
}
a = splitPath("/admin/")
if len(a) != 1 || a[0] != "admin" {
t.Fatal("/admin/ should retrun [admin]")
}
a = splitPath("/admin/users")
if len(a) != 2 || a[0] != "admin" || a[1] != "users" {
t.Fatal("/admin should retrun [admin users]")
}
a = splitPath("/admin/:id:int")
if len(a) != 2 || a[0] != "admin" || a[1] != ":id:int" {
t.Fatal("/admin should retrun [admin :id:int]")
}
}
func TestSplitSegment(t *testing.T) {
items := map[string]struct {
isReg bool
params []string
regStr string
}{
"admin": {false, nil, ""},
"*": {true, []string{":splat"}, ""},
"*.*": {true, []string{".", ":path", ":ext"}, ""},
":id": {true, []string{":id"}, ""},
"?:id": {true, []string{":", ":id"}, ""},
":id:int": {true, []string{":id"}, "([0-9]+)"},
":name:string": {true, []string{":name"}, `([\w]+)`},
":id([0-9]+)": {true, []string{":id"}, `([0-9]+)`},
":id([0-9]+)_:name": {true, []string{":id", ":name"}, `([0-9]+)_(.+)`},
":id(.+)_cms.html": {true, []string{":id"}, `(.+)_cms.html`},
"cms_:id(.+)_:page(.+).html": {true, []string{":id", ":page"}, `cms_(.+)_(.+).html`},
`:app(a|b|c)`: {true, []string{":app"}, `(a|b|c)`},
`:app\((a|b|c)\)`: {true, []string{":app"}, `(.+)\((a|b|c)\)`},
}
for pattern, v := range items {
b, w, r := splitSegment(pattern)
if b != v.isReg || r != v.regStr || strings.Join(w, ",") != strings.Join(v.params, ",") {
t.Fatalf("%s should return %t,%s,%q, got %t,%s,%q", pattern, v.isReg, v.params, v.regStr, b, w, r)
}
}
}

View File

@ -0,0 +1,226 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package web
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
//
// The unregroute_test.go contains tests for the unregister route
// functionality, that allows overriding route paths in children project
// that embed parent routers.
//
const contentRootOriginal = "ok-original-root"
const contentLevel1Original = "ok-original-level1"
const contentLevel2Original = "ok-original-level2"
const contentRootReplacement = "ok-replacement-root"
const contentLevel1Replacement = "ok-replacement-level1"
const contentLevel2Replacement = "ok-replacement-level2"
// TestPreUnregController will supply content for the original routes,
// before unregistration
type TestPreUnregController struct {
Controller
}
func (tc *TestPreUnregController) GetFixedRoot() {
tc.Ctx.Output.Body([]byte(contentRootOriginal))
}
func (tc *TestPreUnregController) GetFixedLevel1() {
tc.Ctx.Output.Body([]byte(contentLevel1Original))
}
func (tc *TestPreUnregController) GetFixedLevel2() {
tc.Ctx.Output.Body([]byte(contentLevel2Original))
}
// TestPostUnregController will supply content for the overriding routes,
// after the original ones are unregistered.
type TestPostUnregController struct {
Controller
}
func (tc *TestPostUnregController) GetFixedRoot() {
tc.Ctx.Output.Body([]byte(contentRootReplacement))
}
func (tc *TestPostUnregController) GetFixedLevel1() {
tc.Ctx.Output.Body([]byte(contentLevel1Replacement))
}
func (tc *TestPostUnregController) GetFixedLevel2() {
tc.Ctx.Output.Body([]byte(contentLevel2Replacement))
}
// TestUnregisterFixedRouteRoot replaces just the root fixed route path.
// In this case, for a path like "/level1/level2" or "/level1", those actions
// should remain intact, and continue to serve the original content.
func TestUnregisterFixedRouteRoot(t *testing.T) {
var method = "GET"
handler := NewControllerRegister()
handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot")
handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1")
handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2")
// Test original root
testHelperFnContentCheck(t, handler, "Test original root",
method, "/", contentRootOriginal)
// Test original level 1
testHelperFnContentCheck(t, handler, "Test original level 1",
method, "/level1", contentLevel1Original)
// Test original level 2
testHelperFnContentCheck(t, handler, "Test original level 2",
method, "/level1/level2", contentLevel2Original)
// Remove only the root path
findAndRemoveSingleTree(handler.routers[method])
// Replace the root path TestPreUnregController action with the action from
// TestPostUnregController
handler.Add("/", &TestPostUnregController{}, "get:GetFixedRoot")
// Test replacement root (expect change)
testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement)
// Test level 1 (expect no change from the original)
testHelperFnContentCheck(t, handler, "Test level 1 (expect no change from the original)", method, "/level1", contentLevel1Original)
// Test level 2 (expect no change from the original)
testHelperFnContentCheck(t, handler, "Test level 2 (expect no change from the original)", method, "/level1/level2", contentLevel2Original)
}
// TestUnregisterFixedRouteLevel1 replaces just the "/level1" fixed route path.
// In this case, for a path like "/level1/level2" or "/", those actions
// should remain intact, and continue to serve the original content.
func TestUnregisterFixedRouteLevel1(t *testing.T) {
var method = "GET"
handler := NewControllerRegister()
handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot")
handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1")
handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2")
// Test original root
testHelperFnContentCheck(t, handler,
"TestUnregisterFixedRouteLevel1.Test original root",
method, "/", contentRootOriginal)
// Test original level 1
testHelperFnContentCheck(t, handler,
"TestUnregisterFixedRouteLevel1.Test original level 1",
method, "/level1", contentLevel1Original)
// Test original level 2
testHelperFnContentCheck(t, handler,
"TestUnregisterFixedRouteLevel1.Test original level 2",
method, "/level1/level2", contentLevel2Original)
// Remove only the level1 path
subPaths := splitPath("/level1")
if handler.routers[method].prefix == strings.Trim("/level1", "/ ") {
findAndRemoveSingleTree(handler.routers[method])
} else {
findAndRemoveTree(subPaths, handler.routers[method], method)
}
// Replace the "level1" path TestPreUnregController action with the action from
// TestPostUnregController
handler.Add("/level1", &TestPostUnregController{}, "get:GetFixedLevel1")
// Test replacement root (expect no change from the original)
testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal)
// Test level 1 (expect change)
testHelperFnContentCheck(t, handler, "Test level 1 (expect change)", method, "/level1", contentLevel1Replacement)
// Test level 2 (expect no change from the original)
testHelperFnContentCheck(t, handler, "Test level 2 (expect no change from the original)", method, "/level1/level2", contentLevel2Original)
}
// TestUnregisterFixedRouteLevel2 unregisters just the "/level1/level2" fixed
// route path. In this case, for a path like "/level1" or "/", those actions
// should remain intact, and continue to serve the original content.
func TestUnregisterFixedRouteLevel2(t *testing.T) {
var method = "GET"
handler := NewControllerRegister()
handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot")
handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1")
handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2")
// Test original root
testHelperFnContentCheck(t, handler,
"TestUnregisterFixedRouteLevel1.Test original root",
method, "/", contentRootOriginal)
// Test original level 1
testHelperFnContentCheck(t, handler,
"TestUnregisterFixedRouteLevel1.Test original level 1",
method, "/level1", contentLevel1Original)
// Test original level 2
testHelperFnContentCheck(t, handler,
"TestUnregisterFixedRouteLevel1.Test original level 2",
method, "/level1/level2", contentLevel2Original)
// Remove only the level2 path
subPaths := splitPath("/level1/level2")
if handler.routers[method].prefix == strings.Trim("/level1/level2", "/ ") {
findAndRemoveSingleTree(handler.routers[method])
} else {
findAndRemoveTree(subPaths, handler.routers[method], method)
}
// Replace the "/level1/level2" path TestPreUnregController action with the action from
// TestPostUnregController
handler.Add("/level1/level2", &TestPostUnregController{}, "get:GetFixedLevel2")
// Test replacement root (expect no change from the original)
testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal)
// Test level 1 (expect no change from the original)
testHelperFnContentCheck(t, handler, "Test level 1 (expect no change from the original)", method, "/level1", contentLevel1Original)
// Test level 2 (expect change)
testHelperFnContentCheck(t, handler, "Test level 2 (expect change)", method, "/level1/level2", contentLevel2Replacement)
}
func testHelperFnContentCheck(t *testing.T, handler *ControllerRegister,
testName, method, path, expectedBodyContent string) {
r, err := http.NewRequest(method, path, nil)
if err != nil {
t.Errorf("httpRecorderBodyTest NewRequest error: %v", err)
return
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
body := w.Body.String()
if body != expectedBodyContent {
t.Errorf("%s: expected [%s], got [%s];", testName, expectedBodyContent, body)
}
}