This commit is contained in:
astaxie 2013-11-26 16:47:50 +08:00
parent fbd1c3f8b0
commit 54fb49ed95
4 changed files with 159 additions and 63 deletions

View File

@ -139,6 +139,7 @@ func (input *BeegoInput) Param(key string) string {
}
func (input *BeegoInput) Query(key string) string {
input.req.ParseForm()
return input.req.Form.Get(key)
}

133
filter.go
View File

@ -2,36 +2,143 @@ package beego
import (
"regexp"
"strings"
)
type FilterRouter struct {
pattern string
regex *regexp.Regexp
filterFunc FilterFunc
hasregex bool
params map[int]string
pattern string
regex *regexp.Regexp
filterFunc FilterFunc
hasregex bool
params map[int]string
parseParams map[string]string
}
func (mr *FilterRouter) ValidRouter(router string) bool {
func (mr *FilterRouter) ValidRouter(router string) (bool, map[string]string) {
if mr.pattern == "" {
return true
return true, nil
}
if mr.pattern == "*" {
return true
return true, nil
}
if router == mr.pattern {
return true
return true, nil
}
if mr.hasregex {
if mr.regex.MatchString(router) {
return true
if !mr.regex.MatchString(router) {
return false, nil
}
matches := mr.regex.FindStringSubmatch(router)
if len(matches) > 0 {
if len(matches[0]) == len(router) {
return true
params := make(map[string]string)
for i, match := range matches[1:] {
params[mr.params[i]] = match
}
return true, params
}
}
}
return false
return false, nil
}
func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
mr := new(FilterRouter)
mr.params = make(map[int]string)
mr.filterFunc = filter
parts := strings.Split(pattern, "/")
j := 0
for i, part := range parts {
if strings.HasPrefix(part, ":") {
expr := "(.+)"
//a user may choose to override the default expression
// similar to expressjs: /user/:id([0-9]+)
if index := strings.Index(part, "("); index != -1 {
expr = part[index:]
part = part[:index]
//match /user/:id:int ([0-9]+)
//match /post/:username:string ([\w]+)
} else if lindex := strings.LastIndex(part, ":"); lindex != 0 {
switch part[lindex:] {
case ":int":
expr = "([0-9]+)"
part = part[:lindex]
case ":string":
expr = `([\w]+)`
part = part[:lindex]
}
}
mr.params[j] = part
parts[i] = expr
j++
}
if strings.HasPrefix(part, "*") {
expr := "(.+)"
if part == "*.*" {
mr.params[j] = ":path"
parts[i] = "([^.]+).([^.]+)"
j++
mr.params[j] = ":ext"
j++
} else {
mr.params[j] = ":splat"
parts[i] = expr
j++
}
}
//url like someprefix:id(xxx).html
if strings.Contains(part, ":") && strings.Contains(part, "(") && strings.Contains(part, ")") {
var out []rune
var start bool
var startexp bool
var param []rune
var expt []rune
for _, v := range part {
if start {
if v != '(' {
param = append(param, v)
continue
}
}
if startexp {
if v != ')' {
expt = append(expt, v)
continue
}
}
if v == ':' {
param = make([]rune, 0)
param = append(param, ':')
start = true
} else if v == '(' {
startexp = true
start = false
mr.params[j] = string(param)
j++
expt = make([]rune, 0)
expt = append(expt, '(')
} else if v == ')' {
startexp = false
expt = append(expt, ')')
out = append(out, expt...)
} else {
out = append(out, v)
}
}
parts[i] = string(out)
}
}
if j != 0 {
pattern = strings.Join(parts, "/")
regex, regexErr := regexp.Compile(pattern)
if regexErr != nil {
//TODO add error handling here to avoid panic
panic(regexErr)
}
mr.regex = regex
mr.hasregex = true
}
mr.pattern = pattern
return mr
}

24
fiter_test.go Normal file
View File

@ -0,0 +1,24 @@
package beego
import (
"github.com/astaxie/beego/context"
"net/http"
"net/http/httptest"
"testing"
)
var FilterUser = func(ctx *context.Context) {
ctx.Output.Body([]byte("i am " + ctx.Input.Params[":last"] + ctx.Input.Params[":first"]))
}
func TestFilter(t *testing.T) {
r, _ := http.NewRequest("GET", "/person/asta/Xie", nil)
w := httptest.NewRecorder()
handler := NewControllerRegistor()
handler.AddFilter("/person/:last/:first", "AfterStatic", FilterUser)
handler.Add("/person/:last/:first", &TestController{})
handler.ServeHTTP(w, r)
if w.Body.String() != "i am astaXie" {
t.Errorf("user define func can't run")
}
}

View File

@ -222,49 +222,6 @@ func (p *ControllerRegistor) AddAuto(c ControllerInterface) {
}
}
func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
mr := new(FilterRouter)
mr.filterFunc = filter
parts := strings.Split(pattern, "/")
j := 0
for i, part := range parts {
if strings.HasPrefix(part, ":") {
expr := "(.+)"
//a user may choose to override the default expression
// similar to expressjs: /user/:id([0-9]+)
if index := strings.Index(part, "("); index != -1 {
expr = part[index:]
part = part[:index]
//match /user/:id:int ([0-9]+)
//match /post/:username:string ([\w]+)
} else if lindex := strings.LastIndex(part, ":"); lindex != 0 {
switch part[lindex:] {
case ":int":
expr = "([0-9]+)"
part = part[:lindex]
case ":string":
expr = `([\w]+)`
part = part[:lindex]
}
}
parts[i] = expr
j++
}
}
if j != 0 {
pattern = strings.Join(parts, "/")
regex, regexErr := regexp.Compile(pattern)
if regexErr != nil {
//TODO add error handling here to avoid panic
panic(regexErr)
}
mr.regex = regex
mr.hasregex = true
}
mr.pattern = pattern
return mr
}
func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc) {
mr := buildFilter(pattern, filter)
switch action {
@ -469,7 +426,8 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if p.enableFilter {
if l, ok := p.filters[BeforeRouter]; ok {
for _, filterR := range l {
if filterR.ValidRouter(r.URL.Path) {
if ok, p := filterR.ValidRouter(r.URL.Path); ok {
context.Input.Params = p
filterR.filterFunc(context)
if w.started {
goto Admin
@ -516,7 +474,8 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if p.enableFilter {
if l, ok := p.filters[AfterStatic]; ok {
for _, filterR := range l {
if filterR.ValidRouter(r.URL.Path) {
if ok, p := filterR.ValidRouter(r.URL.Path); ok {
context.Input.Params = p
filterR.filterFunc(context)
if w.started {
goto Admin
@ -591,7 +550,8 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if p.enableFilter {
if l, ok := p.filters[BeforeExec]; ok {
for _, filterR := range l {
if filterR.ValidRouter(r.URL.Path) {
if ok, p := filterR.ValidRouter(r.URL.Path); ok {
context.Input.Params = p
filterR.filterFunc(context)
if w.started {
goto Admin
@ -746,7 +706,8 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if p.enableFilter {
if l, ok := p.filters[AfterExec]; ok {
for _, filterR := range l {
if filterR.ValidRouter(r.URL.Path) {
if ok, p := filterR.ValidRouter(r.URL.Path); ok {
context.Input.Params = p
filterR.filterFunc(context)
if w.started {
goto Admin
@ -795,7 +756,8 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if p.enableFilter {
if l, ok := p.filters[BeforeExec]; ok {
for _, filterR := range l {
if filterR.ValidRouter(r.URL.Path) {
if ok, p := filterR.ValidRouter(r.URL.Path); ok {
context.Input.Params = p
filterR.filterFunc(context)
if w.started {
goto Admin
@ -850,7 +812,8 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if p.enableFilter {
if l, ok := p.filters[AfterExec]; ok {
for _, filterR := range l {
if filterR.ValidRouter(r.URL.Path) {
if ok, p := filterR.ValidRouter(r.URL.Path); ok {
context.Input.Params = p
filterR.filterFunc(context)
if w.started {
goto Admin
@ -878,7 +841,8 @@ Admin:
if p.enableFilter {
if l, ok := p.filters[FinishRouter]; ok {
for _, filterR := range l {
if filterR.ValidRouter(r.URL.Path) {
if ok, p := filterR.ValidRouter(r.URL.Path); ok {
context.Input.Params = p
filterR.filterFunc(context)
if w.started {
break