1
0
mirror of https://github.com/astaxie/beego.git synced 2024-12-23 09:00:49 +00:00

fix #620 simple the sessionID generate

This commit is contained in:
astaxie 2014-11-04 19:04:26 +08:00
parent c4d8e4a244
commit fc6b9ce009

View File

@ -28,19 +28,13 @@
package session package session
import ( import (
"crypto/hmac"
"crypto/md5"
"crypto/rand" "crypto/rand"
"crypto/sha1"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/url" "net/url"
"time" "time"
"github.com/astaxie/beego/utils"
) )
// SessionStore contains all data for one session process with specific id. // SessionStore contains all data for one session process with specific id.
@ -86,11 +80,10 @@ type managerConfig struct {
Gclifetime int64 `json:"gclifetime"` Gclifetime int64 `json:"gclifetime"`
Maxlifetime int64 `json:"maxLifetime"` Maxlifetime int64 `json:"maxLifetime"`
Secure bool `json:"secure"` Secure bool `json:"secure"`
SessionIDHashFunc string `json:"sessionIDHashFunc"`
SessionIDHashKey string `json:"sessionIDHashKey"`
CookieLifeTime int `json:"cookieLifeTime"` CookieLifeTime int `json:"cookieLifeTime"`
ProviderConfig string `json:"providerConfig"` ProviderConfig string `json:"providerConfig"`
Domain string `json:"domain"` Domain string `json:"domain"`
SessionIdLength int64 `json:"sessionIdLength"`
} }
// Manager contains Provider and its configuration. // Manager contains Provider and its configuration.
@ -129,11 +122,9 @@ func NewManager(provideName, config string) (*Manager, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if cf.SessionIDHashFunc == "" {
cf.SessionIDHashFunc = "sha1" if cf.SessionIdLength == 0 {
} cf.SessionIdLength = 16
if cf.SessionIDHashKey == "" {
cf.SessionIDHashKey = string(generateRandomKey(16))
} }
return &Manager{ return &Manager{
@ -144,11 +135,14 @@ func NewManager(provideName, config string) (*Manager, error) {
// Start session. generate or read the session id from http request. // Start session. generate or read the session id from http request.
// if session id exists, return SessionStore with this id. // if session id exists, return SessionStore with this id.
func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) { func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore, err error) {
cookie, err := r.Cookie(manager.config.CookieName) cookie, errs := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" { if errs != nil || cookie.Value == "" {
sid := manager.sessionId(r) sid, errs := manager.sessionId(r)
session, _ = manager.provider.SessionRead(sid) if errs != nil {
return nil, errs
}
session, err = manager.provider.SessionRead(sid)
cookie = &http.Cookie{Name: manager.config.CookieName, cookie = &http.Cookie{Name: manager.config.CookieName,
Value: url.QueryEscape(sid), Value: url.QueryEscape(sid),
Path: "/", Path: "/",
@ -163,12 +157,18 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
} }
r.AddCookie(cookie) r.AddCookie(cookie)
} else { } else {
sid, _ := url.QueryUnescape(cookie.Value) sid, errs := url.QueryUnescape(cookie.Value)
if errs != nil {
return nil, errs
}
if manager.provider.SessionExist(sid) { if manager.provider.SessionExist(sid) {
session, _ = manager.provider.SessionRead(sid) session, err = manager.provider.SessionRead(sid)
} else { } else {
sid = manager.sessionId(r) sid, err = manager.sessionId(r)
session, _ = manager.provider.SessionRead(sid) if err != nil {
return nil, err
}
session, err = manager.provider.SessionRead(sid)
cookie = &http.Cookie{Name: manager.config.CookieName, cookie = &http.Cookie{Name: manager.config.CookieName,
Value: url.QueryEscape(sid), Value: url.QueryEscape(sid),
Path: "/", Path: "/",
@ -219,7 +219,10 @@ func (manager *Manager) GC() {
// Regenerate a session id for this SessionStore who's id is saving in http request. // Regenerate a session id for this SessionStore who's id is saving in http request.
func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) { func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) {
sid := manager.sessionId(r) sid, err := manager.sessionId(r)
if err != nil {
return
}
cookie, err := r.Cookie(manager.config.CookieName) cookie, err := r.Cookie(manager.config.CookieName)
if err != nil && cookie.Value == "" { if err != nil && cookie.Value == "" {
//delete old cookie //delete old cookie
@ -251,36 +254,16 @@ func (manager *Manager) GetActiveSession() int {
return manager.provider.SessionAll() return manager.provider.SessionAll()
} }
// Set hash function for generating session id.
func (manager *Manager) SetHashFunc(hasfunc, hashkey string) {
manager.config.SessionIDHashFunc = hasfunc
manager.config.SessionIDHashKey = hashkey
}
// Set cookie with https. // Set cookie with https.
func (manager *Manager) SetSecure(secure bool) { func (manager *Manager) SetSecure(secure bool) {
manager.config.Secure = secure manager.config.Secure = secure
} }
// generate session id with rand string, unix nano time, remote addr by hash function. func (manager *Manager) sessionId(r *http.Request) (string, error) {
func (manager *Manager) sessionId(r *http.Request) (sid string) { b := make([]byte, manager.config.SessionIdLength)
bs := make([]byte, 32) n, err := rand.Read(b)
if n, err := io.ReadFull(rand.Reader, bs); n != 32 || err != nil { if n != len(b) || err != nil {
bs = utils.RandomCreateBytes(32) return "", fmt.Errorf("Could not successfully read from the system CSPRNG.")
} }
sig := fmt.Sprintf("%s%d%s", r.RemoteAddr, time.Now().UnixNano(), bs) return hex.EncodeToString(b), nil
if manager.config.SessionIDHashFunc == "md5" {
h := md5.New()
h.Write([]byte(sig))
sid = hex.EncodeToString(h.Sum(nil))
} else if manager.config.SessionIDHashFunc == "sha1" {
h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey))
fmt.Fprintf(h, "%s", sig)
sid = hex.EncodeToString(h.Sum(nil))
} else {
h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey))
fmt.Fprintf(h, "%s", sig)
sid = hex.EncodeToString(h.Sum(nil))
}
return
} }