diff --git a/session/sess_file.go b/session/sess_file.go index 016a30f5..1db4022e 100644 --- a/session/sess_file.go +++ b/session/sess_file.go @@ -116,6 +116,15 @@ func (fp *FileProvider) SessionRead(sid string) (SessionStore, error) { return ss, nil } +func (fp *FileProvider) SessionExist(sid string) bool { + _, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + if err == nil { + return true + } else { + return false + } +} + func (fp *FileProvider) SessionDestroy(sid string) error { os.Remove(path.Join(fp.savePath)) return nil diff --git a/session/sess_mem.go b/session/sess_mem.go index 93cc7fc6..2e615c6f 100644 --- a/session/sess_mem.go +++ b/session/sess_mem.go @@ -87,6 +87,16 @@ func (pder *MemProvider) SessionRead(sid string) (SessionStore, error) { return nil, nil } +func (pder *MemProvider) SessionExist(sid string) bool { + pder.lock.RLock() + defer pder.lock.RUnlock() + if _, ok := pder.sessions[sid]; ok { + return true + } else { + return false + } +} + func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) { pder.lock.RLock() if element, ok := pder.sessions[oldsid]; ok { diff --git a/session/sess_mysql.go b/session/sess_mysql.go index 6d938aee..10dcf20b 100644 --- a/session/sess_mysql.go +++ b/session/sess_mysql.go @@ -110,6 +110,18 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) { return rs, nil } +func (mp *MysqlProvider) SessionExist(sid string) bool { + c := mp.connectInit() + row := c.QueryRow("select session_data from session where session_key=?", sid) + var sessiondata []byte + err := row.Scan(&sessiondata) + if err == sql.ErrNoRows { + return false + } else { + return true + } +} + func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) { c := mp.connectInit() row := c.QueryRow("select session_data from session where session_key=?", oldsid) diff --git a/session/sess_redis.go b/session/sess_redis.go index 665d62eb..28fc6dca 100644 --- a/session/sess_redis.go +++ b/session/sess_redis.go @@ -145,6 +145,15 @@ func (rp *RedisProvider) SessionRead(sid string) (SessionStore, error) { return rs, nil } +func (rp *RedisProvider) SessionExist(sid string) bool { + c := rp.poollist.Get() + if str, err := redis.String(c.Do("HGET", sid, sid)); err != nil || str == "" { + return false + } else { + return true + } +} + func (rp *RedisProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) { c := rp.poollist.Get() if str, err := redis.String(c.Do("HGET", oldsid, oldsid)); err != nil || str == "" { diff --git a/session/session.go b/session/session.go index 5db9bd8b..062bbfd6 100644 --- a/session/session.go +++ b/session/session.go @@ -25,6 +25,7 @@ type SessionStore interface { type Provider interface { SessionInit(maxlifetime int64, savePath string) error SessionRead(sid string) (SessionStore, error) + SessionExist(sid string) bool SessionRegenerate(oldsid, sid string) (SessionStore, error) SessionDestroy(sid string) error SessionAll() int //get all active session @@ -133,7 +134,22 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se r.AddCookie(cookie) } else { sid, _ := url.QueryUnescape(cookie.Value) - session, _ = manager.provider.SessionRead(sid) + if manager.provider.SessionExist(sid) { + session, _ = manager.provider.SessionRead(sid) + } else { + sid = manager.sessionId(r) + session, _ = manager.provider.SessionRead(sid) + cookie = &http.Cookie{Name: manager.cookieName, + Value: url.QueryEscape(sid), + Path: "/", + HttpOnly: true, + Secure: manager.secure} + if manager.maxage >= 0 { + cookie.MaxAge = manager.maxage + } + http.SetCookie(w, cookie) + r.AddCookie(cookie) + } } return }