diff --git a/context/context.go b/context/context.go index fee5e1c5..2735a5b3 100644 --- a/context/context.go +++ b/context/context.go @@ -65,6 +65,7 @@ func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) { ctx.ResponseWriter.reset(rw) ctx.Input.Reset(ctx) ctx.Output.Reset(ctx) + ctx._xsrfToken = "" } // Redirect does redirection to localurl with http header status code. diff --git a/context/context_test.go b/context/context_test.go new file mode 100644 index 00000000..7c0535e0 --- /dev/null +++ b/context/context_test.go @@ -0,0 +1,47 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestXsrfReset_01(t *testing.T) { + r := &http.Request{} + c := NewContext() + c.Request = r + c.ResponseWriter = &Response{} + c.ResponseWriter.reset(httptest.NewRecorder()) + c.Output.Reset(c) + c.Input.Reset(c) + c.XSRFToken("key", 16) + if c._xsrfToken == "" { + t.FailNow() + } + token := c._xsrfToken + c.Reset(&Response{ResponseWriter: httptest.NewRecorder()}, r) + if c._xsrfToken != "" { + t.FailNow() + } + c.XSRFToken("key", 16) + if c._xsrfToken == "" { + t.FailNow() + } + if token == c._xsrfToken { + t.FailNow() + } +} diff --git a/utils/rand.go b/utils/rand.go index 74bb4121..344d1cd5 100644 --- a/utils/rand.go +++ b/utils/rand.go @@ -20,28 +20,24 @@ import ( "time" ) +var alphaNum = []byte(`0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz`) + // RandomCreateBytes generate random []byte by specify chars. func RandomCreateBytes(n int, alphabets ...byte) []byte { - const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + if len(alphabets) == 0 { + alphabets = alphaNum + } var bytes = make([]byte, n) - var randby bool + var randBy bool if num, err := rand.Read(bytes); num != n || err != nil { r.Seed(time.Now().UnixNano()) - randby = true + randBy = true } for i, b := range bytes { - if len(alphabets) == 0 { - if randby { - bytes[i] = alphanum[r.Intn(len(alphanum))] - } else { - bytes[i] = alphanum[b%byte(len(alphanum))] - } + if randBy { + bytes[i] = alphabets[r.Intn(len(alphabets))] } else { - if randby { - bytes[i] = alphabets[r.Intn(len(alphabets))] - } else { - bytes[i] = alphabets[b%byte(len(alphabets))] - } + bytes[i] = alphabets[b%byte(len(alphabets))] } } return bytes diff --git a/utils/rand_test.go b/utils/rand_test.go new file mode 100644 index 00000000..6c238b5e --- /dev/null +++ b/utils/rand_test.go @@ -0,0 +1,33 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "testing" + +func TestRand_01(t *testing.T) { + bs0 := RandomCreateBytes(16) + bs1 := RandomCreateBytes(16) + + t.Log(string(bs0), string(bs1)) + if string(bs0) == string(bs1) { + t.FailNow() + } + + bs0 = RandomCreateBytes(4, []byte(`a`)...) + + if string(bs0) != "aaaa" { + t.FailNow() + } +}