From 481448fa90dbceb9554d17b69d293ee0282c6674 Mon Sep 17 00:00:00 2001 From: astaxie Date: Sun, 5 Jan 2014 14:48:36 +0800 Subject: [PATCH] modify session module change a log --- session/README.md | 35 +++++--- session/sess_cookie.go | 143 +++++++++++++++++++++++++++++ session/sess_file.go | 3 +- session/sess_gob.go | 38 -------- session/sess_mem.go | 10 +-- session/sess_mem_test.go | 35 ++++++++ session/sess_mysql.go | 3 +- session/sess_redis.go | 3 +- session/sess_test.go | 81 +++++++++++++++++ session/sess_utils.go | 188 +++++++++++++++++++++++++++++++++++++++ session/session.go | 154 +++++++++++++++----------------- 11 files changed, 551 insertions(+), 142 deletions(-) create mode 100644 session/sess_cookie.go delete mode 100644 session/sess_gob.go create mode 100644 session/sess_mem_test.go create mode 100644 session/sess_utils.go diff --git a/session/README.md b/session/README.md index 220100ef..2ebf069a 100644 --- a/session/README.md +++ b/session/README.md @@ -28,21 +28,21 @@ Then in you web app init the global session manager * Use **memory** as provider: func init() { - globalSessions, _ = session.NewManager("memory", "gosessionid", 3600,"") + globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`) go globalSessions.GC() } * Use **file** as provider, the last param is the path where you want file to be stored: func init() { - globalSessions, _ = session.NewManager("file", "gosessionid", 3600, "./tmp") + globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","./tmp"}`) go globalSessions.GC() } * Use **Redis** as provider, the last param is the Redis conn address,poolsize,password: func init() { - globalSessions, _ = session.NewManager("redis", "gosessionid", 3600, "127.0.0.1:6379,100,astaxie") + globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","127.0.0.1:6379,100,astaxie"}`) go globalSessions.GC() } @@ -50,15 +50,24 @@ Then in you web app init the global session manager func init() { globalSessions, _ = session.NewManager( - "mysql", "gosessionid", 3600, "username:password@protocol(address)/dbname?param=value") + "mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","username:password@protocol(address)/dbname?param=value"}`) go globalSessions.GC() } +* Use **Cookie** as provider: + + func init() { + globalSessions, _ = session.NewManager( + "cookie", `{"cookieName":"gosessionid","enableSetCookie":false,gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`) + go globalSessions.GC() + } + + Finally in the handlerfunc you can use it like this func login(w http.ResponseWriter, r *http.Request) { sess := globalSessions.SessionStart(w, r) - defer sess.SessionRelease() + defer sess.SessionRelease(w) username := sess.Get("username") fmt.Println(username) if r.Method == "GET" { @@ -78,19 +87,19 @@ When you develop a web app, maybe you want to write own provider because you mus Writing a provider is easy. You only need to define two struct types (Session and Provider), which satisfy the interface definition. -Maybe you will find the **memory** provider as good example. +Maybe you will find the **memory** provider is a good example. type SessionStore interface { - Set(key, value interface{}) error //set session value - Get(key interface{}) interface{} //get session value - Delete(key interface{}) error //delete session value - SessionID() string //back current sessionID - SessionRelease() // release the resource & save data to provider - Flush() error //delete all data + Set(key, value interface{}) error //set session value + Get(key interface{}) interface{} //get session value + Delete(key interface{}) error //delete session value + SessionID() string //back current sessionID + SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data + Flush() error //delete all data } type Provider interface { - SessionInit(maxlifetime int64, savePath string) error + SessionInit(gclifetime int64, config string) error SessionRead(sid string) (SessionStore, error) SessionExist(sid string) bool SessionRegenerate(oldsid, sid string) (SessionStore, error) diff --git a/session/sess_cookie.go b/session/sess_cookie.go new file mode 100644 index 00000000..deff70a0 --- /dev/null +++ b/session/sess_cookie.go @@ -0,0 +1,143 @@ +package session + +import ( + "crypto/aes" + "crypto/cipher" + "encoding/json" + "net/http" + "net/url" + "sync" +) + +var cookiepder = &CookieProvider{} + +type CookieSessionStore struct { + sid string + values map[interface{}]interface{} //session data + lock sync.RWMutex +} + +func (st *CookieSessionStore) Set(key, value interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + st.values[key] = value + return nil +} + +func (st *CookieSessionStore) Get(key interface{}) interface{} { + st.lock.RLock() + defer st.lock.RUnlock() + if v, ok := st.values[key]; ok { + return v + } else { + return nil + } + return nil +} + +func (st *CookieSessionStore) Delete(key interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + delete(st.values, key) + return nil +} + +func (st *CookieSessionStore) Flush() error { + st.lock.Lock() + defer st.lock.Unlock() + st.values = make(map[interface{}]interface{}) + return nil +} + +func (st *CookieSessionStore) SessionID() string { + return st.sid +} + +func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) { + str, err := encodeCookie(cookiepder.block, + cookiepder.config.SecurityKey, + cookiepder.config.SecurityName, + st.values) + if err != nil { + return + } + cookie := &http.Cookie{Name: cookiepder.config.CookieName, + Value: url.QueryEscape(str), + Path: "/", + HttpOnly: true, + Secure: cookiepder.config.Secure} + http.SetCookie(w, cookie) + return +} + +type cookieConfig struct { + SecurityKey string `json:"securityKey"` + BlockKey string `json:"blockKey"` + SecurityName string `json:"securityName"` + CookieName string `json:"cookieName"` + Secure bool `json:"secure"` + Maxage int `json:"maxage"` +} + +type CookieProvider struct { + maxlifetime int64 + config *cookieConfig + block cipher.Block +} + +func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error { + pder.config = &cookieConfig{} + err := json.Unmarshal([]byte(config), pder.config) + if err != nil { + return err + } + if pder.config.BlockKey == "" { + pder.config.BlockKey = string(generateRandomKey(16)) + } + if pder.config.SecurityName == "" { + pder.config.SecurityName = string(generateRandomKey(20)) + } + pder.block, err = aes.NewCipher([]byte(pder.config.BlockKey)) + if err != nil { + return err + } + return nil +} + +func (pder *CookieProvider) SessionRead(sid string) (SessionStore, error) { + kv := make(map[interface{}]interface{}) + kv, _ = decodeCookie(pder.block, + pder.config.SecurityKey, + pder.config.SecurityName, + sid, pder.maxlifetime) + rs := &CookieSessionStore{sid: sid, values: kv} + return rs, nil +} + +func (pder *CookieProvider) SessionExist(sid string) bool { + return true +} + +func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) { + return nil, nil +} + +func (pder *CookieProvider) SessionDestroy(sid string) error { + return nil +} + +func (pder *CookieProvider) SessionGC() { + return +} + +func (pder *CookieProvider) SessionAll() int { + return 0 +} + +func (pder *CookieProvider) SessionUpdate(sid string) error { + return nil +} + +func init() { + Register("cookie", cookiepder) +} diff --git a/session/sess_file.go b/session/sess_file.go index 1db4022e..5d33d0e2 100644 --- a/session/sess_file.go +++ b/session/sess_file.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "io/ioutil" + "net/http" "os" "path" "path/filepath" @@ -60,7 +61,7 @@ func (fs *FileSessionStore) SessionID() string { return fs.sid } -func (fs *FileSessionStore) SessionRelease() { +func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { defer fs.f.Close() b, err := encodeGob(fs.values) if err != nil { diff --git a/session/sess_gob.go b/session/sess_gob.go deleted file mode 100644 index 92313947..00000000 --- a/session/sess_gob.go +++ /dev/null @@ -1,38 +0,0 @@ -package session - -import ( - "bytes" - "encoding/gob" -) - -func init() { - gob.Register([]interface{}{}) - gob.Register(map[int]interface{}{}) - gob.Register(map[string]interface{}{}) - gob.Register(map[interface{}]interface{}{}) - gob.Register(map[string]string{}) - gob.Register(map[int]string{}) - gob.Register(map[int]int{}) - gob.Register(map[int]int64{}) -} - -func encodeGob(obj map[interface{}]interface{}) ([]byte, error) { - buf := bytes.NewBuffer(nil) - enc := gob.NewEncoder(buf) - err := enc.Encode(obj) - if err != nil { - return []byte(""), err - } - return buf.Bytes(), nil -} - -func decodeGob(encoded []byte) (map[interface{}]interface{}, error) { - buf := bytes.NewBuffer(encoded) - dec := gob.NewDecoder(buf) - var out map[interface{}]interface{} - err := dec.Decode(&out) - if err != nil { - return nil, err - } - return out, nil -} diff --git a/session/sess_mem.go b/session/sess_mem.go index 2e615c6f..c74c2602 100644 --- a/session/sess_mem.go +++ b/session/sess_mem.go @@ -2,6 +2,7 @@ package session import ( "container/list" + "net/http" "sync" "time" ) @@ -9,9 +10,9 @@ import ( var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)} type MemSessionStore struct { - sid string //session id唯一标示 - timeAccessed time.Time //最后访问时间 - value map[interface{}]interface{} //session里面存储的值 + sid string //session id + timeAccessed time.Time //last access time + value map[interface{}]interface{} //session store lock sync.RWMutex } @@ -51,8 +52,7 @@ func (st *MemSessionStore) SessionID() string { return st.sid } -func (st *MemSessionStore) SessionRelease() { - +func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) { } type MemProvider struct { diff --git a/session/sess_mem_test.go b/session/sess_mem_test.go new file mode 100644 index 00000000..df2a9a1e --- /dev/null +++ b/session/sess_mem_test.go @@ -0,0 +1,35 @@ +package session + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestMem(t *testing.T) { + globalSessions, _ := NewManager("memory", `{"cookieName":"gosessionid","gclifetime":10}`) + go globalSessions.GC() + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + sess := globalSessions.SessionStart(w, r) + defer sess.SessionRelease(w) + err := sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set error,", err) + } + if username := sess.Get("username"); username != "astaxie" { + t.Fatal("get username error") + } + if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { + t.Fatal("setcookie error") + } else { + parts := strings.Split(strings.TrimSpace(cookiestr), ";") + for k, v := range parts { + nameval := strings.Split(v, "=") + if k == 0 && nameval[0] != "gosessionid" { + t.Fatal("error") + } + } + } +} diff --git a/session/sess_mysql.go b/session/sess_mysql.go index f1a59d4f..1101e437 100644 --- a/session/sess_mysql.go +++ b/session/sess_mysql.go @@ -9,6 +9,7 @@ package session import ( "database/sql" + "net/http" "sync" "time" @@ -60,7 +61,7 @@ func (st *MysqlSessionStore) SessionID() string { return st.sid } -func (st *MysqlSessionStore) SessionRelease() { +func (st *MysqlSessionStore) SessionRelease(w http.ResponseWriter) { defer st.c.Close() if len(st.values) > 0 { b, err := encodeGob(st.values) diff --git a/session/sess_redis.go b/session/sess_redis.go index e582c6ed..0f8c0308 100644 --- a/session/sess_redis.go +++ b/session/sess_redis.go @@ -1,6 +1,7 @@ package session import ( + "net/http" "strconv" "strings" "sync" @@ -58,7 +59,7 @@ func (rs *RedisSessionStore) SessionID() string { return rs.sid } -func (rs *RedisSessionStore) SessionRelease() { +func (rs *RedisSessionStore) SessionRelease(w http.ResponseWriter) { defer rs.c.Close() if len(rs.values) > 0 { b, err := encodeGob(rs.values) diff --git a/session/sess_test.go b/session/sess_test.go index b7d0b38c..d754b526 100644 --- a/session/sess_test.go +++ b/session/sess_test.go @@ -1,6 +1,8 @@ package session import ( + "crypto/aes" + "encoding/json" "testing" ) @@ -26,3 +28,82 @@ func Test_gob(t *testing.T) { t.Error("decode int error") } } + +func TestGenerate(t *testing.T) { + str := generateRandomKey(20) + if len(str) != 20 { + t.Fatal("generate length is not equal to 20") + } +} + +func TestCookieEncodeDecode(t *testing.T) { + hashKey := "testhashKey" + blockkey := generateRandomKey(16) + block, err := aes.NewCipher(blockkey) + if err != nil { + t.Fatal("NewCipher:", err) + } + securityName := string(generateRandomKey(20)) + val := make(map[interface{}]interface{}) + val["name"] = "astaxie" + val["gender"] = "male" + str, err := encodeCookie(block, hashKey, securityName, val) + if err != nil { + t.Fatal("encodeCookie:", err) + } + dst := make(map[interface{}]interface{}) + dst, err = decodeCookie(block, hashKey, securityName, str, 3600) + if err != nil { + t.Fatal("decodeCookie", err) + } + if dst["name"] != "astaxie" { + t.Fatal("dst get map error") + } + if dst["gender"] != "male" { + t.Fatal("dst get map error") + } +} + +func TestParseConfig(t *testing.T) { + s := `{"cookieName":"gosessionid","gclifetime":3600}` + cf := new(managerConfig) + cf.EnableSetCookie = true + err := json.Unmarshal([]byte(s), cf) + if err != nil { + t.Fatal("parse json error,", err) + } + if cf.CookieName != "gosessionid" { + t.Fatal("parseconfig get cookiename error") + } + if cf.Gclifetime != 3600 { + t.Fatal("parseconfig get gclifetime error") + } + + cc := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + cf2 := new(managerConfig) + cf2.EnableSetCookie = true + err = json.Unmarshal([]byte(cc), cf2) + if err != nil { + t.Fatal("parse json error,", err) + } + if cf2.CookieName != "gosessionid" { + t.Fatal("parseconfig get cookiename error") + } + if cf2.Gclifetime != 3600 { + t.Fatal("parseconfig get gclifetime error") + } + if cf2.EnableSetCookie != false { + t.Fatal("parseconfig get enableSetCookie error") + } + cconfig := new(cookieConfig) + err = json.Unmarshal([]byte(cf2.ProviderConfig), cconfig) + if err != nil { + t.Fatal("parse ProviderConfig err,", err) + } + if cconfig.CookieName != "gosessionid" { + t.Fatal("ProviderConfig get cookieName error") + } + if cconfig.SecurityKey != "beegocookiehashkey" { + t.Fatal("ProviderConfig get securityKey error") + } +} diff --git a/session/sess_utils.go b/session/sess_utils.go new file mode 100644 index 00000000..73f96630 --- /dev/null +++ b/session/sess_utils.go @@ -0,0 +1,188 @@ +package session + +import ( + "bytes" + "crypto/cipher" + "crypto/hmac" + "crypto/rand" + "crypto/sha1" + "crypto/subtle" + "encoding/base64" + "encoding/gob" + "errors" + "fmt" + "io" + "strconv" + "time" +) + +func init() { + gob.Register([]interface{}{}) + gob.Register(map[int]interface{}{}) + gob.Register(map[string]interface{}{}) + gob.Register(map[interface{}]interface{}{}) + gob.Register(map[string]string{}) + gob.Register(map[int]string{}) + gob.Register(map[int]int{}) + gob.Register(map[int]int64{}) +} + +func encodeGob(obj map[interface{}]interface{}) ([]byte, error) { + buf := bytes.NewBuffer(nil) + enc := gob.NewEncoder(buf) + err := enc.Encode(obj) + if err != nil { + return []byte(""), err + } + return buf.Bytes(), nil +} + +func decodeGob(encoded []byte) (map[interface{}]interface{}, error) { + buf := bytes.NewBuffer(encoded) + dec := gob.NewDecoder(buf) + var out map[interface{}]interface{} + err := dec.Decode(&out) + if err != nil { + return nil, err + } + return out, nil +} + +// generateRandomKey creates a random key with the given strength. +func generateRandomKey(strength int) []byte { + k := make([]byte, strength) + if _, err := io.ReadFull(rand.Reader, k); err != nil { + return nil + } + return k +} + +// Encryption ----------------------------------------------------------------- + +// encrypt encrypts a value using the given block in counter mode. +// +// A random initialization vector (http://goo.gl/zF67k) with the length of the +// block size is prepended to the resulting ciphertext. +func encrypt(block cipher.Block, value []byte) ([]byte, error) { + iv := generateRandomKey(block.BlockSize()) + if iv == nil { + return nil, errors.New("encrypt: failed to generate random iv") + } + // Encrypt it. + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(value, value) + // Return iv + ciphertext. + return append(iv, value...), nil +} + +// decrypt decrypts a value using the given block in counter mode. +// +// The value to be decrypted must be prepended by a initialization vector +// (http://goo.gl/zF67k) with the length of the block size. +func decrypt(block cipher.Block, value []byte) ([]byte, error) { + size := block.BlockSize() + if len(value) > size { + // Extract iv. + iv := value[:size] + // Extract ciphertext. + value = value[size:] + // Decrypt it. + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(value, value) + return value, nil + } + return nil, errors.New("decrypt: the value could not be decrypted") +} + +func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) { + var err error + var b []byte + // 1. encodeGob. + if b, err = encodeGob(value); err != nil { + return "", err + } + // 2. Encrypt (optional). + if b, err = encrypt(block, b); err != nil { + return "", err + } + b = encode(b) + // 3. Create MAC for "name|date|value". Extra pipe to be used later. + b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b)) + h := hmac.New(sha1.New, []byte(hashKey)) + h.Write(b) + sig := h.Sum(nil) + // Append mac, remove name. + b = append(b, sig...)[len(name)+1:] + // 4. Encode to base64. + b = encode(b) + // Done. + return string(b), nil +} + +func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) { + // 1. Decode from base64. + b, err := decode([]byte(value)) + if err != nil { + return nil, err + } + // 2. Verify MAC. Value is "date|value|mac". + parts := bytes.SplitN(b, []byte("|"), 3) + if len(parts) != 3 { + return nil, errors.New("Decode: invalid value %v") + } + + b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...) + h := hmac.New(sha1.New, []byte(hashKey)) + h.Write(b) + sig := h.Sum(nil) + if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 { + return nil, errors.New("Decode: the value is not valid") + } + // 3. Verify date ranges. + var t1 int64 + if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil { + return nil, errors.New("Decode: invalid timestamp") + } + t2 := time.Now().UTC().Unix() + if t1 > t2 { + return nil, errors.New("Decode: timestamp is too new") + } + if t1 < t2-gcmaxlifetime { + return nil, errors.New("Decode: expired timestamp") + } + // 4. Decrypt (optional). + b, err = decode(parts[1]) + if err != nil { + return nil, err + } + if b, err = decrypt(block, b); err != nil { + return nil, err + } + // 5. decodeGob. + if dst, err := decodeGob(b); err != nil { + return nil, err + } else { + return dst, nil + } + // Done. + return nil, nil +} + +// Encoding ------------------------------------------------------------------- + +// encode encodes a value using base64. +func encode(value []byte) []byte { + encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value))) + base64.URLEncoding.Encode(encoded, value) + return encoded +} + +// decode decodes a cookie using base64. +func decode(value []byte) ([]byte, error) { + decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value))) + b, err := base64.URLEncoding.Decode(decoded, value) + if err != nil { + return nil, err + } + return decoded[:b], nil +} diff --git a/session/session.go b/session/session.go index 062bbfd6..df348fab 100644 --- a/session/session.go +++ b/session/session.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "crypto/sha1" "encoding/hex" + "encoding/json" "fmt" "io" "net/http" @@ -14,16 +15,16 @@ import ( ) type SessionStore interface { - Set(key, value interface{}) error //set session value - Get(key interface{}) interface{} //get session value - Delete(key interface{}) error //delete session value - SessionID() string //back current sessionID - SessionRelease() // release the resource & save data to provider - Flush() error //delete all data + Set(key, value interface{}) error //set session value + Get(key interface{}) interface{} //get session value + Delete(key interface{}) error //delete session value + SessionID() string //back current sessionID + SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data + Flush() error //delete all data } type Provider interface { - SessionInit(maxlifetime int64, savePath string) error + SessionInit(gclifetime int64, config string) error SessionRead(sid string) (SessionStore, error) SessionExist(sid string) bool SessionRegenerate(oldsid, sid string) (SessionStore, error) @@ -47,15 +48,21 @@ func Register(name string, provide Provider) { provides[name] = provide } +type managerConfig struct { + CookieName string `json:"cookieName"` + EnableSetCookie bool `json:"enableSetCookie,omitempty"` + Gclifetime int64 `json:"gclifetime"` + Maxage int `json:"maxage"` + Secure bool `json:"secure"` + SessionIDHashFunc string `json:"sessionIDHashFunc"` + SessionIDHashKey string `json:"sessionIDHashKey"` + CookieLifeTime int64 `json:"cookieLifeTime"` + ProviderConfig string `json:"providerConfig"` +} + type Manager struct { - cookieName string //private cookiename - provider Provider - maxlifetime int64 - hashfunc string //support md5 & sha1 - hashkey string - maxage int //cookielifetime - secure bool - options []interface{} + provider Provider + config *managerConfig } //options @@ -63,74 +70,49 @@ type Manager struct { //2. hashfunc default sha1 //3. hashkey default beegosessionkey //4. maxage default is none -func NewManager(provideName, cookieName string, maxlifetime int64, savePath string, options ...interface{}) (*Manager, error) { +func NewManager(provideName, config string) (*Manager, error) { provider, ok := provides[provideName] if !ok { return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName) } - provider.SessionInit(maxlifetime, savePath) - secure := false - if len(options) > 0 { - secure = options[0].(bool) + cf := new(managerConfig) + cf.EnableSetCookie = true + err := json.Unmarshal([]byte(config), cf) + if err != nil { + return nil, err } - hashfunc := "sha1" - if len(options) > 1 { - hashfunc = options[1].(string) + provider.SessionInit(cf.Gclifetime, cf.ProviderConfig) + + if cf.SessionIDHashFunc == "" { + cf.SessionIDHashFunc = "sha1" } - hashkey := "beegosessionkey" - if len(options) > 2 { - hashkey = options[2].(string) - } - maxage := -1 - if len(options) > 3 { - switch options[3].(type) { - case int: - if options[3].(int) > 0 { - maxage = options[3].(int) - } else if options[3].(int) < 0 { - maxage = 0 - } - case int64: - if options[3].(int64) > 0 { - maxage = int(options[3].(int64)) - } else if options[3].(int64) < 0 { - maxage = 0 - } - case int32: - if options[3].(int32) > 0 { - maxage = int(options[3].(int32)) - } else if options[3].(int32) < 0 { - maxage = 0 - } - } + if cf.SessionIDHashKey == "" { + cf.SessionIDHashKey = string(generateRandomKey(16)) } + return &Manager{ - provider: provider, - cookieName: cookieName, - maxlifetime: maxlifetime, - hashfunc: hashfunc, - hashkey: hashkey, - maxage: maxage, - secure: secure, - options: options, + provider, + cf, }, nil } //get Session func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) { - cookie, err := r.Cookie(manager.cookieName) + cookie, err := r.Cookie(manager.config.CookieName) if err != nil || cookie.Value == "" { sid := manager.sessionId(r) session, _ = manager.provider.SessionRead(sid) - cookie = &http.Cookie{Name: manager.cookieName, + cookie = &http.Cookie{Name: manager.config.CookieName, Value: url.QueryEscape(sid), Path: "/", HttpOnly: true, - Secure: manager.secure} - if manager.maxage >= 0 { - cookie.MaxAge = manager.maxage + Secure: manager.config.Secure} + if manager.config.Maxage >= 0 { + cookie.MaxAge = manager.config.Maxage + } + if manager.config.EnableSetCookie { + http.SetCookie(w, cookie) } - http.SetCookie(w, cookie) r.AddCookie(cookie) } else { sid, _ := url.QueryUnescape(cookie.Value) @@ -139,15 +121,17 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se } else { sid = manager.sessionId(r) session, _ = manager.provider.SessionRead(sid) - cookie = &http.Cookie{Name: manager.cookieName, + cookie = &http.Cookie{Name: manager.config.CookieName, Value: url.QueryEscape(sid), Path: "/", HttpOnly: true, - Secure: manager.secure} - if manager.maxage >= 0 { - cookie.MaxAge = manager.maxage + Secure: manager.config.Secure} + if manager.config.Maxage >= 0 { + cookie.MaxAge = manager.config.Maxage + } + if manager.config.EnableSetCookie { + http.SetCookie(w, cookie) } - http.SetCookie(w, cookie) r.AddCookie(cookie) } } @@ -156,13 +140,17 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se //Destroy sessionid func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { - cookie, err := r.Cookie(manager.cookieName) + cookie, err := r.Cookie(manager.config.CookieName) if err != nil || cookie.Value == "" { return } else { manager.provider.SessionDestroy(cookie.Value) expiration := time.Now() - cookie := http.Cookie{Name: manager.cookieName, Path: "/", HttpOnly: true, Expires: expiration, MaxAge: -1} + cookie := http.Cookie{Name: manager.config.CookieName, + Path: "/", + HttpOnly: true, + Expires: expiration, + MaxAge: -1} http.SetCookie(w, &cookie) } } @@ -174,20 +162,20 @@ func (manager *Manager) GetProvider(sid string) (sessions SessionStore, err erro func (manager *Manager) GC() { manager.provider.SessionGC() - time.AfterFunc(time.Duration(manager.maxlifetime)*time.Second, func() { manager.GC() }) + time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() }) } func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) { sid := manager.sessionId(r) - cookie, err := r.Cookie(manager.cookieName) + cookie, err := r.Cookie(manager.config.CookieName) if err != nil && cookie.Value == "" { //delete old cookie session, _ = manager.provider.SessionRead(sid) - cookie = &http.Cookie{Name: manager.cookieName, + cookie = &http.Cookie{Name: manager.config.CookieName, Value: url.QueryEscape(sid), Path: "/", HttpOnly: true, - Secure: manager.secure, + Secure: manager.config.Secure, } } else { oldsid, _ := url.QueryUnescape(cookie.Value) @@ -196,8 +184,8 @@ func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Reque cookie.HttpOnly = true cookie.Path = "/" } - if manager.maxage >= 0 { - cookie.MaxAge = manager.maxage + if manager.config.Maxage >= 0 { + cookie.MaxAge = manager.config.Maxage } http.SetCookie(w, cookie) r.AddCookie(cookie) @@ -209,12 +197,12 @@ func (manager *Manager) GetActiveSession() int { } func (manager *Manager) SetHashFunc(hasfunc, hashkey string) { - manager.hashfunc = hasfunc - manager.hashkey = hashkey + manager.config.SessionIDHashFunc = hasfunc + manager.config.SessionIDHashKey = hashkey } func (manager *Manager) SetSecure(secure bool) { - manager.secure = secure + manager.config.Secure = secure } //remote_addr cruunixnano randdata @@ -224,16 +212,16 @@ func (manager *Manager) sessionId(r *http.Request) (sid string) { return "" } sig := fmt.Sprintf("%s%d%s", r.RemoteAddr, time.Now().UnixNano(), bs) - if manager.hashfunc == "md5" { + if manager.config.SessionIDHashFunc == "md5" { h := md5.New() h.Write([]byte(sig)) sid = hex.EncodeToString(h.Sum(nil)) - } else if manager.hashfunc == "sha1" { - h := hmac.New(sha1.New, []byte(manager.hashkey)) + } else if manager.config.SessionIDHashFunc == "sha1" { + h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey)) fmt.Fprintf(h, "%s", sig) sid = hex.EncodeToString(h.Sum(nil)) } else { - h := hmac.New(sha1.New, []byte(manager.hashkey)) + h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey)) fmt.Fprintf(h, "%s", sig) sid = hex.EncodeToString(h.Sum(nil)) }