1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-29 13:51:29 +00:00

Merge pull request #1863 from JessonChan/xsrf_fix

Xsrf fix
This commit is contained in:
astaxie 2016-04-12 14:26:20 +08:00
commit 9c400778d3
4 changed files with 91 additions and 14 deletions

View File

@ -65,6 +65,7 @@ func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) {
ctx.ResponseWriter.reset(rw) ctx.ResponseWriter.reset(rw)
ctx.Input.Reset(ctx) ctx.Input.Reset(ctx)
ctx.Output.Reset(ctx) ctx.Output.Reset(ctx)
ctx._xsrfToken = ""
} }
// Redirect does redirection to localurl with http header status code. // Redirect does redirection to localurl with http header status code.

47
context/context_test.go Normal file
View File

@ -0,0 +1,47 @@
// Copyright 2016 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package context
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestXsrfReset_01(t *testing.T) {
r := &http.Request{}
c := NewContext()
c.Request = r
c.ResponseWriter = &Response{}
c.ResponseWriter.reset(httptest.NewRecorder())
c.Output.Reset(c)
c.Input.Reset(c)
c.XSRFToken("key", 16)
if c._xsrfToken == "" {
t.FailNow()
}
token := c._xsrfToken
c.Reset(&Response{ResponseWriter: httptest.NewRecorder()}, r)
if c._xsrfToken != "" {
t.FailNow()
}
c.XSRFToken("key", 16)
if c._xsrfToken == "" {
t.FailNow()
}
if token == c._xsrfToken {
t.FailNow()
}
}

View File

@ -20,29 +20,25 @@ import (
"time" "time"
) )
var alphaNum = []byte(`0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz`)
// RandomCreateBytes generate random []byte by specify chars. // RandomCreateBytes generate random []byte by specify chars.
func RandomCreateBytes(n int, alphabets ...byte) []byte { func RandomCreateBytes(n int, alphabets ...byte) []byte {
const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" if len(alphabets) == 0 {
alphabets = alphaNum
}
var bytes = make([]byte, n) var bytes = make([]byte, n)
var randby bool var randBy bool
if num, err := rand.Read(bytes); num != n || err != nil { if num, err := rand.Read(bytes); num != n || err != nil {
r.Seed(time.Now().UnixNano()) r.Seed(time.Now().UnixNano())
randby = true randBy = true
} }
for i, b := range bytes { for i, b := range bytes {
if len(alphabets) == 0 { if randBy {
if randby {
bytes[i] = alphanum[r.Intn(len(alphanum))]
} else {
bytes[i] = alphanum[b%byte(len(alphanum))]
}
} else {
if randby {
bytes[i] = alphabets[r.Intn(len(alphabets))] bytes[i] = alphabets[r.Intn(len(alphabets))]
} else { } else {
bytes[i] = alphabets[b%byte(len(alphabets))] bytes[i] = alphabets[b%byte(len(alphabets))]
} }
} }
}
return bytes return bytes
} }

33
utils/rand_test.go Normal file
View File

@ -0,0 +1,33 @@
// Copyright 2016 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import "testing"
func TestRand_01(t *testing.T) {
bs0 := RandomCreateBytes(16)
bs1 := RandomCreateBytes(16)
t.Log(string(bs0), string(bs1))
if string(bs0) == string(bs1) {
t.FailNow()
}
bs0 = RandomCreateBytes(4, []byte(`a`)...)
if string(bs0) != "aaaa" {
t.FailNow()
}
}