diff --git a/context/input.go b/context/input.go index f109f4b3..809f2366 100644 --- a/context/input.go +++ b/context/input.go @@ -139,6 +139,7 @@ func (input *BeegoInput) Param(key string) string { } func (input *BeegoInput) Query(key string) string { + input.req.ParseForm() return input.req.Form.Get(key) } diff --git a/filter.go b/filter.go index d4a6b60b..05e5fd49 100644 --- a/filter.go +++ b/filter.go @@ -2,36 +2,143 @@ package beego import ( "regexp" + "strings" ) type FilterRouter struct { - pattern string - regex *regexp.Regexp - filterFunc FilterFunc - hasregex bool - params map[int]string + pattern string + regex *regexp.Regexp + filterFunc FilterFunc + hasregex bool + params map[int]string + parseParams map[string]string } -func (mr *FilterRouter) ValidRouter(router string) bool { +func (mr *FilterRouter) ValidRouter(router string) (bool, map[string]string) { if mr.pattern == "" { - return true + return true, nil } if mr.pattern == "*" { - return true + return true, nil } if router == mr.pattern { - return true + return true, nil } if mr.hasregex { - if mr.regex.MatchString(router) { - return true + if !mr.regex.MatchString(router) { + return false, nil } matches := mr.regex.FindStringSubmatch(router) if len(matches) > 0 { if len(matches[0]) == len(router) { - return true + params := make(map[string]string) + for i, match := range matches[1:] { + params[mr.params[i]] = match + } + return true, params } } } - return false + return false, nil +} + +func buildFilter(pattern string, filter FilterFunc) *FilterRouter { + mr := new(FilterRouter) + mr.params = make(map[int]string) + mr.filterFunc = filter + parts := strings.Split(pattern, "/") + j := 0 + for i, part := range parts { + if strings.HasPrefix(part, ":") { + expr := "(.+)" + //a user may choose to override the default expression + // similar to expressjs: ‘/user/:id([0-9]+)’ + if index := strings.Index(part, "("); index != -1 { + expr = part[index:] + part = part[:index] + //match /user/:id:int ([0-9]+) + //match /post/:username:string ([\w]+) + } else if lindex := strings.LastIndex(part, ":"); lindex != 0 { + switch part[lindex:] { + case ":int": + expr = "([0-9]+)" + part = part[:lindex] + case ":string": + expr = `([\w]+)` + part = part[:lindex] + } + } + mr.params[j] = part + parts[i] = expr + j++ + } + if strings.HasPrefix(part, "*") { + expr := "(.+)" + if part == "*.*" { + mr.params[j] = ":path" + parts[i] = "([^.]+).([^.]+)" + j++ + mr.params[j] = ":ext" + j++ + } else { + mr.params[j] = ":splat" + parts[i] = expr + j++ + } + } + //url like someprefix:id(xxx).html + if strings.Contains(part, ":") && strings.Contains(part, "(") && strings.Contains(part, ")") { + var out []rune + var start bool + var startexp bool + var param []rune + var expt []rune + for _, v := range part { + if start { + if v != '(' { + param = append(param, v) + continue + } + } + if startexp { + if v != ')' { + expt = append(expt, v) + continue + } + } + if v == ':' { + param = make([]rune, 0) + param = append(param, ':') + start = true + } else if v == '(' { + startexp = true + start = false + mr.params[j] = string(param) + j++ + expt = make([]rune, 0) + expt = append(expt, '(') + } else if v == ')' { + startexp = false + expt = append(expt, ')') + out = append(out, expt...) + } else { + out = append(out, v) + } + } + parts[i] = string(out) + } + } + + if j != 0 { + pattern = strings.Join(parts, "/") + regex, regexErr := regexp.Compile(pattern) + if regexErr != nil { + //TODO add error handling here to avoid panic + panic(regexErr) + } + mr.regex = regex + mr.hasregex = true + } + mr.pattern = pattern + return mr } diff --git a/fiter_test.go b/fiter_test.go new file mode 100644 index 00000000..4a1b8639 --- /dev/null +++ b/fiter_test.go @@ -0,0 +1,24 @@ +package beego + +import ( + "github.com/astaxie/beego/context" + "net/http" + "net/http/httptest" + "testing" +) + +var FilterUser = func(ctx *context.Context) { + ctx.Output.Body([]byte("i am " + ctx.Input.Params[":last"] + ctx.Input.Params[":first"])) +} + +func TestFilter(t *testing.T) { + r, _ := http.NewRequest("GET", "/person/asta/Xie", nil) + w := httptest.NewRecorder() + handler := NewControllerRegistor() + handler.AddFilter("/person/:last/:first", "AfterStatic", FilterUser) + handler.Add("/person/:last/:first", &TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am astaXie" { + t.Errorf("user define func can't run") + } +} diff --git a/router.go b/router.go index 7300ecb9..395e7eb7 100644 --- a/router.go +++ b/router.go @@ -222,49 +222,6 @@ func (p *ControllerRegistor) AddAuto(c ControllerInterface) { } } -func buildFilter(pattern string, filter FilterFunc) *FilterRouter { - mr := new(FilterRouter) - mr.filterFunc = filter - parts := strings.Split(pattern, "/") - j := 0 - for i, part := range parts { - if strings.HasPrefix(part, ":") { - expr := "(.+)" - //a user may choose to override the default expression - // similar to expressjs: ‘/user/:id([0-9]+)’ - if index := strings.Index(part, "("); index != -1 { - expr = part[index:] - part = part[:index] - //match /user/:id:int ([0-9]+) - //match /post/:username:string ([\w]+) - } else if lindex := strings.LastIndex(part, ":"); lindex != 0 { - switch part[lindex:] { - case ":int": - expr = "([0-9]+)" - part = part[:lindex] - case ":string": - expr = `([\w]+)` - part = part[:lindex] - } - } - parts[i] = expr - j++ - } - } - if j != 0 { - pattern = strings.Join(parts, "/") - regex, regexErr := regexp.Compile(pattern) - if regexErr != nil { - //TODO add error handling here to avoid panic - panic(regexErr) - } - mr.regex = regex - mr.hasregex = true - } - mr.pattern = pattern - return mr -} - func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc) { mr := buildFilter(pattern, filter) switch action { @@ -469,7 +426,8 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) if p.enableFilter { if l, ok := p.filters[BeforeRouter]; ok { for _, filterR := range l { - if filterR.ValidRouter(r.URL.Path) { + if ok, p := filterR.ValidRouter(r.URL.Path); ok { + context.Input.Params = p filterR.filterFunc(context) if w.started { goto Admin @@ -516,7 +474,8 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) if p.enableFilter { if l, ok := p.filters[AfterStatic]; ok { for _, filterR := range l { - if filterR.ValidRouter(r.URL.Path) { + if ok, p := filterR.ValidRouter(r.URL.Path); ok { + context.Input.Params = p filterR.filterFunc(context) if w.started { goto Admin @@ -591,7 +550,8 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) if p.enableFilter { if l, ok := p.filters[BeforeExec]; ok { for _, filterR := range l { - if filterR.ValidRouter(r.URL.Path) { + if ok, p := filterR.ValidRouter(r.URL.Path); ok { + context.Input.Params = p filterR.filterFunc(context) if w.started { goto Admin @@ -746,7 +706,8 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) if p.enableFilter { if l, ok := p.filters[AfterExec]; ok { for _, filterR := range l { - if filterR.ValidRouter(r.URL.Path) { + if ok, p := filterR.ValidRouter(r.URL.Path); ok { + context.Input.Params = p filterR.filterFunc(context) if w.started { goto Admin @@ -795,7 +756,8 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) if p.enableFilter { if l, ok := p.filters[BeforeExec]; ok { for _, filterR := range l { - if filterR.ValidRouter(r.URL.Path) { + if ok, p := filterR.ValidRouter(r.URL.Path); ok { + context.Input.Params = p filterR.filterFunc(context) if w.started { goto Admin @@ -850,7 +812,8 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) if p.enableFilter { if l, ok := p.filters[AfterExec]; ok { for _, filterR := range l { - if filterR.ValidRouter(r.URL.Path) { + if ok, p := filterR.ValidRouter(r.URL.Path); ok { + context.Input.Params = p filterR.filterFunc(context) if w.started { goto Admin @@ -878,7 +841,8 @@ Admin: if p.enableFilter { if l, ok := p.filters[FinishRouter]; ok { for _, filterR := range l { - if filterR.ValidRouter(r.URL.Path) { + if ok, p := filterR.ValidRouter(r.URL.Path); ok { + context.Input.Params = p filterR.filterFunc(context) if w.started { break