1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-22 12:10:55 +00:00

Strengthens the session's function

This commit is contained in:
astaxie 2013-09-26 18:07:00 +08:00
parent 59a67720b4
commit 02c2e16253
5 changed files with 397 additions and 135 deletions

View File

@ -1,6 +1,8 @@
package session package session
import ( import (
"errors"
"io"
"io/ioutil" "io/ioutil"
"os" "os"
"path" "path"
@ -48,6 +50,14 @@ func (fs *FileSessionStore) Delete(key interface{}) error {
return nil return nil
} }
func (fs *FileSessionStore) Flush() error {
fs.lock.Lock()
defer fs.lock.Unlock()
fs.values = make(map[interface{}]interface{})
fs.updatecontent()
return nil
}
func (fs *FileSessionStore) SessionID() string { func (fs *FileSessionStore) SessionID() string {
return fs.sid return fs.sid
} }
@ -121,6 +131,55 @@ func (fp *FileProvider) SessionGC() {
filepath.Walk(fp.savePath, gcpath) filepath.Walk(fp.savePath, gcpath)
} }
func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
err := os.MkdirAll(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])), 0777)
if err != nil {
println(err.Error())
}
err = os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777)
if err != nil {
println(err.Error())
}
_, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
var newf *os.File
if err == nil {
return nil, errors.New("newsid exist")
} else if os.IsNotExist(err) {
newf, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
}
_, err = os.Stat(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]), oldsid))
var f *os.File
if err == nil {
f, err = os.OpenFile(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]), oldsid), os.O_RDWR, 0777)
io.Copy(newf, f)
} else if os.IsNotExist(err) {
newf, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
} else {
return nil, err
}
f.Close()
os.Remove(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])))
os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now())
var kv map[interface{}]interface{}
b, err := ioutil.ReadAll(newf)
if err != nil {
return nil, err
}
if len(b) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = decodeGob(b)
if err != nil {
return nil, err
}
}
newf, err = os.OpenFile(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), os.O_WRONLY|os.O_CREATE, 0777)
ss := &FileSessionStore{f: newf, sid: sid, values: kv}
return ss, nil
}
func gcpath(path string, info os.FileInfo, err error) error { func gcpath(path string, info os.FileInfo, err error) error {
if err != nil { if err != nil {
return err return err

View File

@ -1,128 +1,158 @@
package session package session
import ( import (
"container/list" "container/list"
"sync" "sync"
"time" "time"
) )
var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)} var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)}
type MemSessionStore struct { type MemSessionStore struct {
sid string //session id唯一标示 sid string //session id唯一标示
timeAccessed time.Time //最后访问时间 timeAccessed time.Time //最后访问时间
value map[interface{}]interface{} //session里面存储的值 value map[interface{}]interface{} //session里面存储的值
lock sync.RWMutex lock sync.RWMutex
} }
func (st *MemSessionStore) Set(key, value interface{}) error { func (st *MemSessionStore) Set(key, value interface{}) error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
st.value[key] = value st.value[key] = value
return nil return nil
} }
func (st *MemSessionStore) Get(key interface{}) interface{} { func (st *MemSessionStore) Get(key interface{}) interface{} {
st.lock.RLock() st.lock.RLock()
defer st.lock.RUnlock() defer st.lock.RUnlock()
if v, ok := st.value[key]; ok { if v, ok := st.value[key]; ok {
return v return v
} else { } else {
return nil return nil
} }
return nil return nil
} }
func (st *MemSessionStore) Delete(key interface{}) error { func (st *MemSessionStore) Delete(key interface{}) error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
delete(st.value, key) delete(st.value, key)
return nil return nil
} }
func (st *MemSessionStore) SessionID() string { func (st *MemSessionStore) Flush() error {
return st.sid st.lock.Lock()
} defer st.lock.Unlock()
st.value = make(map[interface{}]interface{})
func (st *MemSessionStore) SessionRelease() { return nil
}
}
func (st *MemSessionStore) SessionID() string {
type MemProvider struct { return st.sid
lock sync.RWMutex //用来锁 }
sessions map[string]*list.Element //用来存储在内存
list *list.List //用来做gc func (st *MemSessionStore) SessionRelease() {
maxlifetime int64
savePath string }
}
type MemProvider struct {
func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error { lock sync.RWMutex //用来锁
pder.maxlifetime = maxlifetime sessions map[string]*list.Element //用来存储在内存
pder.savePath = savePath list *list.List //用来做gc
return nil maxlifetime int64
} savePath string
}
func (pder *MemProvider) SessionRead(sid string) (SessionStore, error) {
pder.lock.RLock() func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error {
if element, ok := pder.sessions[sid]; ok { pder.maxlifetime = maxlifetime
go pder.SessionUpdate(sid) pder.savePath = savePath
pder.lock.RUnlock() return nil
return element.Value.(*MemSessionStore), nil }
} else {
pder.lock.RUnlock() func (pder *MemProvider) SessionRead(sid string) (SessionStore, error) {
pder.lock.Lock() pder.lock.RLock()
newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})} if element, ok := pder.sessions[sid]; ok {
element := pder.list.PushBack(newsess) go pder.SessionUpdate(sid)
pder.sessions[sid] = element pder.lock.RUnlock()
pder.lock.Unlock() return element.Value.(*MemSessionStore), nil
return newsess, nil } else {
} pder.lock.RUnlock()
return nil, nil pder.lock.Lock()
} newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})}
element := pder.list.PushBack(newsess)
func (pder *MemProvider) SessionDestroy(sid string) error { pder.sessions[sid] = element
pder.lock.Lock() pder.lock.Unlock()
defer pder.lock.Unlock() return newsess, nil
if element, ok := pder.sessions[sid]; ok { }
delete(pder.sessions, sid) return nil, nil
pder.list.Remove(element) }
return nil
} func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
return nil pder.lock.RLock()
} if element, ok := pder.sessions[oldsid]; ok {
go pder.SessionUpdate(oldsid)
func (pder *MemProvider) SessionGC() { pder.lock.RUnlock()
pder.lock.RLock() pder.lock.Lock()
for { element.Value.(*MemSessionStore).sid = sid
element := pder.list.Back() pder.sessions[sid] = element
if element == nil { delete(pder.sessions, oldsid)
break pder.lock.Unlock()
} return element.Value.(*MemSessionStore), nil
if (element.Value.(*MemSessionStore).timeAccessed.Unix() + pder.maxlifetime) < time.Now().Unix() { } else {
pder.lock.RUnlock() pder.lock.RUnlock()
pder.lock.Lock() pder.lock.Lock()
pder.list.Remove(element) newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})}
delete(pder.sessions, element.Value.(*MemSessionStore).sid) element := pder.list.PushBack(newsess)
pder.lock.Unlock() pder.sessions[sid] = element
pder.lock.RLock() pder.lock.Unlock()
} else { return newsess, nil
break }
} return nil, nil
} }
pder.lock.RUnlock()
} func (pder *MemProvider) SessionDestroy(sid string) error {
pder.lock.Lock()
func (pder *MemProvider) SessionUpdate(sid string) error { defer pder.lock.Unlock()
pder.lock.Lock() if element, ok := pder.sessions[sid]; ok {
defer pder.lock.Unlock() delete(pder.sessions, sid)
if element, ok := pder.sessions[sid]; ok { pder.list.Remove(element)
element.Value.(*MemSessionStore).timeAccessed = time.Now() return nil
pder.list.MoveToFront(element) }
return nil return nil
} }
return nil
} func (pder *MemProvider) SessionGC() {
pder.lock.RLock()
func init() { for {
Register("memory", mempder) element := pder.list.Back()
} if element == nil {
break
}
if (element.Value.(*MemSessionStore).timeAccessed.Unix() + pder.maxlifetime) < time.Now().Unix() {
pder.lock.RUnlock()
pder.lock.Lock()
pder.list.Remove(element)
delete(pder.sessions, element.Value.(*MemSessionStore).sid)
pder.lock.Unlock()
pder.lock.RLock()
} else {
break
}
}
pder.lock.RUnlock()
}
func (pder *MemProvider) SessionUpdate(sid string) error {
pder.lock.Lock()
defer pder.lock.Unlock()
if element, ok := pder.sessions[sid]; ok {
element.Value.(*MemSessionStore).timeAccessed = time.Now()
pder.list.MoveToFront(element)
return nil
}
return nil
}
func init() {
Register("memory", mempder)
}

View File

@ -50,6 +50,14 @@ func (st *MysqlSessionStore) Delete(key interface{}) error {
return nil return nil
} }
func (st *MysqlSessionStore) Flush() error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
st.updatemysql()
return nil
}
func (st *MysqlSessionStore) SessionID() string { func (st *MysqlSessionStore) SessionID() string {
return st.sid return st.sid
} }
@ -108,6 +116,28 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) {
return rs, nil return rs, nil
} }
func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=?", oldsid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix())
}
c.Exec("update session set `session_key`=? where session_key=?", sid, oldsid)
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = decodeGob(sessiondata)
if err != nil {
return nil, err
}
}
rs := &MysqlSessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
func (mp *MysqlProvider) SessionDestroy(sid string) error { func (mp *MysqlProvider) SessionDestroy(sid string) error {
c := mp.connectInit() c := mp.connectInit()
c.Exec("DELETE FROM session where session_key=?", sid) c.Exec("DELETE FROM session where session_key=?", sid)

View File

@ -35,6 +35,11 @@ func (rs *RedisSessionStore) Delete(key interface{}) error {
return err return err
} }
func (rs *RedisSessionStore) Flush() error {
_, err := rs.c.Do("DEL", rs.sid)
return err
}
func (rs *RedisSessionStore) SessionID() string { func (rs *RedisSessionStore) SessionID() string {
return rs.sid return rs.sid
} }
@ -99,6 +104,16 @@ func (rp *RedisProvider) SessionRead(sid string) (SessionStore, error) {
return rs, nil return rs, nil
} }
func (rp *RedisProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
c := rp.connectInit()
if str, err := redis.String(c.Do("HGET", oldsid, oldsid)); err != nil || str == "" {
c.Do("HSET", oldsid, oldsid, rp.maxlifetime)
}
c.Do("RENAME", oldsid, sid)
rs := &RedisSessionStore{c: c, sid: sid}
return rs, nil
}
func (rp *RedisProvider) SessionDestroy(sid string) error { func (rp *RedisProvider) SessionDestroy(sid string) error {
c := rp.connectInit() c := rp.connectInit()
c.Do("DEL", sid) c.Do("DEL", sid)

View File

@ -1,8 +1,12 @@
package session package session
import ( import (
"crypto/hmac"
"crypto/md5"
"crypto/rand" "crypto/rand"
"crypto/sha1"
"encoding/base64" "encoding/base64"
"encoding/hex"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -16,11 +20,13 @@ type SessionStore interface {
Delete(key interface{}) error //delete session value Delete(key interface{}) error //delete session value
SessionID() string //back current sessionID SessionID() string //back current sessionID
SessionRelease() // release the resource SessionRelease() // release the resource
Flush() error //delete all data
} }
type Provider interface { type Provider interface {
SessionInit(maxlifetime int64, savePath string) error SessionInit(maxlifetime int64, savePath string) error
SessionRead(sid string) (SessionStore, error) SessionRead(sid string) (SessionStore, error)
SessionRegenerate(oldsid, sid string) (SessionStore, error)
SessionDestroy(sid string) error SessionDestroy(sid string) error
SessionGC() SessionGC()
} }
@ -44,40 +50,91 @@ type Manager struct {
cookieName string //private cookiename cookieName string //private cookiename
provider Provider provider Provider
maxlifetime int64 maxlifetime int64
hashfunc string //support md5 & sha1
hashkey string
options []interface{} options []interface{}
} }
//options
//1. is https default false
//2. hashfunc default sha1
//3. hashkey default beegosessionkey
//4. maxage default is none
func NewManager(provideName, cookieName string, maxlifetime int64, savePath string, options ...interface{}) (*Manager, error) { func NewManager(provideName, cookieName string, maxlifetime int64, savePath string, options ...interface{}) (*Manager, error) {
provider, ok := provides[provideName] provider, ok := provides[provideName]
if !ok { if !ok {
return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName) return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName)
} }
provider.SessionInit(maxlifetime, savePath) provider.SessionInit(maxlifetime, savePath)
return &Manager{provider: provider, cookieName: cookieName, maxlifetime: maxlifetime, options: options}, nil hashfunc := "sha1"
if len(options) > 1 {
hashfunc = options[1].(string)
}
hashkey := "beegosessionkey"
if len(options) > 2 {
hashkey = options[2].(string)
}
return &Manager{
provider: provider,
cookieName: cookieName,
maxlifetime: maxlifetime,
hashfunc: hashfunc,
hashkey: hashkey,
options: options,
}, nil
} }
//get Session //get Session
func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) { func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) {
cookie, err := r.Cookie(manager.cookieName) cookie, err := r.Cookie(manager.cookieName)
maxage := -1
if len(manager.options) > 3 {
switch manager.options[3].(type) {
case int:
if manager.options[3].(int) > 0 {
maxage = manager.options[3].(int)
} else if manager.options[3].(int) < 0 {
maxage = 0
}
case int64:
if manager.options[3].(int64) > 0 {
maxage = int(manager.options[3].(int64))
} else if manager.options[3].(int64) < 0 {
maxage = 0
}
case int32:
if manager.options[3].(int32) > 0 {
maxage = int(manager.options[3].(int32))
} else if manager.options[3].(int32) < 0 {
maxage = 0
}
}
}
if err != nil || cookie.Value == "" { if err != nil || cookie.Value == "" {
sid := manager.sessionId() sid := manager.sessionId(r)
session, _ = manager.provider.SessionRead(sid) session, _ = manager.provider.SessionRead(sid)
secure := false secure := false
if len(manager.options) > 0 { if len(manager.options) > 0 {
secure = manager.options[0].(bool) secure = manager.options[0].(bool)
} }
cookie := http.Cookie{Name: manager.cookieName, cookie = &http.Cookie{Name: manager.cookieName,
Value: url.QueryEscape(sid), Value: url.QueryEscape(sid),
Path: "/", Path: "/",
HttpOnly: true, HttpOnly: true,
Secure: secure} Secure: secure}
if maxage >= 0 {
cookie.MaxAge = maxage
}
//cookie.Expires = time.Now().Add(time.Duration(manager.maxlifetime) * time.Second) //cookie.Expires = time.Now().Add(time.Duration(manager.maxlifetime) * time.Second)
http.SetCookie(w, &cookie) http.SetCookie(w, cookie)
r.AddCookie(&cookie) r.AddCookie(cookie)
} else { } else {
//cookie.Expires = time.Now().Add(time.Duration(manager.maxlifetime) * time.Second) //cookie.Expires = time.Now().Add(time.Duration(manager.maxlifetime) * time.Second)
cookie.HttpOnly = true cookie.HttpOnly = true
cookie.Path = "/" cookie.Path = "/"
if maxage >= 0 {
cookie.MaxAge = maxage
}
http.SetCookie(w, cookie) http.SetCookie(w, cookie)
sid, _ := url.QueryUnescape(cookie.Value) sid, _ := url.QueryUnescape(cookie.Value)
session, _ = manager.provider.SessionRead(sid) session, _ = manager.provider.SessionRead(sid)
@ -103,10 +160,81 @@ func (manager *Manager) GC() {
time.AfterFunc(time.Duration(manager.maxlifetime)*time.Second, func() { manager.GC() }) time.AfterFunc(time.Duration(manager.maxlifetime)*time.Second, func() { manager.GC() })
} }
func (manager *Manager) sessionId() string { func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) {
sid := manager.sessionId(r)
cookie, err := r.Cookie(manager.cookieName)
if err != nil && cookie.Value == "" {
//delete old cookie
session, _ = manager.provider.SessionRead(sid)
secure := false
if len(manager.options) > 0 {
secure = manager.options[0].(bool)
}
cookie = &http.Cookie{Name: manager.cookieName,
Value: url.QueryEscape(sid),
Path: "/",
HttpOnly: true,
Secure: secure,
}
} else {
oldsid, _ := url.QueryUnescape(cookie.Value)
session, _ = manager.provider.SessionRegenerate(oldsid, sid)
cookie.Value = url.QueryEscape(sid)
cookie.HttpOnly = true
cookie.Path = "/"
}
maxage := -1
if len(manager.options) > 3 {
switch manager.options[3].(type) {
case int:
if manager.options[3].(int) > 0 {
maxage = manager.options[3].(int)
} else if manager.options[3].(int) < 0 {
maxage = 0
}
case int64:
if manager.options[3].(int64) > 0 {
maxage = int(manager.options[3].(int64))
} else if manager.options[3].(int64) < 0 {
maxage = 0
}
case int32:
if manager.options[3].(int32) > 0 {
maxage = int(manager.options[3].(int32))
} else if manager.options[3].(int32) < 0 {
maxage = 0
}
}
}
if maxage >= 0 {
cookie.MaxAge = maxage
}
http.SetCookie(w, cookie)
r.AddCookie(cookie)
return
}
//remote_addr cruunixnano randdata
func (manager *Manager) sessionId(r *http.Request) (sid string) {
b := make([]byte, 24) b := make([]byte, 24)
if _, err := io.ReadFull(rand.Reader, b); err != nil { if _, err := io.ReadFull(rand.Reader, b); err != nil {
return "" return ""
} }
return base64.URLEncoding.EncodeToString(b) bs := base64.URLEncoding.EncodeToString(b)
sig := fmt.Sprintf("%s%d%s", r.RemoteAddr, time.Now().UnixNano(), bs)
if manager.hashfunc == "md5" {
h := md5.New()
h.Write([]byte(bs))
sid = fmt.Sprintf("%s", hex.EncodeToString(h.Sum(nil)))
} else if manager.hashfunc == "sha1" {
h := hmac.New(sha1.New, []byte(manager.hashkey))
fmt.Fprintf(h, "%s", sig)
sid = fmt.Sprintf("%s", hex.EncodeToString(h.Sum(nil)))
} else {
h := hmac.New(sha1.New, []byte(manager.hashkey))
fmt.Fprintf(h, "%s", sig)
sid = fmt.Sprintf("%s", hex.EncodeToString(h.Sum(nil)))
}
return
} }