mirror of
https://github.com/astaxie/beego.git
synced 2024-11-05 04:50:54 +00:00
669 lines
17 KiB
Go
669 lines
17 KiB
Go
package beego
|
||
|
||
import (
|
||
"fmt"
|
||
beecontext "github.com/astaxie/beego/context"
|
||
"github.com/astaxie/beego/middleware"
|
||
"net/http"
|
||
"net/url"
|
||
"os"
|
||
"reflect"
|
||
"regexp"
|
||
"runtime"
|
||
"strconv"
|
||
"strings"
|
||
)
|
||
|
||
var HTTPMETHOD = []string{"get", "post", "put", "delete", "patch", "options", "head"}
|
||
|
||
type controllerInfo struct {
|
||
pattern string
|
||
regex *regexp.Regexp
|
||
params map[int]string
|
||
controllerType reflect.Type
|
||
methods map[string]string
|
||
hasMethod bool
|
||
}
|
||
|
||
type ControllerRegistor struct {
|
||
routers []*controllerInfo
|
||
fixrouters []*controllerInfo
|
||
enableFilter bool
|
||
filters map[string][]*FilterRouter
|
||
enableAuto bool
|
||
autoRouter map[string]map[string]reflect.Type //key:controller key:method value:reflect.type
|
||
}
|
||
|
||
func NewControllerRegistor() *ControllerRegistor {
|
||
return &ControllerRegistor{
|
||
routers: make([]*controllerInfo, 0),
|
||
autoRouter: make(map[string]map[string]reflect.Type),
|
||
filters: make(map[string][]*FilterRouter),
|
||
}
|
||
}
|
||
|
||
//methods support like this:
|
||
//default methods is the same name as method
|
||
//Add("/user",&UserController{})
|
||
//Add("/api/list",&RestController{},"*:ListFood")
|
||
//Add("/api/create",&RestController{},"post:CreateFood")
|
||
//Add("/api/update",&RestController{},"put:UpdateFood")
|
||
//Add("/api/delete",&RestController{},"delete:DeleteFood")
|
||
//Add("/api",&RestController{},"get,post:ApiFunc")
|
||
//Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
|
||
func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
|
||
parts := strings.Split(pattern, "/")
|
||
|
||
j := 0
|
||
params := make(map[int]string)
|
||
for i, part := range parts {
|
||
if strings.HasPrefix(part, ":") {
|
||
expr := "(.+)"
|
||
//a user may choose to override the defult expression
|
||
// similar to expressjs: ‘/user/:id([0-9]+)’
|
||
if index := strings.Index(part, "("); index != -1 {
|
||
expr = part[index:]
|
||
part = part[:index]
|
||
//match /user/:id:int ([0-9]+)
|
||
//match /post/:username:string ([\w]+)
|
||
} else if lindex := strings.LastIndex(part, ":"); lindex != 0 {
|
||
switch part[lindex:] {
|
||
case ":int":
|
||
expr = "([0-9]+)"
|
||
part = part[:lindex]
|
||
case ":string":
|
||
expr = `([\w]+)`
|
||
part = part[:lindex]
|
||
}
|
||
}
|
||
params[j] = part
|
||
parts[i] = expr
|
||
j++
|
||
}
|
||
if strings.HasPrefix(part, "*") {
|
||
expr := "(.+)"
|
||
if part == "*.*" {
|
||
params[j] = ":path"
|
||
parts[i] = "([^.]+).([^.]+)"
|
||
j++
|
||
params[j] = ":ext"
|
||
j++
|
||
} else {
|
||
params[j] = ":splat"
|
||
parts[i] = expr
|
||
j++
|
||
}
|
||
}
|
||
}
|
||
reflectVal := reflect.ValueOf(c)
|
||
t := reflect.Indirect(reflectVal).Type()
|
||
methods := make(map[string]string)
|
||
if len(mappingMethods) > 0 {
|
||
semi := strings.Split(mappingMethods[0], ";")
|
||
for _, v := range semi {
|
||
colon := strings.Split(v, ":")
|
||
if len(colon) != 2 {
|
||
panic("method mapping fomate is error")
|
||
}
|
||
comma := strings.Split(colon[0], ",")
|
||
for _, m := range comma {
|
||
if m == "*" || inSlice(strings.ToLower(m), HTTPMETHOD) {
|
||
if val := reflectVal.MethodByName(colon[1]); val.IsValid() {
|
||
methods[strings.ToLower(m)] = colon[1]
|
||
} else {
|
||
panic(colon[1] + " method don't exist in the controller " + t.Name())
|
||
}
|
||
} else {
|
||
panic(v + " is an error method mapping,Don't exist method named " + m)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
if j == 0 {
|
||
//now create the Route
|
||
route := &controllerInfo{}
|
||
route.pattern = pattern
|
||
route.controllerType = t
|
||
route.methods = methods
|
||
if len(methods) > 0 {
|
||
route.hasMethod = true
|
||
}
|
||
p.fixrouters = append(p.fixrouters, route)
|
||
} else { // add regexp routers
|
||
//recreate the url pattern, with parameters replaced
|
||
//by regular expressions. then compile the regex
|
||
pattern = strings.Join(parts, "/")
|
||
regex, regexErr := regexp.Compile(pattern)
|
||
if regexErr != nil {
|
||
//TODO add error handling here to avoid panic
|
||
panic(regexErr)
|
||
return
|
||
}
|
||
|
||
//now create the Route
|
||
|
||
route := &controllerInfo{}
|
||
route.regex = regex
|
||
route.params = params
|
||
route.pattern = pattern
|
||
route.methods = methods
|
||
if len(methods) > 0 {
|
||
route.hasMethod = true
|
||
}
|
||
route.controllerType = t
|
||
p.routers = append(p.routers, route)
|
||
}
|
||
}
|
||
|
||
func (p *ControllerRegistor) AddAuto(c ControllerInterface) {
|
||
p.enableAuto = true
|
||
reflectVal := reflect.ValueOf(c)
|
||
rt := reflectVal.Type()
|
||
ct := reflect.Indirect(reflectVal).Type()
|
||
firstParam := strings.ToLower(strings.TrimSuffix(ct.Name(), "Controller"))
|
||
if _, ok := p.autoRouter[firstParam]; ok {
|
||
return
|
||
} else {
|
||
p.autoRouter[firstParam] = make(map[string]reflect.Type)
|
||
}
|
||
for i := 0; i < rt.NumMethod(); i++ {
|
||
p.autoRouter[firstParam][rt.Method(i).Name] = ct
|
||
}
|
||
}
|
||
|
||
// Filter adds the middleware filter.
|
||
func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc) {
|
||
p.enableFilter = true
|
||
mr := new(FilterRouter)
|
||
mr.filterFunc = filter
|
||
|
||
parts := strings.Split(pattern, "/")
|
||
j := 0
|
||
for i, part := range parts {
|
||
if strings.HasPrefix(part, ":") {
|
||
expr := "(.+)"
|
||
//a user may choose to override the defult expression
|
||
// similar to expressjs: ‘/user/:id([0-9]+)’
|
||
if index := strings.Index(part, "("); index != -1 {
|
||
expr = part[index:]
|
||
part = part[:index]
|
||
//match /user/:id:int ([0-9]+)
|
||
//match /post/:username:string ([\w]+)
|
||
} else if lindex := strings.LastIndex(part, ":"); lindex != 0 {
|
||
switch part[lindex:] {
|
||
case ":int":
|
||
expr = "([0-9]+)"
|
||
part = part[:lindex]
|
||
case ":string":
|
||
expr = `([\w]+)`
|
||
part = part[:lindex]
|
||
}
|
||
}
|
||
parts[i] = expr
|
||
j++
|
||
}
|
||
}
|
||
if j != 0 {
|
||
pattern = strings.Join(parts, "/")
|
||
regex, regexErr := regexp.Compile(pattern)
|
||
if regexErr != nil {
|
||
//TODO add error handling here to avoid panic
|
||
panic(regexErr)
|
||
return
|
||
}
|
||
mr.regex = regex
|
||
mr.hasregex = true
|
||
}
|
||
mr.pattern = pattern
|
||
p.filters[action] = append(p.filters[action], mr)
|
||
}
|
||
|
||
// AutoRoute
|
||
func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||
defer func() {
|
||
if err := recover(); err != nil {
|
||
errstr := fmt.Sprint(err)
|
||
if handler, ok := middleware.ErrorMaps[errstr]; ok && ErrorsShow {
|
||
handler(rw, r)
|
||
} else {
|
||
if !RecoverPanic {
|
||
// go back to panic
|
||
panic(err)
|
||
} else {
|
||
var stack string
|
||
Critical("Handler crashed with error", err)
|
||
for i := 1; ; i++ {
|
||
_, file, line, ok := runtime.Caller(i)
|
||
if !ok {
|
||
break
|
||
}
|
||
Critical(file, line)
|
||
if RunMode == "dev" {
|
||
stack = stack + fmt.Sprintln(file, line)
|
||
}
|
||
}
|
||
if RunMode == "dev" {
|
||
middleware.ShowErr(err, rw, r, stack)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}()
|
||
|
||
w := &responseWriter{writer: rw}
|
||
w.Header().Set("Server", "beegoServer")
|
||
context := &beecontext.Context{
|
||
ResponseWriter: w,
|
||
Request: r,
|
||
Input: beecontext.NewInput(r),
|
||
Output: beecontext.NewOutput(w),
|
||
}
|
||
context.Output.Context = context
|
||
|
||
if context.Input.IsWebsocket() {
|
||
context.ResponseWriter = rw
|
||
context.Output = beecontext.NewOutput(rw)
|
||
}
|
||
|
||
var runrouter *controllerInfo
|
||
var findrouter bool
|
||
|
||
params := make(map[string]string)
|
||
|
||
context.Input.Param = params
|
||
if p.enableFilter {
|
||
if l, ok := p.filters["BeforRouter"]; ok {
|
||
for _, filterR := range l {
|
||
if filterR.ValidRouter(r.URL.Path) {
|
||
filterR.filterFunc(context)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
//static file server
|
||
for prefix, staticDir := range StaticDir {
|
||
if r.URL.Path == "/favicon.ico" {
|
||
file := staticDir + r.URL.Path
|
||
http.ServeFile(w, r, file)
|
||
w.started = true
|
||
return
|
||
}
|
||
if strings.HasPrefix(r.URL.Path, prefix) {
|
||
file := staticDir + r.URL.Path[len(prefix):]
|
||
finfo, err := os.Stat(file)
|
||
if err != nil {
|
||
return
|
||
}
|
||
//if the request is dir and DirectoryIndex is false then
|
||
if finfo.IsDir() && !DirectoryIndex {
|
||
middleware.Exception("403", rw, r, "403 Forbidden")
|
||
return
|
||
}
|
||
http.ServeFile(w, r, file)
|
||
w.started = true
|
||
return
|
||
}
|
||
}
|
||
|
||
if p.enableFilter {
|
||
if l, ok := p.filters["AfterStatic"]; ok {
|
||
for _, filterR := range l {
|
||
if filterR.ValidRouter(r.URL.Path) {
|
||
filterR.filterFunc(context)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
requestPath := r.URL.Path
|
||
|
||
if CopyRequestBody {
|
||
context.Input.Body()
|
||
}
|
||
|
||
//first find path from the fixrouters to Improve Performance
|
||
for _, route := range p.fixrouters {
|
||
n := len(requestPath)
|
||
if requestPath == route.pattern {
|
||
runrouter = route
|
||
findrouter = true
|
||
break
|
||
}
|
||
// pattern /admin url /admin 200 /admin/ 404
|
||
// pattern /admin/ url /admin 301 /admin/ 200
|
||
if requestPath[n-1] != '/' && len(route.pattern) == n+1 &&
|
||
route.pattern[n] == '/' && route.pattern[:n-1] == requestPath {
|
||
http.Redirect(w, r, requestPath+"/", 301)
|
||
return
|
||
}
|
||
}
|
||
|
||
//find regex's router
|
||
if !findrouter {
|
||
//find a matching Route
|
||
for _, route := range p.routers {
|
||
|
||
//check if Route pattern matches url
|
||
if !route.regex.MatchString(requestPath) {
|
||
continue
|
||
}
|
||
|
||
//get submatches (params)
|
||
matches := route.regex.FindStringSubmatch(requestPath)
|
||
|
||
//double check that the Route matches the URL pattern.
|
||
if len(matches[0]) != len(requestPath) {
|
||
continue
|
||
}
|
||
|
||
if len(route.params) > 0 {
|
||
//add url parameters to the query param map
|
||
values := r.URL.Query()
|
||
for i, match := range matches[1:] {
|
||
values.Add(route.params[i], match)
|
||
params[route.params[i]] = match
|
||
}
|
||
//reassemble query params and add to RawQuery
|
||
r.URL.RawQuery = url.Values(values).Encode()
|
||
}
|
||
runrouter = route
|
||
findrouter = true
|
||
break
|
||
}
|
||
}
|
||
|
||
if runrouter != nil {
|
||
if r.Method == "POST" {
|
||
r.ParseMultipartForm(MaxMemory)
|
||
}
|
||
//execute middleware filters
|
||
if p.enableFilter {
|
||
if l, ok := p.filters["BeforExec"]; ok {
|
||
for _, filterR := range l {
|
||
if filterR.ValidRouter(r.URL.Path) {
|
||
filterR.filterFunc(context)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
//Invoke the request handler
|
||
vc := reflect.New(runrouter.controllerType)
|
||
|
||
//call the controller init function
|
||
init := vc.MethodByName("Init")
|
||
in := make([]reflect.Value, 2)
|
||
in[0] = reflect.ValueOf(context)
|
||
in[1] = reflect.ValueOf(runrouter.controllerType.Name())
|
||
init.Call(in)
|
||
//call prepare function
|
||
in = make([]reflect.Value, 0)
|
||
method := vc.MethodByName("Prepare")
|
||
method.Call(in)
|
||
|
||
//if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf
|
||
if EnableXSRF {
|
||
method = vc.MethodByName("XsrfToken")
|
||
method.Call(in)
|
||
if r.Method == "POST" || r.Method == "DELETE" || r.Method == "PUT" ||
|
||
(r.Method == "POST" && (r.Form.Get("_method") == "delete" || r.Form.Get("_method") == "put")) {
|
||
method = vc.MethodByName("CheckXsrfCookie")
|
||
method.Call(in)
|
||
}
|
||
}
|
||
|
||
//if response has written,yes don't run next
|
||
if !w.started {
|
||
if r.Method == "GET" {
|
||
if runrouter.hasMethod {
|
||
if m, ok := runrouter.methods["get"]; ok {
|
||
method = vc.MethodByName(m)
|
||
} else if m, ok = runrouter.methods["*"]; ok {
|
||
method = vc.MethodByName(m)
|
||
} else {
|
||
method = vc.MethodByName("Get")
|
||
}
|
||
} else {
|
||
method = vc.MethodByName("Get")
|
||
}
|
||
method.Call(in)
|
||
} else if r.Method == "HEAD" {
|
||
if runrouter.hasMethod {
|
||
if m, ok := runrouter.methods["head"]; ok {
|
||
method = vc.MethodByName(m)
|
||
} else if m, ok = runrouter.methods["*"]; ok {
|
||
method = vc.MethodByName(m)
|
||
} else {
|
||
method = vc.MethodByName("Head")
|
||
}
|
||
} else {
|
||
method = vc.MethodByName("Head")
|
||
}
|
||
|
||
method.Call(in)
|
||
} else if r.Method == "DELETE" || (r.Method == "POST" && r.Form.Get("_method") == "delete") {
|
||
if runrouter.hasMethod {
|
||
if m, ok := runrouter.methods["delete"]; ok {
|
||
method = vc.MethodByName(m)
|
||
} else if m, ok = runrouter.methods["*"]; ok {
|
||
method = vc.MethodByName(m)
|
||
} else {
|
||
method = vc.MethodByName("Delete")
|
||
}
|
||
} else {
|
||
method = vc.MethodByName("Delete")
|
||
}
|
||
method.Call(in)
|
||
} else if r.Method == "PUT" || (r.Method == "POST" && r.Form.Get("_method") == "put") {
|
||
if runrouter.hasMethod {
|
||
if m, ok := runrouter.methods["put"]; ok {
|
||
method = vc.MethodByName(m)
|
||
} else if m, ok = runrouter.methods["*"]; ok {
|
||
method = vc.MethodByName(m)
|
||
} else {
|
||
method = vc.MethodByName("Put")
|
||
}
|
||
} else {
|
||
method = vc.MethodByName("Put")
|
||
}
|
||
method.Call(in)
|
||
} else if r.Method == "POST" {
|
||
if runrouter.hasMethod {
|
||
if m, ok := runrouter.methods["post"]; ok {
|
||
method = vc.MethodByName(m)
|
||
} else if m, ok = runrouter.methods["*"]; ok {
|
||
method = vc.MethodByName(m)
|
||
} else {
|
||
method = vc.MethodByName("Post")
|
||
}
|
||
} else {
|
||
method = vc.MethodByName("Post")
|
||
}
|
||
method.Call(in)
|
||
} else if r.Method == "PATCH" {
|
||
if runrouter.hasMethod {
|
||
if m, ok := runrouter.methods["patch"]; ok {
|
||
method = vc.MethodByName(m)
|
||
} else if m, ok = runrouter.methods["*"]; ok {
|
||
method = vc.MethodByName(m)
|
||
} else {
|
||
method = vc.MethodByName("Patch")
|
||
}
|
||
} else {
|
||
method = vc.MethodByName("Patch")
|
||
}
|
||
method.Call(in)
|
||
} else if r.Method == "OPTIONS" {
|
||
if runrouter.hasMethod {
|
||
if m, ok := runrouter.methods["options"]; ok {
|
||
method = vc.MethodByName(m)
|
||
} else if m, ok = runrouter.methods["*"]; ok {
|
||
method = vc.MethodByName(m)
|
||
} else {
|
||
method = vc.MethodByName("Options")
|
||
}
|
||
} else {
|
||
method = vc.MethodByName("Options")
|
||
}
|
||
method.Call(in)
|
||
}
|
||
gotofunc := vc.Elem().FieldByName("gotofunc").String()
|
||
if gotofunc != "" {
|
||
method = vc.MethodByName(gotofunc)
|
||
if method.IsValid() {
|
||
method.Call(in)
|
||
} else {
|
||
panic("gotofunc is exists:" + gotofunc)
|
||
}
|
||
}
|
||
if !w.started && !context.Input.IsWebsocket() {
|
||
if AutoRender {
|
||
method = vc.MethodByName("Render")
|
||
method.Call(in)
|
||
}
|
||
}
|
||
}
|
||
|
||
method = vc.MethodByName("Finish")
|
||
method.Call(in)
|
||
//execute middleware filters
|
||
if p.enableFilter {
|
||
if l, ok := p.filters["AfterExec"]; ok {
|
||
for _, filterR := range l {
|
||
if filterR.ValidRouter(r.URL.Path) {
|
||
filterR.filterFunc(context)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
method = vc.MethodByName("Destructor")
|
||
method.Call(in)
|
||
}
|
||
|
||
//start autorouter
|
||
|
||
if p.enableAuto {
|
||
if !findrouter {
|
||
for cName, methodmap := range p.autoRouter {
|
||
|
||
if strings.ToLower(requestPath) == "/"+cName {
|
||
http.Redirect(w, r, requestPath+"/", 301)
|
||
return
|
||
}
|
||
|
||
if strings.ToLower(requestPath) == "/"+cName+"/" {
|
||
requestPath = requestPath + "index"
|
||
}
|
||
if strings.HasPrefix(strings.ToLower(requestPath), "/"+cName+"/") {
|
||
for mName, controllerType := range methodmap {
|
||
if strings.HasPrefix(strings.ToLower(requestPath), "/"+cName+"/"+strings.ToLower(mName)) {
|
||
if r.Method == "POST" {
|
||
r.ParseMultipartForm(MaxMemory)
|
||
}
|
||
//execute middleware filters
|
||
if p.enableFilter {
|
||
if l, ok := p.filters["BeforExec"]; ok {
|
||
for _, filterR := range l {
|
||
if filterR.ValidRouter(r.URL.Path) {
|
||
filterR.filterFunc(context)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
//parse params
|
||
otherurl := requestPath[len("/"+cName+"/"+strings.ToLower(mName)):]
|
||
if len(otherurl) > 1 {
|
||
plist := strings.Split(otherurl, "/")
|
||
for k, v := range plist[1:] {
|
||
params[strconv.Itoa(k)] = v
|
||
}
|
||
}
|
||
//Invoke the request handler
|
||
vc := reflect.New(controllerType)
|
||
|
||
//call the controller init function
|
||
init := vc.MethodByName("Init")
|
||
in := make([]reflect.Value, 2)
|
||
in[0] = reflect.ValueOf(context)
|
||
in[1] = reflect.ValueOf(controllerType.Name())
|
||
init.Call(in)
|
||
//call prepare function
|
||
in = make([]reflect.Value, 0)
|
||
method := vc.MethodByName("Prepare")
|
||
method.Call(in)
|
||
method = vc.MethodByName(mName)
|
||
method.Call(in)
|
||
//if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf
|
||
if EnableXSRF {
|
||
method = vc.MethodByName("XsrfToken")
|
||
method.Call(in)
|
||
if r.Method == "POST" || r.Method == "DELETE" || r.Method == "PUT" ||
|
||
(r.Method == "POST" && (r.Form.Get("_method") == "delete" || r.Form.Get("_method") == "put")) {
|
||
method = vc.MethodByName("CheckXsrfCookie")
|
||
method.Call(in)
|
||
}
|
||
}
|
||
if !w.started && !context.Input.IsWebsocket() {
|
||
if AutoRender {
|
||
method = vc.MethodByName("Render")
|
||
method.Call(in)
|
||
}
|
||
}
|
||
method = vc.MethodByName("Finish")
|
||
method.Call(in)
|
||
//execute middleware filters
|
||
if p.enableFilter {
|
||
if l, ok := p.filters["AfterExec"]; ok {
|
||
for _, filterR := range l {
|
||
if filterR.ValidRouter(r.URL.Path) {
|
||
filterR.filterFunc(context)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
method = vc.MethodByName("Destructor")
|
||
method.Call(in)
|
||
// set find
|
||
findrouter = true
|
||
goto Last
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
Last:
|
||
//if no matches to url, throw a not found exception
|
||
if !findrouter {
|
||
middleware.Exception("404", rw, r, "")
|
||
}
|
||
}
|
||
|
||
//responseWriter is a wrapper for the http.ResponseWriter
|
||
//started set to true if response was written to then don't execute other handler
|
||
type responseWriter struct {
|
||
writer http.ResponseWriter
|
||
started bool
|
||
status int
|
||
}
|
||
|
||
// Header returns the header map that will be sent by WriteHeader.
|
||
func (w *responseWriter) Header() http.Header {
|
||
return w.writer.Header()
|
||
}
|
||
|
||
// Write writes the data to the connection as part of an HTTP reply,
|
||
// and sets `started` to true
|
||
func (w *responseWriter) Write(p []byte) (int, error) {
|
||
w.started = true
|
||
return w.writer.Write(p)
|
||
}
|
||
|
||
// WriteHeader sends an HTTP response header with status code,
|
||
// and sets `started` to true
|
||
func (w *responseWriter) WriteHeader(code int) {
|
||
w.status = code
|
||
w.started = true
|
||
w.writer.WriteHeader(code)
|
||
}
|