From 3a5de83ec243e4dd509c488ea1edabb3ebc3c2a2 Mon Sep 17 00:00:00 2001 From: astaxie Date: Sun, 28 Sep 2014 22:10:43 +0800 Subject: [PATCH] beego: support router case sensitive --- config.go | 7 +++++++ router.go | 16 ++++++++++++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index db234e16..9f952735 100644 --- a/config.go +++ b/config.go @@ -81,6 +81,7 @@ var ( FlashSeperator string // used to seperate flash key:value AppConfigProvider string // config provider EnableDocs bool // enable generate docs & server docs API Swagger + RouterCaseSensitive bool // router case sensitive default is true ) func init() { @@ -164,6 +165,8 @@ func init() { FlashName = "BEEGO_FLASH" FlashSeperator = "BEEGOFLASH" + RouterCaseSensitive = true + runtime.GOMAXPROCS(runtime.NumCPU()) // init BeeLogger @@ -375,6 +378,10 @@ func ParseConfig() (err error) { if enabledocs, err := GetConfig("bool", "EnableDocs"); err == nil { EnableDocs = enabledocs.(bool) } + + if casesensitive, err := GetConfig("bool", "RouterCaseSensitive"); err == nil { + RouterCaseSensitive = casesensitive.(bool) + } } return nil } diff --git a/router.go b/router.go index 18be04d3..b14cc5dd 100644 --- a/router.go +++ b/router.go @@ -163,6 +163,9 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM } func (p *ControllerRegistor) addToRouter(method, pattern string, r *controllerInfo) { + if !RouterCaseSensitive { + pattern = strings.ToLower(pattern) + } if t, ok := p.routers[method]; ok { t.AddRouter(pattern, r) } else { @@ -381,6 +384,9 @@ func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter Filter mr.tree = NewTree() mr.pattern = pattern mr.filterFunc = filter + if !RouterCaseSensitive { + pattern = strings.ToLower(pattern) + } mr.tree.AddRouter(pattern, true) return p.insertFilterRouter(pos, mr) } @@ -565,12 +571,18 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) context.Output.Context = context context.Output.EnableGzip = EnableGzip + var urlPath string + if !RouterCaseSensitive { + urlPath = strings.ToLower(r.URL.Path) + } else { + urlPath = r.URL.Path + } // defined filter function do_filter := func(pos int) (started bool) { if p.enableFilter { if l, ok := p.filters[pos]; ok { for _, filterR := range l { - if ok, p := filterR.ValidRouter(r.URL.Path); ok { + if ok, p := filterR.ValidRouter(urlPath); ok { context.Input.Params = p filterR.filterFunc(context) if w.started { @@ -628,7 +640,7 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) if !findrouter { if t, ok := p.routers[r.Method]; ok { - runObject, p := t.Match(r.URL.Path) + runObject, p := t.Match(urlPath) if r, ok := runObject.(*controllerInfo); ok { routerInfo = r findrouter = true