From 02c2e162535ac37825b5b8161e1544d7d3306659 Mon Sep 17 00:00:00 2001 From: astaxie Date: Thu, 26 Sep 2013 18:07:00 +0800 Subject: [PATCH] Strengthens the session's function --- session/sess_file.go | 59 +++++++++ session/sess_mem.go | 286 +++++++++++++++++++++++------------------- session/sess_mysql.go | 30 +++++ session/sess_redis.go | 15 +++ session/session.go | 142 +++++++++++++++++++-- 5 files changed, 397 insertions(+), 135 deletions(-) diff --git a/session/sess_file.go b/session/sess_file.go index c3306b78..955cc9e0 100644 --- a/session/sess_file.go +++ b/session/sess_file.go @@ -1,6 +1,8 @@ package session import ( + "errors" + "io" "io/ioutil" "os" "path" @@ -48,6 +50,14 @@ func (fs *FileSessionStore) Delete(key interface{}) error { return nil } +func (fs *FileSessionStore) Flush() error { + fs.lock.Lock() + defer fs.lock.Unlock() + fs.values = make(map[interface{}]interface{}) + fs.updatecontent() + return nil +} + func (fs *FileSessionStore) SessionID() string { return fs.sid } @@ -121,6 +131,55 @@ func (fp *FileProvider) SessionGC() { filepath.Walk(fp.savePath, gcpath) } +func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) { + err := os.MkdirAll(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])), 0777) + if err != nil { + println(err.Error()) + } + err = os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777) + if err != nil { + println(err.Error()) + } + _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + var newf *os.File + if err == nil { + return nil, errors.New("newsid exist") + } else if os.IsNotExist(err) { + newf, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + } + + _, err = os.Stat(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]), oldsid)) + var f *os.File + if err == nil { + f, err = os.OpenFile(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]), oldsid), os.O_RDWR, 0777) + io.Copy(newf, f) + } else if os.IsNotExist(err) { + newf, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + } else { + return nil, err + } + f.Close() + os.Remove(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]))) + os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now()) + var kv map[interface{}]interface{} + b, err := ioutil.ReadAll(newf) + if err != nil { + return nil, err + } + if len(b) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = decodeGob(b) + if err != nil { + return nil, err + } + } + + newf, err = os.OpenFile(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), os.O_WRONLY|os.O_CREATE, 0777) + ss := &FileSessionStore{f: newf, sid: sid, values: kv} + return ss, nil +} + func gcpath(path string, info os.FileInfo, err error) error { if err != nil { return err diff --git a/session/sess_mem.go b/session/sess_mem.go index 8ba391f8..fd021dac 100644 --- a/session/sess_mem.go +++ b/session/sess_mem.go @@ -1,128 +1,158 @@ -package session - -import ( - "container/list" - "sync" - "time" -) - -var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)} - -type MemSessionStore struct { - sid string //session id唯一标示 - timeAccessed time.Time //最后访问时间 - value map[interface{}]interface{} //session里面存储的值 - lock sync.RWMutex -} - -func (st *MemSessionStore) Set(key, value interface{}) error { - st.lock.Lock() - defer st.lock.Unlock() - st.value[key] = value - return nil -} - -func (st *MemSessionStore) Get(key interface{}) interface{} { - st.lock.RLock() - defer st.lock.RUnlock() - if v, ok := st.value[key]; ok { - return v - } else { - return nil - } - return nil -} - -func (st *MemSessionStore) Delete(key interface{}) error { - st.lock.Lock() - defer st.lock.Unlock() - delete(st.value, key) - return nil -} - -func (st *MemSessionStore) SessionID() string { - return st.sid -} - -func (st *MemSessionStore) SessionRelease() { - -} - -type MemProvider struct { - lock sync.RWMutex //用来锁 - sessions map[string]*list.Element //用来存储在内存 - list *list.List //用来做gc - maxlifetime int64 - savePath string -} - -func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error { - pder.maxlifetime = maxlifetime - pder.savePath = savePath - return nil -} - -func (pder *MemProvider) SessionRead(sid string) (SessionStore, error) { - pder.lock.RLock() - if element, ok := pder.sessions[sid]; ok { - go pder.SessionUpdate(sid) - pder.lock.RUnlock() - return element.Value.(*MemSessionStore), nil - } else { - pder.lock.RUnlock() - pder.lock.Lock() - newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})} - element := pder.list.PushBack(newsess) - pder.sessions[sid] = element - pder.lock.Unlock() - return newsess, nil - } - return nil, nil -} - -func (pder *MemProvider) SessionDestroy(sid string) error { - pder.lock.Lock() - defer pder.lock.Unlock() - if element, ok := pder.sessions[sid]; ok { - delete(pder.sessions, sid) - pder.list.Remove(element) - return nil - } - return nil -} - -func (pder *MemProvider) SessionGC() { - pder.lock.RLock() - for { - element := pder.list.Back() - if element == nil { - break - } - if (element.Value.(*MemSessionStore).timeAccessed.Unix() + pder.maxlifetime) < time.Now().Unix() { - pder.lock.RUnlock() - pder.lock.Lock() - pder.list.Remove(element) - delete(pder.sessions, element.Value.(*MemSessionStore).sid) - pder.lock.Unlock() - pder.lock.RLock() - } else { - break - } - } - pder.lock.RUnlock() -} - -func (pder *MemProvider) SessionUpdate(sid string) error { - pder.lock.Lock() - defer pder.lock.Unlock() - if element, ok := pder.sessions[sid]; ok { - element.Value.(*MemSessionStore).timeAccessed = time.Now() - pder.list.MoveToFront(element) - return nil - } - return nil -} - -func init() { - Register("memory", mempder) -} +package session + +import ( + "container/list" + "sync" + "time" +) + +var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)} + +type MemSessionStore struct { + sid string //session id唯一标示 + timeAccessed time.Time //最后访问时间 + value map[interface{}]interface{} //session里面存储的值 + lock sync.RWMutex +} + +func (st *MemSessionStore) Set(key, value interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + st.value[key] = value + return nil +} + +func (st *MemSessionStore) Get(key interface{}) interface{} { + st.lock.RLock() + defer st.lock.RUnlock() + if v, ok := st.value[key]; ok { + return v + } else { + return nil + } + return nil +} + +func (st *MemSessionStore) Delete(key interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + delete(st.value, key) + return nil +} + +func (st *MemSessionStore) Flush() error { + st.lock.Lock() + defer st.lock.Unlock() + st.value = make(map[interface{}]interface{}) + return nil +} + +func (st *MemSessionStore) SessionID() string { + return st.sid +} + +func (st *MemSessionStore) SessionRelease() { + +} + +type MemProvider struct { + lock sync.RWMutex //用来锁 + sessions map[string]*list.Element //用来存储在内存 + list *list.List //用来做gc + maxlifetime int64 + savePath string +} + +func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error { + pder.maxlifetime = maxlifetime + pder.savePath = savePath + return nil +} + +func (pder *MemProvider) SessionRead(sid string) (SessionStore, error) { + pder.lock.RLock() + if element, ok := pder.sessions[sid]; ok { + go pder.SessionUpdate(sid) + pder.lock.RUnlock() + return element.Value.(*MemSessionStore), nil + } else { + pder.lock.RUnlock() + pder.lock.Lock() + newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})} + element := pder.list.PushBack(newsess) + pder.sessions[sid] = element + pder.lock.Unlock() + return newsess, nil + } + return nil, nil +} + +func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) { + pder.lock.RLock() + if element, ok := pder.sessions[oldsid]; ok { + go pder.SessionUpdate(oldsid) + pder.lock.RUnlock() + pder.lock.Lock() + element.Value.(*MemSessionStore).sid = sid + pder.sessions[sid] = element + delete(pder.sessions, oldsid) + pder.lock.Unlock() + return element.Value.(*MemSessionStore), nil + } else { + pder.lock.RUnlock() + pder.lock.Lock() + newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})} + element := pder.list.PushBack(newsess) + pder.sessions[sid] = element + pder.lock.Unlock() + return newsess, nil + } + return nil, nil +} + +func (pder *MemProvider) SessionDestroy(sid string) error { + pder.lock.Lock() + defer pder.lock.Unlock() + if element, ok := pder.sessions[sid]; ok { + delete(pder.sessions, sid) + pder.list.Remove(element) + return nil + } + return nil +} + +func (pder *MemProvider) SessionGC() { + pder.lock.RLock() + for { + element := pder.list.Back() + if element == nil { + break + } + if (element.Value.(*MemSessionStore).timeAccessed.Unix() + pder.maxlifetime) < time.Now().Unix() { + pder.lock.RUnlock() + pder.lock.Lock() + pder.list.Remove(element) + delete(pder.sessions, element.Value.(*MemSessionStore).sid) + pder.lock.Unlock() + pder.lock.RLock() + } else { + break + } + } + pder.lock.RUnlock() +} + +func (pder *MemProvider) SessionUpdate(sid string) error { + pder.lock.Lock() + defer pder.lock.Unlock() + if element, ok := pder.sessions[sid]; ok { + element.Value.(*MemSessionStore).timeAccessed = time.Now() + pder.list.MoveToFront(element) + return nil + } + return nil +} + +func init() { + Register("memory", mempder) +} diff --git a/session/sess_mysql.go b/session/sess_mysql.go index 217227e9..f1af8564 100644 --- a/session/sess_mysql.go +++ b/session/sess_mysql.go @@ -50,6 +50,14 @@ func (st *MysqlSessionStore) Delete(key interface{}) error { return nil } +func (st *MysqlSessionStore) Flush() error { + st.lock.Lock() + defer st.lock.Unlock() + st.values = make(map[interface{}]interface{}) + st.updatemysql() + return nil +} + func (st *MysqlSessionStore) SessionID() string { return st.sid } @@ -108,6 +116,28 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) { return rs, nil } +func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) { + c := mp.connectInit() + row := c.QueryRow("select session_data from session where session_key=?", oldsid) + var sessiondata []byte + err := row.Scan(&sessiondata) + if err == sql.ErrNoRows { + c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix()) + } + c.Exec("update session set `session_key`=? where session_key=?", sid, oldsid) + var kv map[interface{}]interface{} + if len(sessiondata) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = decodeGob(sessiondata) + if err != nil { + return nil, err + } + } + rs := &MysqlSessionStore{c: c, sid: sid, values: kv} + return rs, nil +} + func (mp *MysqlProvider) SessionDestroy(sid string) error { c := mp.connectInit() c.Exec("DELETE FROM session where session_key=?", sid) diff --git a/session/sess_redis.go b/session/sess_redis.go index a5d41071..13348e84 100644 --- a/session/sess_redis.go +++ b/session/sess_redis.go @@ -35,6 +35,11 @@ func (rs *RedisSessionStore) Delete(key interface{}) error { return err } +func (rs *RedisSessionStore) Flush() error { + _, err := rs.c.Do("DEL", rs.sid) + return err +} + func (rs *RedisSessionStore) SessionID() string { return rs.sid } @@ -99,6 +104,16 @@ func (rp *RedisProvider) SessionRead(sid string) (SessionStore, error) { return rs, nil } +func (rp *RedisProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) { + c := rp.connectInit() + if str, err := redis.String(c.Do("HGET", oldsid, oldsid)); err != nil || str == "" { + c.Do("HSET", oldsid, oldsid, rp.maxlifetime) + } + c.Do("RENAME", oldsid, sid) + rs := &RedisSessionStore{c: c, sid: sid} + return rs, nil +} + func (rp *RedisProvider) SessionDestroy(sid string) error { c := rp.connectInit() c.Do("DEL", sid) diff --git a/session/session.go b/session/session.go index 3491ee33..5b85ffc3 100644 --- a/session/session.go +++ b/session/session.go @@ -1,8 +1,12 @@ package session import ( + "crypto/hmac" + "crypto/md5" "crypto/rand" + "crypto/sha1" "encoding/base64" + "encoding/hex" "fmt" "io" "net/http" @@ -16,11 +20,13 @@ type SessionStore interface { Delete(key interface{}) error //delete session value SessionID() string //back current sessionID SessionRelease() // release the resource + Flush() error //delete all data } type Provider interface { SessionInit(maxlifetime int64, savePath string) error SessionRead(sid string) (SessionStore, error) + SessionRegenerate(oldsid, sid string) (SessionStore, error) SessionDestroy(sid string) error SessionGC() } @@ -44,40 +50,91 @@ type Manager struct { cookieName string //private cookiename provider Provider maxlifetime int64 + hashfunc string //support md5 & sha1 + hashkey string options []interface{} } +//options +//1. is https default false +//2. hashfunc default sha1 +//3. hashkey default beegosessionkey +//4. maxage default is none func NewManager(provideName, cookieName string, maxlifetime int64, savePath string, options ...interface{}) (*Manager, error) { provider, ok := provides[provideName] if !ok { return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName) } provider.SessionInit(maxlifetime, savePath) - return &Manager{provider: provider, cookieName: cookieName, maxlifetime: maxlifetime, options: options}, nil + hashfunc := "sha1" + if len(options) > 1 { + hashfunc = options[1].(string) + } + hashkey := "beegosessionkey" + if len(options) > 2 { + hashkey = options[2].(string) + } + return &Manager{ + provider: provider, + cookieName: cookieName, + maxlifetime: maxlifetime, + hashfunc: hashfunc, + hashkey: hashkey, + options: options, + }, nil } //get Session func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) { cookie, err := r.Cookie(manager.cookieName) + maxage := -1 + if len(manager.options) > 3 { + switch manager.options[3].(type) { + case int: + if manager.options[3].(int) > 0 { + maxage = manager.options[3].(int) + } else if manager.options[3].(int) < 0 { + maxage = 0 + } + case int64: + if manager.options[3].(int64) > 0 { + maxage = int(manager.options[3].(int64)) + } else if manager.options[3].(int64) < 0 { + maxage = 0 + } + case int32: + if manager.options[3].(int32) > 0 { + maxage = int(manager.options[3].(int32)) + } else if manager.options[3].(int32) < 0 { + maxage = 0 + } + } + } if err != nil || cookie.Value == "" { - sid := manager.sessionId() + sid := manager.sessionId(r) session, _ = manager.provider.SessionRead(sid) secure := false if len(manager.options) > 0 { secure = manager.options[0].(bool) } - cookie := http.Cookie{Name: manager.cookieName, + cookie = &http.Cookie{Name: manager.cookieName, Value: url.QueryEscape(sid), Path: "/", HttpOnly: true, Secure: secure} + if maxage >= 0 { + cookie.MaxAge = maxage + } //cookie.Expires = time.Now().Add(time.Duration(manager.maxlifetime) * time.Second) - http.SetCookie(w, &cookie) - r.AddCookie(&cookie) + http.SetCookie(w, cookie) + r.AddCookie(cookie) } else { //cookie.Expires = time.Now().Add(time.Duration(manager.maxlifetime) * time.Second) cookie.HttpOnly = true cookie.Path = "/" + if maxage >= 0 { + cookie.MaxAge = maxage + } http.SetCookie(w, cookie) sid, _ := url.QueryUnescape(cookie.Value) session, _ = manager.provider.SessionRead(sid) @@ -103,10 +160,81 @@ func (manager *Manager) GC() { time.AfterFunc(time.Duration(manager.maxlifetime)*time.Second, func() { manager.GC() }) } -func (manager *Manager) sessionId() string { +func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) { + sid := manager.sessionId(r) + cookie, err := r.Cookie(manager.cookieName) + if err != nil && cookie.Value == "" { + //delete old cookie + session, _ = manager.provider.SessionRead(sid) + secure := false + if len(manager.options) > 0 { + secure = manager.options[0].(bool) + } + cookie = &http.Cookie{Name: manager.cookieName, + Value: url.QueryEscape(sid), + Path: "/", + HttpOnly: true, + Secure: secure, + } + } else { + oldsid, _ := url.QueryUnescape(cookie.Value) + session, _ = manager.provider.SessionRegenerate(oldsid, sid) + cookie.Value = url.QueryEscape(sid) + cookie.HttpOnly = true + cookie.Path = "/" + } + maxage := -1 + if len(manager.options) > 3 { + switch manager.options[3].(type) { + case int: + if manager.options[3].(int) > 0 { + maxage = manager.options[3].(int) + } else if manager.options[3].(int) < 0 { + maxage = 0 + } + case int64: + if manager.options[3].(int64) > 0 { + maxage = int(manager.options[3].(int64)) + } else if manager.options[3].(int64) < 0 { + maxage = 0 + } + case int32: + if manager.options[3].(int32) > 0 { + maxage = int(manager.options[3].(int32)) + } else if manager.options[3].(int32) < 0 { + maxage = 0 + } + } + } + if maxage >= 0 { + cookie.MaxAge = maxage + } + http.SetCookie(w, cookie) + r.AddCookie(cookie) + return +} + +//remote_addr cruunixnano randdata + +func (manager *Manager) sessionId(r *http.Request) (sid string) { b := make([]byte, 24) if _, err := io.ReadFull(rand.Reader, b); err != nil { return "" } - return base64.URLEncoding.EncodeToString(b) + bs := base64.URLEncoding.EncodeToString(b) + sig := fmt.Sprintf("%s%d%s", r.RemoteAddr, time.Now().UnixNano(), bs) + if manager.hashfunc == "md5" { + h := md5.New() + h.Write([]byte(bs)) + sid = fmt.Sprintf("%s", hex.EncodeToString(h.Sum(nil))) + } else if manager.hashfunc == "sha1" { + h := hmac.New(sha1.New, []byte(manager.hashkey)) + fmt.Fprintf(h, "%s", sig) + sid = fmt.Sprintf("%s", hex.EncodeToString(h.Sum(nil))) + } else { + h := hmac.New(sha1.New, []byte(manager.hashkey)) + fmt.Fprintf(h, "%s", sig) + sid = fmt.Sprintf("%s", hex.EncodeToString(h.Sum(nil))) + } + return }