1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-25 19:00:55 +00:00

Add ctx to session API

This commit is contained in:
Ming Deng 2020-08-30 15:39:07 +00:00
parent 0019e0fc1b
commit 670064686e
23 changed files with 302 additions and 288 deletions

View File

@ -129,7 +129,7 @@ func TestCache_Scan(t *testing.T) {
t.Error("init err") t.Error("init err")
} }
// insert all // insert all
for i := 0; i < 10000; i++ { for i := 0; i < 100; i++ {
if err = bm.Put(fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil { if err = bm.Put(fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil {
t.Error("set Error", err) t.Error("set Error", err)
} }
@ -141,7 +141,7 @@ func TestCache_Scan(t *testing.T) {
t.Error("scan Error", err) t.Error("scan Error", err)
} }
assert.Equal(t, 10000, len(keys), "scan all error") assert.Equal(t, 100, len(keys), "scan all error")
// clear all // clear all
if err = bm.ClearAll(); err != nil { if err = bm.ClearAll(); err != nil {

View File

@ -33,6 +33,7 @@
package couchbase package couchbase
import ( import (
"context"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
@ -63,7 +64,7 @@ type Provider struct {
} }
// Set value to couchabse session // Set value to couchabse session
func (cs *SessionStore) Set(key, value interface{}) error { func (cs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
cs.lock.Lock() cs.lock.Lock()
defer cs.lock.Unlock() defer cs.lock.Unlock()
cs.values[key] = value cs.values[key] = value
@ -71,7 +72,7 @@ func (cs *SessionStore) Set(key, value interface{}) error {
} }
// Get value from couchabse session // Get value from couchabse session
func (cs *SessionStore) Get(key interface{}) interface{} { func (cs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
cs.lock.RLock() cs.lock.RLock()
defer cs.lock.RUnlock() defer cs.lock.RUnlock()
if v, ok := cs.values[key]; ok { if v, ok := cs.values[key]; ok {
@ -81,7 +82,7 @@ func (cs *SessionStore) Get(key interface{}) interface{} {
} }
// Delete value in couchbase session by given key // Delete value in couchbase session by given key
func (cs *SessionStore) Delete(key interface{}) error { func (cs *SessionStore) Delete(ctx context.Context, key interface{}) error {
cs.lock.Lock() cs.lock.Lock()
defer cs.lock.Unlock() defer cs.lock.Unlock()
delete(cs.values, key) delete(cs.values, key)
@ -89,7 +90,7 @@ func (cs *SessionStore) Delete(key interface{}) error {
} }
// Flush Clean all values in couchbase session // Flush Clean all values in couchbase session
func (cs *SessionStore) Flush() error { func (cs *SessionStore) Flush(context.Context) error {
cs.lock.Lock() cs.lock.Lock()
defer cs.lock.Unlock() defer cs.lock.Unlock()
cs.values = make(map[interface{}]interface{}) cs.values = make(map[interface{}]interface{})
@ -97,12 +98,12 @@ func (cs *SessionStore) Flush() error {
} }
// SessionID Get couchbase session store id // SessionID Get couchbase session store id
func (cs *SessionStore) SessionID() string { func (cs *SessionStore) SessionID(context.Context) string {
return cs.sid return cs.sid
} }
// SessionRelease Write couchbase session with Gob string // SessionRelease Write couchbase session with Gob string
func (cs *SessionStore) SessionRelease(w http.ResponseWriter) { func (cs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
defer cs.b.Close() defer cs.b.Close()
bo, err := session.EncodeGob(cs.values) bo, err := session.EncodeGob(cs.values)
@ -135,7 +136,7 @@ func (cp *Provider) getBucket() *couchbase.Bucket {
// SessionInit init couchbase session // SessionInit init couchbase session
// savepath like couchbase server REST/JSON URL // savepath like couchbase server REST/JSON URL
// e.g. http://host:port/, Pool, Bucket // e.g. http://host:port/, Pool, Bucket
func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error { func (cp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
cp.maxlifetime = maxlifetime cp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",") configs := strings.Split(savePath, ",")
if len(configs) > 0 { if len(configs) > 0 {
@ -152,7 +153,7 @@ func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error {
} }
// SessionRead read couchbase session by sid // SessionRead read couchbase session by sid
func (cp *Provider) SessionRead(sid string) (session.Store, error) { func (cp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
cp.b = cp.getBucket() cp.b = cp.getBucket()
var ( var (
@ -179,7 +180,7 @@ func (cp *Provider) SessionRead(sid string) (session.Store, error) {
// SessionExist Check couchbase session exist. // SessionExist Check couchbase session exist.
// it checkes sid exist or not. // it checkes sid exist or not.
func (cp *Provider) SessionExist(sid string) (bool, error) { func (cp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
cp.b = cp.getBucket() cp.b = cp.getBucket()
defer cp.b.Close() defer cp.b.Close()
@ -192,7 +193,7 @@ func (cp *Provider) SessionExist(sid string) (bool, error) {
} }
// SessionRegenerate remove oldsid and use sid to generate new session // SessionRegenerate remove oldsid and use sid to generate new session
func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { func (cp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
cp.b = cp.getBucket() cp.b = cp.getBucket()
var doc []byte var doc []byte
@ -225,7 +226,7 @@ func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error)
} }
// SessionDestroy Remove bucket in this couchbase // SessionDestroy Remove bucket in this couchbase
func (cp *Provider) SessionDestroy(sid string) error { func (cp *Provider) SessionDestroy(ctx context.Context, sid string) error {
cp.b = cp.getBucket() cp.b = cp.getBucket()
defer cp.b.Close() defer cp.b.Close()
@ -234,11 +235,11 @@ func (cp *Provider) SessionDestroy(sid string) error {
} }
// SessionGC Recycle // SessionGC Recycle
func (cp *Provider) SessionGC() { func (cp *Provider) SessionGC(context.Context) {
} }
// SessionAll return all active session // SessionAll return all active session
func (cp *Provider) SessionAll() int { func (cp *Provider) SessionAll(context.Context) int {
return 0 return 0
} }

View File

@ -2,6 +2,7 @@
package ledis package ledis
import ( import (
"context"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@ -27,7 +28,7 @@ type SessionStore struct {
} }
// Set value in ledis session // Set value in ledis session
func (ls *SessionStore) Set(key, value interface{}) error { func (ls *SessionStore) Set(ctx context.Context, key, value interface{}) error {
ls.lock.Lock() ls.lock.Lock()
defer ls.lock.Unlock() defer ls.lock.Unlock()
ls.values[key] = value ls.values[key] = value
@ -35,7 +36,7 @@ func (ls *SessionStore) Set(key, value interface{}) error {
} }
// Get value in ledis session // Get value in ledis session
func (ls *SessionStore) Get(key interface{}) interface{} { func (ls *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
ls.lock.RLock() ls.lock.RLock()
defer ls.lock.RUnlock() defer ls.lock.RUnlock()
if v, ok := ls.values[key]; ok { if v, ok := ls.values[key]; ok {
@ -45,7 +46,7 @@ func (ls *SessionStore) Get(key interface{}) interface{} {
} }
// Delete value in ledis session // Delete value in ledis session
func (ls *SessionStore) Delete(key interface{}) error { func (ls *SessionStore) Delete(ctx context.Context, key interface{}) error {
ls.lock.Lock() ls.lock.Lock()
defer ls.lock.Unlock() defer ls.lock.Unlock()
delete(ls.values, key) delete(ls.values, key)
@ -53,7 +54,7 @@ func (ls *SessionStore) Delete(key interface{}) error {
} }
// Flush clear all values in ledis session // Flush clear all values in ledis session
func (ls *SessionStore) Flush() error { func (ls *SessionStore) Flush(context.Context) error {
ls.lock.Lock() ls.lock.Lock()
defer ls.lock.Unlock() defer ls.lock.Unlock()
ls.values = make(map[interface{}]interface{}) ls.values = make(map[interface{}]interface{})
@ -61,12 +62,12 @@ func (ls *SessionStore) Flush() error {
} }
// SessionID get ledis session id // SessionID get ledis session id
func (ls *SessionStore) SessionID() string { func (ls *SessionStore) SessionID(context.Context) string {
return ls.sid return ls.sid
} }
// SessionRelease save session values to ledis // SessionRelease save session values to ledis
func (ls *SessionStore) SessionRelease(w http.ResponseWriter) { func (ls *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(ls.values) b, err := session.EncodeGob(ls.values)
if err != nil { if err != nil {
return return
@ -85,7 +86,7 @@ type Provider struct {
// SessionInit init ledis session // SessionInit init ledis session
// savepath like ledis server saveDataPath,pool size // savepath like ledis server saveDataPath,pool size
// e.g. 127.0.0.1:6379,100,astaxie // e.g. 127.0.0.1:6379,100,astaxie
func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error { func (lp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
var err error var err error
lp.maxlifetime = maxlifetime lp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",") configs := strings.Split(savePath, ",")
@ -111,7 +112,7 @@ func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error {
} }
// SessionRead read ledis session by sid // SessionRead read ledis session by sid
func (lp *Provider) SessionRead(sid string) (session.Store, error) { func (lp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
var ( var (
kv map[interface{}]interface{} kv map[interface{}]interface{}
err error err error
@ -132,13 +133,13 @@ func (lp *Provider) SessionRead(sid string) (session.Store, error) {
} }
// SessionExist check ledis session exist by sid // SessionExist check ledis session exist by sid
func (lp *Provider) SessionExist(sid string) (bool, error) { func (lp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
count, _ := c.Exists([]byte(sid)) count, _ := c.Exists([]byte(sid))
return count != 0, nil return count != 0, nil
} }
// SessionRegenerate generate new sid for ledis session // SessionRegenerate generate new sid for ledis session
func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { func (lp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
count, _ := c.Exists([]byte(sid)) count, _ := c.Exists([]byte(sid))
if count == 0 { if count == 0 {
// oldsid doesn't exists, set the new sid directly // oldsid doesn't exists, set the new sid directly
@ -151,21 +152,21 @@ func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error)
c.Set([]byte(sid), data) c.Set([]byte(sid), data)
c.Expire([]byte(sid), lp.maxlifetime) c.Expire([]byte(sid), lp.maxlifetime)
} }
return lp.SessionRead(sid) return lp.SessionRead(context.Background(), sid)
} }
// SessionDestroy delete ledis session by id // SessionDestroy delete ledis session by id
func (lp *Provider) SessionDestroy(sid string) error { func (lp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c.Del([]byte(sid)) c.Del([]byte(sid))
return nil return nil
} }
// SessionGC Impelment method, no used. // SessionGC Impelment method, no used.
func (lp *Provider) SessionGC() { func (lp *Provider) SessionGC(context.Context) {
} }
// SessionAll return all active session // SessionAll return all active session
func (lp *Provider) SessionAll() int { func (lp *Provider) SessionAll(context.Context) int {
return 0 return 0
} }
func init() { func init() {

View File

@ -33,6 +33,7 @@
package memcache package memcache
import ( import (
"context"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
@ -54,7 +55,7 @@ type SessionStore struct {
} }
// Set value in memcache session // Set value in memcache session
func (rs *SessionStore) Set(key, value interface{}) error { func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
rs.lock.Lock() rs.lock.Lock()
defer rs.lock.Unlock() defer rs.lock.Unlock()
rs.values[key] = value rs.values[key] = value
@ -62,7 +63,7 @@ func (rs *SessionStore) Set(key, value interface{}) error {
} }
// Get value in memcache session // Get value in memcache session
func (rs *SessionStore) Get(key interface{}) interface{} { func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
rs.lock.RLock() rs.lock.RLock()
defer rs.lock.RUnlock() defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok { if v, ok := rs.values[key]; ok {
@ -72,7 +73,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} {
} }
// Delete value in memcache session // Delete value in memcache session
func (rs *SessionStore) Delete(key interface{}) error { func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error {
rs.lock.Lock() rs.lock.Lock()
defer rs.lock.Unlock() defer rs.lock.Unlock()
delete(rs.values, key) delete(rs.values, key)
@ -80,7 +81,7 @@ func (rs *SessionStore) Delete(key interface{}) error {
} }
// Flush clear all values in memcache session // Flush clear all values in memcache session
func (rs *SessionStore) Flush() error { func (rs *SessionStore) Flush(context.Context) error {
rs.lock.Lock() rs.lock.Lock()
defer rs.lock.Unlock() defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{}) rs.values = make(map[interface{}]interface{})
@ -88,12 +89,12 @@ func (rs *SessionStore) Flush() error {
} }
// SessionID get memcache session id // SessionID get memcache session id
func (rs *SessionStore) SessionID() string { func (rs *SessionStore) SessionID(context.Context) string {
return rs.sid return rs.sid
} }
// SessionRelease save session values to memcache // SessionRelease save session values to memcache
func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values) b, err := session.EncodeGob(rs.values)
if err != nil { if err != nil {
return return
@ -113,7 +114,7 @@ type MemProvider struct {
// SessionInit init memcache session // SessionInit init memcache session
// savepath like // savepath like
// e.g. 127.0.0.1:9090 // e.g. 127.0.0.1:9090
func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error { func (rp *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime rp.maxlifetime = maxlifetime
rp.conninfo = strings.Split(savePath, ";") rp.conninfo = strings.Split(savePath, ";")
client = memcache.New(rp.conninfo...) client = memcache.New(rp.conninfo...)
@ -121,7 +122,7 @@ func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error {
} }
// SessionRead read memcache session by sid // SessionRead read memcache session by sid
func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { func (rp *MemProvider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
if client == nil { if client == nil {
if err := rp.connectInit(); err != nil { if err := rp.connectInit(); err != nil {
return nil, err return nil, err
@ -149,7 +150,7 @@ func (rp *MemProvider) SessionRead(sid string) (session.Store, error) {
} }
// SessionExist check memcache session exist by sid // SessionExist check memcache session exist by sid
func (rp *MemProvider) SessionExist(sid string) (bool, error) { func (rp *MemProvider) SessionExist(ctx context.Context, sid string) (bool, error) {
if client == nil { if client == nil {
if err := rp.connectInit(); err != nil { if err := rp.connectInit(); err != nil {
return false, err return false, err
@ -162,7 +163,7 @@ func (rp *MemProvider) SessionExist(sid string) (bool, error) {
} }
// SessionRegenerate generate new sid for memcache session // SessionRegenerate generate new sid for memcache session
func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) { func (rp *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
if client == nil { if client == nil {
if err := rp.connectInit(); err != nil { if err := rp.connectInit(); err != nil {
return nil, err return nil, err
@ -201,7 +202,7 @@ func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, err
} }
// SessionDestroy delete memcache session by id // SessionDestroy delete memcache session by id
func (rp *MemProvider) SessionDestroy(sid string) error { func (rp *MemProvider) SessionDestroy(ctx context.Context, sid string) error {
if client == nil { if client == nil {
if err := rp.connectInit(); err != nil { if err := rp.connectInit(); err != nil {
return err return err
@ -217,11 +218,11 @@ func (rp *MemProvider) connectInit() error {
} }
// SessionGC Impelment method, no used. // SessionGC Impelment method, no used.
func (rp *MemProvider) SessionGC() { func (rp *MemProvider) SessionGC(context.Context) {
} }
// SessionAll return all activeSession // SessionAll return all activeSession
func (rp *MemProvider) SessionAll() int { func (rp *MemProvider) SessionAll(context.Context) int {
return 0 return 0
} }

View File

@ -41,6 +41,7 @@
package mysql package mysql
import ( import (
"context"
"database/sql" "database/sql"
"net/http" "net/http"
"sync" "sync"
@ -67,7 +68,7 @@ type SessionStore struct {
// Set value in mysql session. // Set value in mysql session.
// it is temp value in map. // it is temp value in map.
func (st *SessionStore) Set(key, value interface{}) error { func (st *SessionStore) Set(ctx context.Context, key, value interface{}) error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
st.values[key] = value st.values[key] = value
@ -75,7 +76,7 @@ func (st *SessionStore) Set(key, value interface{}) error {
} }
// Get value from mysql session // Get value from mysql session
func (st *SessionStore) Get(key interface{}) interface{} { func (st *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
st.lock.RLock() st.lock.RLock()
defer st.lock.RUnlock() defer st.lock.RUnlock()
if v, ok := st.values[key]; ok { if v, ok := st.values[key]; ok {
@ -85,7 +86,7 @@ func (st *SessionStore) Get(key interface{}) interface{} {
} }
// Delete value in mysql session // Delete value in mysql session
func (st *SessionStore) Delete(key interface{}) error { func (st *SessionStore) Delete(ctx context.Context, key interface{}) error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
delete(st.values, key) delete(st.values, key)
@ -93,7 +94,7 @@ func (st *SessionStore) Delete(key interface{}) error {
} }
// Flush clear all values in mysql session // Flush clear all values in mysql session
func (st *SessionStore) Flush() error { func (st *SessionStore) Flush(context.Context) error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
st.values = make(map[interface{}]interface{}) st.values = make(map[interface{}]interface{})
@ -101,13 +102,13 @@ func (st *SessionStore) Flush() error {
} }
// SessionID get session id of this mysql session store // SessionID get session id of this mysql session store
func (st *SessionStore) SessionID() string { func (st *SessionStore) SessionID(context.Context) string {
return st.sid return st.sid
} }
// SessionRelease save mysql session values to database. // SessionRelease save mysql session values to database.
// must call this method to save values to database. // must call this method to save values to database.
func (st *SessionStore) SessionRelease(w http.ResponseWriter) { func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
defer st.c.Close() defer st.c.Close()
b, err := session.EncodeGob(st.values) b, err := session.EncodeGob(st.values)
if err != nil { if err != nil {
@ -134,14 +135,14 @@ func (mp *Provider) connectInit() *sql.DB {
// SessionInit init mysql session. // SessionInit init mysql session.
// savepath is the connection string of mysql. // savepath is the connection string of mysql.
func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { func (mp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
mp.maxlifetime = maxlifetime mp.maxlifetime = maxlifetime
mp.savePath = savePath mp.savePath = savePath
return nil return nil
} }
// SessionRead get mysql session by sid // SessionRead get mysql session by sid
func (mp *Provider) SessionRead(sid string) (session.Store, error) { func (mp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
c := mp.connectInit() c := mp.connectInit()
row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid)
var sessiondata []byte var sessiondata []byte
@ -164,7 +165,7 @@ func (mp *Provider) SessionRead(sid string) (session.Store, error) {
} }
// SessionExist check mysql session exist // SessionExist check mysql session exist
func (mp *Provider) SessionExist(sid string) (bool, error) { func (mp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := mp.connectInit() c := mp.connectInit()
defer c.Close() defer c.Close()
row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid)
@ -180,7 +181,7 @@ func (mp *Provider) SessionExist(sid string) (bool, error) {
} }
// SessionRegenerate generate new sid for mysql session // SessionRegenerate generate new sid for mysql session
func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { func (mp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := mp.connectInit() c := mp.connectInit()
row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid) row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid)
var sessiondata []byte var sessiondata []byte
@ -203,7 +204,7 @@ func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error)
} }
// SessionDestroy delete mysql session by sid // SessionDestroy delete mysql session by sid
func (mp *Provider) SessionDestroy(sid string) error { func (mp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := mp.connectInit() c := mp.connectInit()
c.Exec("DELETE FROM "+TableName+" where session_key=?", sid) c.Exec("DELETE FROM "+TableName+" where session_key=?", sid)
c.Close() c.Close()
@ -211,14 +212,14 @@ func (mp *Provider) SessionDestroy(sid string) error {
} }
// SessionGC delete expired values in mysql session // SessionGC delete expired values in mysql session
func (mp *Provider) SessionGC() { func (mp *Provider) SessionGC(context.Context) {
c := mp.connectInit() c := mp.connectInit()
c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime) c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime)
c.Close() c.Close()
} }
// SessionAll count values in mysql session // SessionAll count values in mysql session
func (mp *Provider) SessionAll() int { func (mp *Provider) SessionAll(context.Context) int {
c := mp.connectInit() c := mp.connectInit()
defer c.Close() defer c.Close()
var total int var total int

View File

@ -51,6 +51,7 @@
package postgres package postgres
import ( import (
"context"
"database/sql" "database/sql"
"net/http" "net/http"
"sync" "sync"
@ -73,7 +74,7 @@ type SessionStore struct {
// Set value in postgresql session. // Set value in postgresql session.
// it is temp value in map. // it is temp value in map.
func (st *SessionStore) Set(key, value interface{}) error { func (st *SessionStore) Set(ctx context.Context, key, value interface{}) error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
st.values[key] = value st.values[key] = value
@ -81,7 +82,7 @@ func (st *SessionStore) Set(key, value interface{}) error {
} }
// Get value from postgresql session // Get value from postgresql session
func (st *SessionStore) Get(key interface{}) interface{} { func (st *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
st.lock.RLock() st.lock.RLock()
defer st.lock.RUnlock() defer st.lock.RUnlock()
if v, ok := st.values[key]; ok { if v, ok := st.values[key]; ok {
@ -91,7 +92,7 @@ func (st *SessionStore) Get(key interface{}) interface{} {
} }
// Delete value in postgresql session // Delete value in postgresql session
func (st *SessionStore) Delete(key interface{}) error { func (st *SessionStore) Delete(ctx context.Context, key interface{}) error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
delete(st.values, key) delete(st.values, key)
@ -99,7 +100,7 @@ func (st *SessionStore) Delete(key interface{}) error {
} }
// Flush clear all values in postgresql session // Flush clear all values in postgresql session
func (st *SessionStore) Flush() error { func (st *SessionStore) Flush(context.Context) error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
st.values = make(map[interface{}]interface{}) st.values = make(map[interface{}]interface{})
@ -107,13 +108,13 @@ func (st *SessionStore) Flush() error {
} }
// SessionID get session id of this postgresql session store // SessionID get session id of this postgresql session store
func (st *SessionStore) SessionID() string { func (st *SessionStore) SessionID(context.Context) string {
return st.sid return st.sid
} }
// SessionRelease save postgresql session values to database. // SessionRelease save postgresql session values to database.
// must call this method to save values to database. // must call this method to save values to database.
func (st *SessionStore) SessionRelease(w http.ResponseWriter) { func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
defer st.c.Close() defer st.c.Close()
b, err := session.EncodeGob(st.values) b, err := session.EncodeGob(st.values)
if err != nil { if err != nil {
@ -141,14 +142,14 @@ func (mp *Provider) connectInit() *sql.DB {
// SessionInit init postgresql session. // SessionInit init postgresql session.
// savepath is the connection string of postgresql. // savepath is the connection string of postgresql.
func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { func (mp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
mp.maxlifetime = maxlifetime mp.maxlifetime = maxlifetime
mp.savePath = savePath mp.savePath = savePath
return nil return nil
} }
// SessionRead get postgresql session by sid // SessionRead get postgresql session by sid
func (mp *Provider) SessionRead(sid string) (session.Store, error) { func (mp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
c := mp.connectInit() c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=$1", sid) row := c.QueryRow("select session_data from session where session_key=$1", sid)
var sessiondata []byte var sessiondata []byte
@ -178,7 +179,7 @@ func (mp *Provider) SessionRead(sid string) (session.Store, error) {
} }
// SessionExist check postgresql session exist // SessionExist check postgresql session exist
func (mp *Provider) SessionExist(sid string) (bool, error) { func (mp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := mp.connectInit() c := mp.connectInit()
defer c.Close() defer c.Close()
row := c.QueryRow("select session_data from session where session_key=$1", sid) row := c.QueryRow("select session_data from session where session_key=$1", sid)
@ -194,7 +195,7 @@ func (mp *Provider) SessionExist(sid string) (bool, error) {
} }
// SessionRegenerate generate new sid for postgresql session // SessionRegenerate generate new sid for postgresql session
func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { func (mp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := mp.connectInit() c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=$1", oldsid) row := c.QueryRow("select session_data from session where session_key=$1", oldsid)
var sessiondata []byte var sessiondata []byte
@ -218,7 +219,7 @@ func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error)
} }
// SessionDestroy delete postgresql session by sid // SessionDestroy delete postgresql session by sid
func (mp *Provider) SessionDestroy(sid string) error { func (mp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := mp.connectInit() c := mp.connectInit()
c.Exec("DELETE FROM session where session_key=$1", sid) c.Exec("DELETE FROM session where session_key=$1", sid)
c.Close() c.Close()
@ -226,14 +227,14 @@ func (mp *Provider) SessionDestroy(sid string) error {
} }
// SessionGC delete expired values in postgresql session // SessionGC delete expired values in postgresql session
func (mp *Provider) SessionGC() { func (mp *Provider) SessionGC(context.Context) {
c := mp.connectInit() c := mp.connectInit()
c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime) c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime)
c.Close() c.Close()
} }
// SessionAll count values in postgresql session // SessionAll count values in postgresql session
func (mp *Provider) SessionAll() int { func (mp *Provider) SessionAll(context.Context) int {
c := mp.connectInit() c := mp.connectInit()
defer c.Close() defer c.Close()
var total int var total int

View File

@ -33,6 +33,7 @@
package redis package redis
import ( import (
"context"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@ -59,7 +60,7 @@ type SessionStore struct {
} }
// Set value in redis session // Set value in redis session
func (rs *SessionStore) Set(key, value interface{}) error { func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
rs.lock.Lock() rs.lock.Lock()
defer rs.lock.Unlock() defer rs.lock.Unlock()
rs.values[key] = value rs.values[key] = value
@ -67,7 +68,7 @@ func (rs *SessionStore) Set(key, value interface{}) error {
} }
// Get value in redis session // Get value in redis session
func (rs *SessionStore) Get(key interface{}) interface{} { func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
rs.lock.RLock() rs.lock.RLock()
defer rs.lock.RUnlock() defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok { if v, ok := rs.values[key]; ok {
@ -77,7 +78,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} {
} }
// Delete value in redis session // Delete value in redis session
func (rs *SessionStore) Delete(key interface{}) error { func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error {
rs.lock.Lock() rs.lock.Lock()
defer rs.lock.Unlock() defer rs.lock.Unlock()
delete(rs.values, key) delete(rs.values, key)
@ -85,7 +86,7 @@ func (rs *SessionStore) Delete(key interface{}) error {
} }
// Flush clear all values in redis session // Flush clear all values in redis session
func (rs *SessionStore) Flush() error { func (rs *SessionStore) Flush(context.Context) error {
rs.lock.Lock() rs.lock.Lock()
defer rs.lock.Unlock() defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{}) rs.values = make(map[interface{}]interface{})
@ -93,12 +94,12 @@ func (rs *SessionStore) Flush() error {
} }
// SessionID get redis session id // SessionID get redis session id
func (rs *SessionStore) SessionID() string { func (rs *SessionStore) SessionID(context.Context) string {
return rs.sid return rs.sid
} }
// SessionRelease save session values to redis // SessionRelease save session values to redis
func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values) b, err := session.EncodeGob(rs.values)
if err != nil { if err != nil {
return return
@ -123,7 +124,7 @@ type Provider struct {
// SessionInit init redis session // SessionInit init redis session
// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second // savepath like redis server addr,pool size,password,dbnum,IdleTimeout second
// e.g. 127.0.0.1:6379,100,astaxie,0,30 // e.g. 127.0.0.1:6379,100,astaxie,0,30
func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",") configs := strings.Split(savePath, ",")
if len(configs) > 0 { if len(configs) > 0 {
@ -185,7 +186,7 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
} }
// SessionRead read redis session by sid // SessionRead read redis session by sid
func (rp *Provider) SessionRead(sid string) (session.Store, error) { func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
var kv map[interface{}]interface{} var kv map[interface{}]interface{}
kvs, err := rp.poollist.Get(sid).Result() kvs, err := rp.poollist.Get(sid).Result()
@ -205,7 +206,7 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) {
} }
// SessionExist check redis session exist by sid // SessionExist check redis session exist by sid
func (rp *Provider) SessionExist(sid string) (bool, error) { func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := rp.poollist c := rp.poollist
if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 {
@ -215,7 +216,7 @@ func (rp *Provider) SessionExist(sid string) (bool, error) {
} }
// SessionRegenerate generate new sid for redis session // SessionRegenerate generate new sid for redis session
func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := rp.poollist c := rp.poollist
if existed, _ := c.Exists(oldsid).Result(); existed == 0 { if existed, _ := c.Exists(oldsid).Result(); existed == 0 {
// oldsid doesn't exists, set the new sid directly // oldsid doesn't exists, set the new sid directly
@ -226,11 +227,11 @@ func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error)
c.Rename(oldsid, sid) c.Rename(oldsid, sid)
c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second)
} }
return rp.SessionRead(sid) return rp.SessionRead(context.Background(), sid)
} }
// SessionDestroy delete redis session by id // SessionDestroy delete redis session by id
func (rp *Provider) SessionDestroy(sid string) error { func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := rp.poollist c := rp.poollist
c.Del(sid) c.Del(sid)
@ -238,11 +239,11 @@ func (rp *Provider) SessionDestroy(sid string) error {
} }
// SessionGC Impelment method, no used. // SessionGC Impelment method, no used.
func (rp *Provider) SessionGC() { func (rp *Provider) SessionGC(context.Context) {
} }
// SessionAll return all activeSession // SessionAll return all activeSession
func (rp *Provider) SessionAll() int { func (rp *Provider) SessionAll(context.Context) int {
return 0 return 0
} }

View File

@ -40,57 +40,57 @@ func TestRedis(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("session start failed:", err) t.Fatal("session start failed:", err)
} }
defer sess.SessionRelease(w) defer sess.SessionRelease(nil, w)
// SET AND GET // SET AND GET
err = sess.Set("username", "astaxie") err = sess.Set(nil, "username", "astaxie")
if err != nil { if err != nil {
t.Fatal("set username failed:", err) t.Fatal("set username failed:", err)
} }
username := sess.Get("username") username := sess.Get(nil, "username")
if username != "astaxie" { if username != "astaxie" {
t.Fatal("get username failed") t.Fatal("get username failed")
} }
// DELETE // DELETE
err = sess.Delete("username") err = sess.Delete(nil, "username")
if err != nil { if err != nil {
t.Fatal("delete username failed:", err) t.Fatal("delete username failed:", err)
} }
username = sess.Get("username") username = sess.Get(nil, "username")
if username != nil { if username != nil {
t.Fatal("delete username failed") t.Fatal("delete username failed")
} }
// FLUSH // FLUSH
err = sess.Set("username", "astaxie") err = sess.Set(nil, "username", "astaxie")
if err != nil { if err != nil {
t.Fatal("set failed:", err) t.Fatal("set failed:", err)
} }
err = sess.Set("password", "1qaz2wsx") err = sess.Set(nil, "password", "1qaz2wsx")
if err != nil { if err != nil {
t.Fatal("set failed:", err) t.Fatal("set failed:", err)
} }
username = sess.Get("username") username = sess.Get(nil, "username")
if username != "astaxie" { if username != "astaxie" {
t.Fatal("get username failed") t.Fatal("get username failed")
} }
password := sess.Get("password") password := sess.Get(nil, "password")
if password != "1qaz2wsx" { if password != "1qaz2wsx" {
t.Fatal("get password failed") t.Fatal("get password failed")
} }
err = sess.Flush() err = sess.Flush(nil)
if err != nil { if err != nil {
t.Fatal("flush failed:", err) t.Fatal("flush failed:", err)
} }
username = sess.Get("username") username = sess.Get(nil, "username")
if username != nil { if username != nil {
t.Fatal("flush failed") t.Fatal("flush failed")
} }
password = sess.Get("password") password = sess.Get(nil, "password")
if password != nil { if password != nil {
t.Fatal("flush failed") t.Fatal("flush failed")
} }
sess.SessionRelease(w) sess.SessionRelease(nil, w)
} }

View File

@ -33,6 +33,7 @@
package redis_cluster package redis_cluster
import ( import (
"context"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@ -58,7 +59,7 @@ type SessionStore struct {
} }
// Set value in redis_cluster session // Set value in redis_cluster session
func (rs *SessionStore) Set(key, value interface{}) error { func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
rs.lock.Lock() rs.lock.Lock()
defer rs.lock.Unlock() defer rs.lock.Unlock()
rs.values[key] = value rs.values[key] = value
@ -66,7 +67,7 @@ func (rs *SessionStore) Set(key, value interface{}) error {
} }
// Get value in redis_cluster session // Get value in redis_cluster session
func (rs *SessionStore) Get(key interface{}) interface{} { func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
rs.lock.RLock() rs.lock.RLock()
defer rs.lock.RUnlock() defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok { if v, ok := rs.values[key]; ok {
@ -76,7 +77,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} {
} }
// Delete value in redis_cluster session // Delete value in redis_cluster session
func (rs *SessionStore) Delete(key interface{}) error { func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error {
rs.lock.Lock() rs.lock.Lock()
defer rs.lock.Unlock() defer rs.lock.Unlock()
delete(rs.values, key) delete(rs.values, key)
@ -84,7 +85,7 @@ func (rs *SessionStore) Delete(key interface{}) error {
} }
// Flush clear all values in redis_cluster session // Flush clear all values in redis_cluster session
func (rs *SessionStore) Flush() error { func (rs *SessionStore) Flush(context.Context) error {
rs.lock.Lock() rs.lock.Lock()
defer rs.lock.Unlock() defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{}) rs.values = make(map[interface{}]interface{})
@ -92,12 +93,12 @@ func (rs *SessionStore) Flush() error {
} }
// SessionID get redis_cluster session id // SessionID get redis_cluster session id
func (rs *SessionStore) SessionID() string { func (rs *SessionStore) SessionID(context.Context) string {
return rs.sid return rs.sid
} }
// SessionRelease save session values to redis_cluster // SessionRelease save session values to redis_cluster
func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values) b, err := session.EncodeGob(rs.values)
if err != nil { if err != nil {
return return
@ -122,7 +123,7 @@ type Provider struct {
// SessionInit init redis_cluster session // SessionInit init redis_cluster session
// savepath like redis server addr,pool size,password,dbnum // savepath like redis server addr,pool size,password,dbnum
// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0 // e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0
func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",") configs := strings.Split(savePath, ",")
if len(configs) > 0 { if len(configs) > 0 {
@ -182,7 +183,7 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
} }
// SessionRead read redis_cluster session by sid // SessionRead read redis_cluster session by sid
func (rp *Provider) SessionRead(sid string) (session.Store, error) { func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
var kv map[interface{}]interface{} var kv map[interface{}]interface{}
kvs, err := rp.poollist.Get(sid).Result() kvs, err := rp.poollist.Get(sid).Result()
if err != nil && err != rediss.Nil { if err != nil && err != rediss.Nil {
@ -201,7 +202,7 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) {
} }
// SessionExist check redis_cluster session exist by sid // SessionExist check redis_cluster session exist by sid
func (rp *Provider) SessionExist(sid string) (bool, error) { func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := rp.poollist c := rp.poollist
if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 {
return false, err return false, err
@ -210,7 +211,7 @@ func (rp *Provider) SessionExist(sid string) (bool, error) {
} }
// SessionRegenerate generate new sid for redis_cluster session // SessionRegenerate generate new sid for redis_cluster session
func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := rp.poollist c := rp.poollist
if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 {
@ -222,22 +223,22 @@ func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error)
c.Rename(oldsid, sid) c.Rename(oldsid, sid)
c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second)
} }
return rp.SessionRead(sid) return rp.SessionRead(context.Background(), sid)
} }
// SessionDestroy delete redis session by id // SessionDestroy delete redis session by id
func (rp *Provider) SessionDestroy(sid string) error { func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := rp.poollist c := rp.poollist
c.Del(sid) c.Del(sid)
return nil return nil
} }
// SessionGC Impelment method, no used. // SessionGC Impelment method, no used.
func (rp *Provider) SessionGC() { func (rp *Provider) SessionGC(context.Context) {
} }
// SessionAll return all activeSession // SessionAll return all activeSession
func (rp *Provider) SessionAll() int { func (rp *Provider) SessionAll(context.Context) int {
return 0 return 0
} }

View File

@ -33,6 +33,7 @@
package redis_sentinel package redis_sentinel
import ( import (
"context"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@ -58,7 +59,7 @@ type SessionStore struct {
} }
// Set value in redis_sentinel session // Set value in redis_sentinel session
func (rs *SessionStore) Set(key, value interface{}) error { func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
rs.lock.Lock() rs.lock.Lock()
defer rs.lock.Unlock() defer rs.lock.Unlock()
rs.values[key] = value rs.values[key] = value
@ -66,7 +67,7 @@ func (rs *SessionStore) Set(key, value interface{}) error {
} }
// Get value in redis_sentinel session // Get value in redis_sentinel session
func (rs *SessionStore) Get(key interface{}) interface{} { func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
rs.lock.RLock() rs.lock.RLock()
defer rs.lock.RUnlock() defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok { if v, ok := rs.values[key]; ok {
@ -76,7 +77,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} {
} }
// Delete value in redis_sentinel session // Delete value in redis_sentinel session
func (rs *SessionStore) Delete(key interface{}) error { func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error {
rs.lock.Lock() rs.lock.Lock()
defer rs.lock.Unlock() defer rs.lock.Unlock()
delete(rs.values, key) delete(rs.values, key)
@ -84,7 +85,7 @@ func (rs *SessionStore) Delete(key interface{}) error {
} }
// Flush clear all values in redis_sentinel session // Flush clear all values in redis_sentinel session
func (rs *SessionStore) Flush() error { func (rs *SessionStore) Flush(context.Context) error {
rs.lock.Lock() rs.lock.Lock()
defer rs.lock.Unlock() defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{}) rs.values = make(map[interface{}]interface{})
@ -92,12 +93,12 @@ func (rs *SessionStore) Flush() error {
} }
// SessionID get redis_sentinel session id // SessionID get redis_sentinel session id
func (rs *SessionStore) SessionID() string { func (rs *SessionStore) SessionID(context.Context) string {
return rs.sid return rs.sid
} }
// SessionRelease save session values to redis_sentinel // SessionRelease save session values to redis_sentinel
func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values) b, err := session.EncodeGob(rs.values)
if err != nil { if err != nil {
return return
@ -123,7 +124,7 @@ type Provider struct {
// SessionInit init redis_sentinel session // SessionInit init redis_sentinel session
// savepath like redis sentinel addr,pool size,password,dbnum,masterName // savepath like redis sentinel addr,pool size,password,dbnum,masterName
// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster // e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster
func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",") configs := strings.Split(savePath, ",")
if len(configs) > 0 { if len(configs) > 0 {
@ -195,7 +196,7 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
} }
// SessionRead read redis_sentinel session by sid // SessionRead read redis_sentinel session by sid
func (rp *Provider) SessionRead(sid string) (session.Store, error) { func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
var kv map[interface{}]interface{} var kv map[interface{}]interface{}
kvs, err := rp.poollist.Get(sid).Result() kvs, err := rp.poollist.Get(sid).Result()
if err != nil && err != redis.Nil { if err != nil && err != redis.Nil {
@ -214,7 +215,7 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) {
} }
// SessionExist check redis_sentinel session exist by sid // SessionExist check redis_sentinel session exist by sid
func (rp *Provider) SessionExist(sid string) (bool, error) { func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := rp.poollist c := rp.poollist
if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 {
return false, err return false, err
@ -223,7 +224,7 @@ func (rp *Provider) SessionExist(sid string) (bool, error) {
} }
// SessionRegenerate generate new sid for redis_sentinel session // SessionRegenerate generate new sid for redis_sentinel session
func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := rp.poollist c := rp.poollist
if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 {
@ -235,22 +236,22 @@ func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error)
c.Rename(oldsid, sid) c.Rename(oldsid, sid)
c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second)
} }
return rp.SessionRead(sid) return rp.SessionRead(context.Background(), sid)
} }
// SessionDestroy delete redis session by id // SessionDestroy delete redis session by id
func (rp *Provider) SessionDestroy(sid string) error { func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := rp.poollist c := rp.poollist
c.Del(sid) c.Del(sid)
return nil return nil
} }
// SessionGC Impelment method, no used. // SessionGC Impelment method, no used.
func (rp *Provider) SessionGC() { func (rp *Provider) SessionGC(context.Context) {
} }
// SessionAll return all activeSession // SessionAll return all activeSession
func (rp *Provider) SessionAll() int { func (rp *Provider) SessionAll(context.Context) int {
return 0 return 0
} }

View File

@ -33,58 +33,58 @@ func TestRedisSentinel(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("session start failed:", err) t.Fatal("session start failed:", err)
} }
defer sess.SessionRelease(w) defer sess.SessionRelease(nil, w)
// SET AND GET // SET AND GET
err = sess.Set("username", "astaxie") err = sess.Set(nil, "username", "astaxie")
if err != nil { if err != nil {
t.Fatal("set username failed:", err) t.Fatal("set username failed:", err)
} }
username := sess.Get("username") username := sess.Get(nil, "username")
if username != "astaxie" { if username != "astaxie" {
t.Fatal("get username failed") t.Fatal("get username failed")
} }
// DELETE // DELETE
err = sess.Delete("username") err = sess.Delete(nil, "username")
if err != nil { if err != nil {
t.Fatal("delete username failed:", err) t.Fatal("delete username failed:", err)
} }
username = sess.Get("username") username = sess.Get(nil, "username")
if username != nil { if username != nil {
t.Fatal("delete username failed") t.Fatal("delete username failed")
} }
// FLUSH // FLUSH
err = sess.Set("username", "astaxie") err = sess.Set(nil, "username", "astaxie")
if err != nil { if err != nil {
t.Fatal("set failed:", err) t.Fatal("set failed:", err)
} }
err = sess.Set("password", "1qaz2wsx") err = sess.Set(nil, "password", "1qaz2wsx")
if err != nil { if err != nil {
t.Fatal("set failed:", err) t.Fatal("set failed:", err)
} }
username = sess.Get("username") username = sess.Get(nil, "username")
if username != "astaxie" { if username != "astaxie" {
t.Fatal("get username failed") t.Fatal("get username failed")
} }
password := sess.Get("password") password := sess.Get(nil, "password")
if password != "1qaz2wsx" { if password != "1qaz2wsx" {
t.Fatal("get password failed") t.Fatal("get password failed")
} }
err = sess.Flush() err = sess.Flush(nil)
if err != nil { if err != nil {
t.Fatal("flush failed:", err) t.Fatal("flush failed:", err)
} }
username = sess.Get("username") username = sess.Get(nil, "username")
if username != nil { if username != nil {
t.Fatal("flush failed") t.Fatal("flush failed")
} }
password = sess.Get("password") password = sess.Get(nil, "password")
if password != nil { if password != nil {
t.Fatal("flush failed") t.Fatal("flush failed")
} }
sess.SessionRelease(w) sess.SessionRelease(nil, w)
} }

View File

@ -15,6 +15,7 @@
package session package session
import ( import (
"context"
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"encoding/json" "encoding/json"
@ -34,7 +35,7 @@ type CookieSessionStore struct {
// Set value to cookie session. // Set value to cookie session.
// the value are encoded as gob with hash block string. // the value are encoded as gob with hash block string.
func (st *CookieSessionStore) Set(key, value interface{}) error { func (st *CookieSessionStore) Set(ctx context.Context, key, value interface{}) error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
st.values[key] = value st.values[key] = value
@ -42,7 +43,7 @@ func (st *CookieSessionStore) Set(key, value interface{}) error {
} }
// Get value from cookie session // Get value from cookie session
func (st *CookieSessionStore) Get(key interface{}) interface{} { func (st *CookieSessionStore) Get(ctx context.Context, key interface{}) interface{} {
st.lock.RLock() st.lock.RLock()
defer st.lock.RUnlock() defer st.lock.RUnlock()
if v, ok := st.values[key]; ok { if v, ok := st.values[key]; ok {
@ -52,7 +53,7 @@ func (st *CookieSessionStore) Get(key interface{}) interface{} {
} }
// Delete value in cookie session // Delete value in cookie session
func (st *CookieSessionStore) Delete(key interface{}) error { func (st *CookieSessionStore) Delete(ctx context.Context, key interface{}) error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
delete(st.values, key) delete(st.values, key)
@ -60,7 +61,7 @@ func (st *CookieSessionStore) Delete(key interface{}) error {
} }
// Flush Clean all values in cookie session // Flush Clean all values in cookie session
func (st *CookieSessionStore) Flush() error { func (st *CookieSessionStore) Flush(context.Context) error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
st.values = make(map[interface{}]interface{}) st.values = make(map[interface{}]interface{})
@ -68,12 +69,12 @@ func (st *CookieSessionStore) Flush() error {
} }
// SessionID Return id of this cookie session // SessionID Return id of this cookie session
func (st *CookieSessionStore) SessionID() string { func (st *CookieSessionStore) SessionID(context.Context) string {
return st.sid return st.sid
} }
// SessionRelease Write cookie session to http response cookie // SessionRelease Write cookie session to http response cookie
func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) { func (st *CookieSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
st.lock.Lock() st.lock.Lock()
encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values) encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values)
st.lock.Unlock() st.lock.Unlock()
@ -112,7 +113,7 @@ type CookieProvider struct {
// securityName - recognized name in encoded cookie string // securityName - recognized name in encoded cookie string
// cookieName - cookie name // cookieName - cookie name
// maxage - cookie max life time. // maxage - cookie max life time.
func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error { func (pder *CookieProvider) SessionInit(ctx context.Context, maxlifetime int64, config string) error {
pder.config = &cookieConfig{} pder.config = &cookieConfig{}
err := json.Unmarshal([]byte(config), pder.config) err := json.Unmarshal([]byte(config), pder.config)
if err != nil { if err != nil {
@ -134,7 +135,7 @@ func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error
// SessionRead Get SessionStore in cooke. // SessionRead Get SessionStore in cooke.
// decode cooke string to map and put into SessionStore with sid. // decode cooke string to map and put into SessionStore with sid.
func (pder *CookieProvider) SessionRead(sid string) (Store, error) { func (pder *CookieProvider) SessionRead(ctx context.Context, sid string) (Store, error) {
maps, _ := decodeCookie(pder.block, maps, _ := decodeCookie(pder.block,
pder.config.SecurityKey, pder.config.SecurityKey,
pder.config.SecurityName, pder.config.SecurityName,
@ -147,26 +148,26 @@ func (pder *CookieProvider) SessionRead(sid string) (Store, error) {
} }
// SessionExist Cookie session is always existed // SessionExist Cookie session is always existed
func (pder *CookieProvider) SessionExist(sid string) (bool, error) { func (pder *CookieProvider) SessionExist(ctx context.Context, sid string) (bool, error) {
return true, nil return true, nil
} }
// SessionRegenerate Implement method, no used. // SessionRegenerate Implement method, no used.
func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (Store, error) { func (pder *CookieProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) {
return nil, nil return nil, nil
} }
// SessionDestroy Implement method, no used. // SessionDestroy Implement method, no used.
func (pder *CookieProvider) SessionDestroy(sid string) error { func (pder *CookieProvider) SessionDestroy(ctx context.Context, sid string) error {
return nil return nil
} }
// SessionGC Implement method, no used. // SessionGC Implement method, no used.
func (pder *CookieProvider) SessionGC() { func (pder *CookieProvider) SessionGC(context.Context) {
} }
// SessionAll Implement method, return 0. // SessionAll Implement method, return 0.
func (pder *CookieProvider) SessionAll() int { func (pder *CookieProvider) SessionAll(context.Context) int {
return 0 return 0
} }

View File

@ -38,14 +38,14 @@ func TestCookie(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("set error,", err) t.Fatal("set error,", err)
} }
err = sess.Set("username", "astaxie") err = sess.Set(nil, "username", "astaxie")
if err != nil { if err != nil {
t.Fatal("set error,", err) t.Fatal("set error,", err)
} }
if username := sess.Get("username"); username != "astaxie" { if username := sess.Get(nil, "username"); username != "astaxie" {
t.Fatal("get username error") t.Fatal("get username error")
} }
sess.SessionRelease(w) sess.SessionRelease(nil, w)
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
t.Fatal("setcookie error") t.Fatal("setcookie error")
} else { } else {
@ -85,7 +85,7 @@ func TestDestorySessionCookie(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("session start err,", err) t.Fatal("session start err,", err)
} }
if newSession.SessionID() != session.SessionID() { if newSession.SessionID(nil) != session.SessionID(nil) {
t.Fatal("get cookie session id is not the same again.") t.Fatal("get cookie session id is not the same again.")
} }
@ -99,7 +99,7 @@ func TestDestorySessionCookie(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("session start error") t.Fatal("session start error")
} }
if newSession.SessionID() == session.SessionID() { if newSession.SessionID(nil) == session.SessionID(nil) {
t.Fatal("after destroy session and reqeust again ,get cookie session id is same.") t.Fatal("after destroy session and reqeust again ,get cookie session id is same.")
} }
} }

View File

@ -15,6 +15,7 @@
package session package session
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -40,7 +41,7 @@ type FileSessionStore struct {
} }
// Set value to file session // Set value to file session
func (fs *FileSessionStore) Set(key, value interface{}) error { func (fs *FileSessionStore) Set(ctx context.Context, key, value interface{}) error {
fs.lock.Lock() fs.lock.Lock()
defer fs.lock.Unlock() defer fs.lock.Unlock()
fs.values[key] = value fs.values[key] = value
@ -48,7 +49,7 @@ func (fs *FileSessionStore) Set(key, value interface{}) error {
} }
// Get value from file session // Get value from file session
func (fs *FileSessionStore) Get(key interface{}) interface{} { func (fs *FileSessionStore) Get(ctx context.Context, key interface{}) interface{} {
fs.lock.RLock() fs.lock.RLock()
defer fs.lock.RUnlock() defer fs.lock.RUnlock()
if v, ok := fs.values[key]; ok { if v, ok := fs.values[key]; ok {
@ -58,7 +59,7 @@ func (fs *FileSessionStore) Get(key interface{}) interface{} {
} }
// Delete value in file session by given key // Delete value in file session by given key
func (fs *FileSessionStore) Delete(key interface{}) error { func (fs *FileSessionStore) Delete(ctx context.Context, key interface{}) error {
fs.lock.Lock() fs.lock.Lock()
defer fs.lock.Unlock() defer fs.lock.Unlock()
delete(fs.values, key) delete(fs.values, key)
@ -66,7 +67,7 @@ func (fs *FileSessionStore) Delete(key interface{}) error {
} }
// Flush Clean all values in file session // Flush Clean all values in file session
func (fs *FileSessionStore) Flush() error { func (fs *FileSessionStore) Flush(context.Context) error {
fs.lock.Lock() fs.lock.Lock()
defer fs.lock.Unlock() defer fs.lock.Unlock()
fs.values = make(map[interface{}]interface{}) fs.values = make(map[interface{}]interface{})
@ -74,12 +75,12 @@ func (fs *FileSessionStore) Flush() error {
} }
// SessionID Get file session store id // SessionID Get file session store id
func (fs *FileSessionStore) SessionID() string { func (fs *FileSessionStore) SessionID(context.Context) string {
return fs.sid return fs.sid
} }
// SessionRelease Write file session to local file with Gob string // SessionRelease Write file session to local file with Gob string
func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { func (fs *FileSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
filepder.lock.Lock() filepder.lock.Lock()
defer filepder.lock.Unlock() defer filepder.lock.Unlock()
b, err := EncodeGob(fs.values) b, err := EncodeGob(fs.values)
@ -119,7 +120,7 @@ type FileProvider struct {
// SessionInit Init file session provider. // SessionInit Init file session provider.
// savePath sets the session files path. // savePath sets the session files path.
func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { func (fp *FileProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
fp.maxlifetime = maxlifetime fp.maxlifetime = maxlifetime
fp.savePath = savePath fp.savePath = savePath
return nil return nil
@ -128,7 +129,7 @@ func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error {
// SessionRead Read file session by sid. // SessionRead Read file session by sid.
// if file is not exist, create it. // if file is not exist, create it.
// the file path is generated from sid string. // the file path is generated from sid string.
func (fp *FileProvider) SessionRead(sid string) (Store, error) { func (fp *FileProvider) SessionRead(ctx context.Context, sid string) (Store, error) {
invalidChars := "./" invalidChars := "./"
if strings.ContainsAny(sid, invalidChars) { if strings.ContainsAny(sid, invalidChars) {
return nil, errors.New("the sid shouldn't have following characters: " + invalidChars) return nil, errors.New("the sid shouldn't have following characters: " + invalidChars)
@ -176,7 +177,7 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) {
// SessionExist Check file session exist. // SessionExist Check file session exist.
// it checks the file named from sid exist or not. // it checks the file named from sid exist or not.
func (fp *FileProvider) SessionExist(sid string) (bool, error) { func (fp *FileProvider) SessionExist(ctx context.Context, sid string) (bool, error) {
filepder.lock.Lock() filepder.lock.Lock()
defer filepder.lock.Unlock() defer filepder.lock.Unlock()
@ -190,7 +191,7 @@ func (fp *FileProvider) SessionExist(sid string) (bool, error) {
} }
// SessionDestroy Remove all files in this save path // SessionDestroy Remove all files in this save path
func (fp *FileProvider) SessionDestroy(sid string) error { func (fp *FileProvider) SessionDestroy(ctx context.Context, sid string) error {
filepder.lock.Lock() filepder.lock.Lock()
defer filepder.lock.Unlock() defer filepder.lock.Unlock()
os.Remove(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) os.Remove(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
@ -198,7 +199,7 @@ func (fp *FileProvider) SessionDestroy(sid string) error {
} }
// SessionGC Recycle files in save path // SessionGC Recycle files in save path
func (fp *FileProvider) SessionGC() { func (fp *FileProvider) SessionGC(context.Context) {
filepder.lock.Lock() filepder.lock.Lock()
defer filepder.lock.Unlock() defer filepder.lock.Unlock()
@ -208,7 +209,7 @@ func (fp *FileProvider) SessionGC() {
// SessionAll Get active file session number. // SessionAll Get active file session number.
// it walks save path to count files. // it walks save path to count files.
func (fp *FileProvider) SessionAll() int { func (fp *FileProvider) SessionAll(context.Context) int {
a := &activeSession{} a := &activeSession{}
err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error { err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error {
return a.visit(path, f, err) return a.visit(path, f, err)
@ -222,7 +223,7 @@ func (fp *FileProvider) SessionAll() int {
// SessionRegenerate Generate new sid for file session. // SessionRegenerate Generate new sid for file session.
// it delete old file and create new file named from new sid. // it delete old file and create new file named from new sid.
func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { func (fp *FileProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) {
filepder.lock.Lock() filepder.lock.Lock()
defer filepder.lock.Unlock() defer filepder.lock.Unlock()

View File

@ -15,6 +15,7 @@
package session package session
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"sync" "sync"
@ -37,7 +38,7 @@ func TestFileProvider_SessionInit(t *testing.T) {
defer os.RemoveAll(sessionPath) defer os.RemoveAll(sessionPath)
fp := &FileProvider{} fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath) _ = fp.SessionInit(context.Background(), 180, sessionPath)
if fp.maxlifetime != 180 { if fp.maxlifetime != 180 {
t.Error() t.Error()
} }
@ -54,9 +55,9 @@ func TestFileProvider_SessionExist(t *testing.T) {
defer os.RemoveAll(sessionPath) defer os.RemoveAll(sessionPath)
fp := &FileProvider{} fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath) _ = fp.SessionInit(context.Background(), 180, sessionPath)
exists, err := fp.SessionExist(sid) exists, err := fp.SessionExist(context.Background(), sid)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -64,12 +65,12 @@ func TestFileProvider_SessionExist(t *testing.T) {
t.Error() t.Error()
} }
_, err = fp.SessionRead(sid) _, err = fp.SessionRead(context.Background(), sid)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
exists, err = fp.SessionExist(sid) exists, err = fp.SessionExist(context.Background(), sid)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -85,9 +86,9 @@ func TestFileProvider_SessionExist2(t *testing.T) {
defer os.RemoveAll(sessionPath) defer os.RemoveAll(sessionPath)
fp := &FileProvider{} fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath) _ = fp.SessionInit(context.Background(), 180, sessionPath)
exists, err := fp.SessionExist(sid) exists, err := fp.SessionExist(context.Background(), sid)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -95,7 +96,7 @@ func TestFileProvider_SessionExist2(t *testing.T) {
t.Error() t.Error()
} }
exists, err = fp.SessionExist("") exists, err = fp.SessionExist(context.Background(), "")
if err == nil { if err == nil {
t.Error() t.Error()
} }
@ -103,7 +104,7 @@ func TestFileProvider_SessionExist2(t *testing.T) {
t.Error() t.Error()
} }
exists, err = fp.SessionExist("1") exists, err = fp.SessionExist(context.Background(), "1")
if err == nil { if err == nil {
t.Error() t.Error()
} }
@ -119,15 +120,15 @@ func TestFileProvider_SessionRead(t *testing.T) {
defer os.RemoveAll(sessionPath) defer os.RemoveAll(sessionPath)
fp := &FileProvider{} fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath) _ = fp.SessionInit(context.Background(), 180, sessionPath)
s, err := fp.SessionRead(sid) s, err := fp.SessionRead(context.Background(), sid)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
_ = s.Set("sessionValue", 18975) _ = s.Set(nil, "sessionValue", 18975)
v := s.Get("sessionValue") v := s.Get(nil, "sessionValue")
if v.(int) != 18975 { if v.(int) != 18975 {
t.Error() t.Error()
@ -141,14 +142,14 @@ func TestFileProvider_SessionRead1(t *testing.T) {
defer os.RemoveAll(sessionPath) defer os.RemoveAll(sessionPath)
fp := &FileProvider{} fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath) _ = fp.SessionInit(context.Background(), 180, sessionPath)
_, err := fp.SessionRead("") _, err := fp.SessionRead(context.Background(), "")
if err == nil { if err == nil {
t.Error(err) t.Error(err)
} }
_, err = fp.SessionRead("1") _, err = fp.SessionRead(context.Background(), "1")
if err == nil { if err == nil {
t.Error(err) t.Error(err)
} }
@ -161,18 +162,18 @@ func TestFileProvider_SessionAll(t *testing.T) {
defer os.RemoveAll(sessionPath) defer os.RemoveAll(sessionPath)
fp := &FileProvider{} fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath) _ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 546 sessionCount := 546
for i := 1; i <= sessionCount; i++ { for i := 1; i <= sessionCount; i++ {
_, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) _, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
} }
if fp.SessionAll() != sessionCount { if fp.SessionAll(nil) != sessionCount {
t.Error() t.Error()
} }
} }
@ -184,14 +185,14 @@ func TestFileProvider_SessionRegenerate(t *testing.T) {
defer os.RemoveAll(sessionPath) defer os.RemoveAll(sessionPath)
fp := &FileProvider{} fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath) _ = fp.SessionInit(context.Background(), 180, sessionPath)
_, err := fp.SessionRead(sid) _, err := fp.SessionRead(context.Background(), sid)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
exists, err := fp.SessionExist(sid) exists, err := fp.SessionExist(context.Background(), sid)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -199,12 +200,12 @@ func TestFileProvider_SessionRegenerate(t *testing.T) {
t.Error() t.Error()
} }
_, err = fp.SessionRegenerate(sid, sidNew) _, err = fp.SessionRegenerate(context.Background(), sid, sidNew)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
exists, err = fp.SessionExist(sid) exists, err = fp.SessionExist(context.Background(), sid)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -212,7 +213,7 @@ func TestFileProvider_SessionRegenerate(t *testing.T) {
t.Error() t.Error()
} }
exists, err = fp.SessionExist(sidNew) exists, err = fp.SessionExist(context.Background(), sidNew)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -228,14 +229,14 @@ func TestFileProvider_SessionDestroy(t *testing.T) {
defer os.RemoveAll(sessionPath) defer os.RemoveAll(sessionPath)
fp := &FileProvider{} fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath) _ = fp.SessionInit(context.Background(), 180, sessionPath)
_, err := fp.SessionRead(sid) _, err := fp.SessionRead(context.Background(), sid)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
exists, err := fp.SessionExist(sid) exists, err := fp.SessionExist(context.Background(), sid)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -243,12 +244,12 @@ func TestFileProvider_SessionDestroy(t *testing.T) {
t.Error() t.Error()
} }
err = fp.SessionDestroy(sid) err = fp.SessionDestroy(context.Background(), sid)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
exists, err = fp.SessionExist(sid) exists, err = fp.SessionExist(context.Background(), sid)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -264,12 +265,12 @@ func TestFileProvider_SessionGC(t *testing.T) {
defer os.RemoveAll(sessionPath) defer os.RemoveAll(sessionPath)
fp := &FileProvider{} fp := &FileProvider{}
_ = fp.SessionInit(1, sessionPath) _ = fp.SessionInit(context.Background(), 1, sessionPath)
sessionCount := 412 sessionCount := 412
for i := 1; i <= sessionCount; i++ { for i := 1; i <= sessionCount; i++ {
_, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) _, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -277,8 +278,8 @@ func TestFileProvider_SessionGC(t *testing.T) {
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
fp.SessionGC() fp.SessionGC(nil)
if fp.SessionAll() != 0 { if fp.SessionAll(nil) != 0 {
t.Error() t.Error()
} }
} }
@ -290,12 +291,12 @@ func TestFileSessionStore_Set(t *testing.T) {
defer os.RemoveAll(sessionPath) defer os.RemoveAll(sessionPath)
fp := &FileProvider{} fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath) _ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 100 sessionCount := 100
s, _ := fp.SessionRead(sid) s, _ := fp.SessionRead(context.Background(), sid)
for i := 1; i <= sessionCount; i++ { for i := 1; i <= sessionCount; i++ {
err := s.Set(i, i) err := s.Set(nil, i, i)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -309,14 +310,14 @@ func TestFileSessionStore_Get(t *testing.T) {
defer os.RemoveAll(sessionPath) defer os.RemoveAll(sessionPath)
fp := &FileProvider{} fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath) _ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 100 sessionCount := 100
s, _ := fp.SessionRead(sid) s, _ := fp.SessionRead(context.Background(), sid)
for i := 1; i <= sessionCount; i++ { for i := 1; i <= sessionCount; i++ {
_ = s.Set(i, i) _ = s.Set(nil, i, i)
v := s.Get(i) v := s.Get(nil, i)
if v.(int) != i { if v.(int) != i {
t.Error() t.Error()
} }
@ -330,18 +331,18 @@ func TestFileSessionStore_Delete(t *testing.T) {
defer os.RemoveAll(sessionPath) defer os.RemoveAll(sessionPath)
fp := &FileProvider{} fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath) _ = fp.SessionInit(context.Background(), 180, sessionPath)
s, _ := fp.SessionRead(sid) s, _ := fp.SessionRead(context.Background(), sid)
s.Set("1", 1) s.Set(nil, "1", 1)
if s.Get("1") == nil { if s.Get(nil, "1") == nil {
t.Error() t.Error()
} }
s.Delete("1") s.Delete(nil, "1")
if s.Get("1") != nil { if s.Get(nil, "1") != nil {
t.Error() t.Error()
} }
} }
@ -353,18 +354,18 @@ func TestFileSessionStore_Flush(t *testing.T) {
defer os.RemoveAll(sessionPath) defer os.RemoveAll(sessionPath)
fp := &FileProvider{} fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath) _ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 100 sessionCount := 100
s, _ := fp.SessionRead(sid) s, _ := fp.SessionRead(context.Background(), sid)
for i := 1; i <= sessionCount; i++ { for i := 1; i <= sessionCount; i++ {
_ = s.Set(i, i) _ = s.Set(nil, i, i)
} }
_ = s.Flush() _ = s.Flush(nil)
for i := 1; i <= sessionCount; i++ { for i := 1; i <= sessionCount; i++ {
if s.Get(i) != nil { if s.Get(nil, i) != nil {
t.Error() t.Error()
} }
} }
@ -377,16 +378,16 @@ func TestFileSessionStore_SessionID(t *testing.T) {
defer os.RemoveAll(sessionPath) defer os.RemoveAll(sessionPath)
fp := &FileProvider{} fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath) _ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 85 sessionCount := 85
for i := 1; i <= sessionCount; i++ { for i := 1; i <= sessionCount; i++ {
s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if s.SessionID() != fmt.Sprintf("%s_%d", sid, i) { if s.SessionID(nil) != fmt.Sprintf("%s_%d", sid, i) {
t.Error(err) t.Error(err)
} }
} }
@ -399,27 +400,27 @@ func TestFileSessionStore_SessionRelease(t *testing.T) {
defer os.RemoveAll(sessionPath) defer os.RemoveAll(sessionPath)
fp := &FileProvider{} fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath) _ = fp.SessionInit(context.Background(), 180, sessionPath)
filepder.savePath = sessionPath filepder.savePath = sessionPath
sessionCount := 85 sessionCount := 85
for i := 1; i <= sessionCount; i++ { for i := 1; i <= sessionCount; i++ {
s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
s.Set(i, i) s.Set(nil, i, i)
s.SessionRelease(nil) s.SessionRelease(nil, nil)
} }
for i := 1; i <= sessionCount; i++ { for i := 1; i <= sessionCount; i++ {
s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if s.Get(i).(int) != i { if s.Get(nil, i).(int) != i {
t.Error() t.Error()
} }
} }

View File

@ -16,6 +16,7 @@ package session
import ( import (
"container/list" "container/list"
"context"
"net/http" "net/http"
"sync" "sync"
"time" "time"
@ -33,7 +34,7 @@ type MemSessionStore struct {
} }
// Set value to memory session // Set value to memory session
func (st *MemSessionStore) Set(key, value interface{}) error { func (st *MemSessionStore) Set(ctx context.Context, key, value interface{}) error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
st.value[key] = value st.value[key] = value
@ -41,7 +42,7 @@ func (st *MemSessionStore) Set(key, value interface{}) error {
} }
// Get value from memory session by key // Get value from memory session by key
func (st *MemSessionStore) Get(key interface{}) interface{} { func (st *MemSessionStore) Get(ctx context.Context, key interface{}) interface{} {
st.lock.RLock() st.lock.RLock()
defer st.lock.RUnlock() defer st.lock.RUnlock()
if v, ok := st.value[key]; ok { if v, ok := st.value[key]; ok {
@ -51,7 +52,7 @@ func (st *MemSessionStore) Get(key interface{}) interface{} {
} }
// Delete in memory session by key // Delete in memory session by key
func (st *MemSessionStore) Delete(key interface{}) error { func (st *MemSessionStore) Delete(ctx context.Context, key interface{}) error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
delete(st.value, key) delete(st.value, key)
@ -59,7 +60,7 @@ func (st *MemSessionStore) Delete(key interface{}) error {
} }
// Flush clear all values in memory session // Flush clear all values in memory session
func (st *MemSessionStore) Flush() error { func (st *MemSessionStore) Flush(context.Context) error {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
st.value = make(map[interface{}]interface{}) st.value = make(map[interface{}]interface{})
@ -67,12 +68,12 @@ func (st *MemSessionStore) Flush() error {
} }
// SessionID get this id of memory session store // SessionID get this id of memory session store
func (st *MemSessionStore) SessionID() string { func (st *MemSessionStore) SessionID(context.Context) string {
return st.sid return st.sid
} }
// SessionRelease Implement method, no used. // SessionRelease Implement method, no used.
func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) { func (st *MemSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
} }
// MemProvider Implement the provider interface // MemProvider Implement the provider interface
@ -85,14 +86,14 @@ type MemProvider struct {
} }
// SessionInit init memory session // SessionInit init memory session
func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error { func (pder *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
pder.maxlifetime = maxlifetime pder.maxlifetime = maxlifetime
pder.savePath = savePath pder.savePath = savePath
return nil return nil
} }
// SessionRead get memory session store by sid // SessionRead get memory session store by sid
func (pder *MemProvider) SessionRead(sid string) (Store, error) { func (pder *MemProvider) SessionRead(ctx context.Context, sid string) (Store, error) {
pder.lock.RLock() pder.lock.RLock()
if element, ok := pder.sessions[sid]; ok { if element, ok := pder.sessions[sid]; ok {
go pder.SessionUpdate(sid) go pder.SessionUpdate(sid)
@ -109,7 +110,7 @@ func (pder *MemProvider) SessionRead(sid string) (Store, error) {
} }
// SessionExist check session store exist in memory session by sid // SessionExist check session store exist in memory session by sid
func (pder *MemProvider) SessionExist(sid string) (bool, error) { func (pder *MemProvider) SessionExist(ctx context.Context, sid string) (bool, error) {
pder.lock.RLock() pder.lock.RLock()
defer pder.lock.RUnlock() defer pder.lock.RUnlock()
if _, ok := pder.sessions[sid]; ok { if _, ok := pder.sessions[sid]; ok {
@ -119,7 +120,7 @@ func (pder *MemProvider) SessionExist(sid string) (bool, error) {
} }
// SessionRegenerate generate new sid for session store in memory session // SessionRegenerate generate new sid for session store in memory session
func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) { func (pder *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) {
pder.lock.RLock() pder.lock.RLock()
if element, ok := pder.sessions[oldsid]; ok { if element, ok := pder.sessions[oldsid]; ok {
go pder.SessionUpdate(oldsid) go pder.SessionUpdate(oldsid)
@ -141,7 +142,7 @@ func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) {
} }
// SessionDestroy delete session store in memory session by id // SessionDestroy delete session store in memory session by id
func (pder *MemProvider) SessionDestroy(sid string) error { func (pder *MemProvider) SessionDestroy(ctx context.Context, sid string) error {
pder.lock.Lock() pder.lock.Lock()
defer pder.lock.Unlock() defer pder.lock.Unlock()
if element, ok := pder.sessions[sid]; ok { if element, ok := pder.sessions[sid]; ok {
@ -153,7 +154,7 @@ func (pder *MemProvider) SessionDestroy(sid string) error {
} }
// SessionGC clean expired session stores in memory session // SessionGC clean expired session stores in memory session
func (pder *MemProvider) SessionGC() { func (pder *MemProvider) SessionGC(context.Context) {
pder.lock.RLock() pder.lock.RLock()
for { for {
element := pder.list.Back() element := pder.list.Back()
@ -175,7 +176,7 @@ func (pder *MemProvider) SessionGC() {
} }
// SessionAll get count number of memory session // SessionAll get count number of memory session
func (pder *MemProvider) SessionAll() int { func (pder *MemProvider) SessionAll(context.Context) int {
return pder.list.Len() return pder.list.Len()
} }

View File

@ -36,12 +36,12 @@ func TestMem(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("set error,", err) t.Fatal("set error,", err)
} }
defer sess.SessionRelease(w) defer sess.SessionRelease(nil, w)
err = sess.Set("username", "astaxie") err = sess.Set(nil, "username", "astaxie")
if err != nil { if err != nil {
t.Fatal("set error,", err) t.Fatal("set error,", err)
} }
if username := sess.Get("username"); username != "astaxie" { if username := sess.Get(nil, "username"); username != "astaxie" {
t.Fatal("get username error") t.Fatal("get username error")
} }
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {

View File

@ -28,6 +28,7 @@
package session package session
import ( import (
"context"
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"errors" "errors"
@ -43,24 +44,24 @@ import (
// Store contains all data for one session process with specific id. // Store contains all data for one session process with specific id.
type Store interface { type Store interface {
Set(key, value interface{}) error //set session value Set(ctx context.Context, key, value interface{}) error //set session value
Get(key interface{}) interface{} //get session value Get(ctx context.Context, key interface{}) interface{} //get session value
Delete(key interface{}) error //delete session value Delete(ctx context.Context, key interface{}) error //delete session value
SessionID() string //back current sessionID SessionID(ctx context.Context) string //back current sessionID
SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data SessionRelease(ctx context.Context, w http.ResponseWriter) // release the resource & save data to provider & return the data
Flush() error //delete all data Flush(ctx context.Context) error //delete all data
} }
// Provider contains global session methods and saved SessionStores. // Provider contains global session methods and saved SessionStores.
// it can operate a SessionStore by its id. // it can operate a SessionStore by its id.
type Provider interface { type Provider interface {
SessionInit(gclifetime int64, config string) error SessionInit(ctx context.Context, gclifetime int64, config string) error
SessionRead(sid string) (Store, error) SessionRead(ctx context.Context, sid string) (Store, error)
SessionExist(sid string) (bool, error) SessionExist(ctx context.Context, sid string) (bool, error)
SessionRegenerate(oldsid, sid string) (Store, error) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error)
SessionDestroy(sid string) error SessionDestroy(ctx context.Context, sid string) error
SessionAll() int //get all active session SessionAll(ctx context.Context) int //get all active session
SessionGC() SessionGC(ctx context.Context)
} }
var provides = make(map[string]Provider) var provides = make(map[string]Provider)
@ -148,7 +149,7 @@ func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) {
} }
} }
err := provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig) err := provider.SessionInit(nil, cf.Maxlifetime, cf.ProviderConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -212,12 +213,12 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
} }
if sid != "" { if sid != "" {
exists, err := manager.provider.SessionExist(sid) exists, err := manager.provider.SessionExist(nil, sid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if exists { if exists {
return manager.provider.SessionRead(sid) return manager.provider.SessionRead(nil, sid)
} }
} }
@ -227,7 +228,7 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
return nil, errs return nil, errs
} }
session, err = manager.provider.SessionRead(sid) session, err = manager.provider.SessionRead(nil, sid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -269,7 +270,7 @@ func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
} }
sid, _ := url.QueryUnescape(cookie.Value) sid, _ := url.QueryUnescape(cookie.Value)
manager.provider.SessionDestroy(sid) manager.provider.SessionDestroy(nil, sid)
if manager.config.EnableSetCookie { if manager.config.EnableSetCookie {
expiration := time.Now() expiration := time.Now()
cookie = &http.Cookie{Name: manager.config.CookieName, cookie = &http.Cookie{Name: manager.config.CookieName,
@ -285,14 +286,14 @@ func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
// GetSessionStore Get SessionStore by its id. // GetSessionStore Get SessionStore by its id.
func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) { func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) {
sessions, err = manager.provider.SessionRead(sid) sessions, err = manager.provider.SessionRead(nil, sid)
return return
} }
// GC Start session gc process. // GC Start session gc process.
// it can do gc in times after gc lifetime. // it can do gc in times after gc lifetime.
func (manager *Manager) GC() { func (manager *Manager) GC() {
manager.provider.SessionGC() manager.provider.SessionGC(nil)
time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() }) time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() })
} }
@ -305,7 +306,7 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque
cookie, err := r.Cookie(manager.config.CookieName) cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" { if err != nil || cookie.Value == "" {
//delete old cookie //delete old cookie
session, _ = manager.provider.SessionRead(sid) session, _ = manager.provider.SessionRead(nil, sid)
cookie = &http.Cookie{Name: manager.config.CookieName, cookie = &http.Cookie{Name: manager.config.CookieName,
Value: url.QueryEscape(sid), Value: url.QueryEscape(sid),
Path: "/", Path: "/",
@ -315,7 +316,7 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque
} }
} else { } else {
oldsid, _ := url.QueryUnescape(cookie.Value) oldsid, _ := url.QueryUnescape(cookie.Value)
session, _ = manager.provider.SessionRegenerate(oldsid, sid) session, _ = manager.provider.SessionRegenerate(nil, oldsid, sid)
cookie.Value = url.QueryEscape(sid) cookie.Value = url.QueryEscape(sid)
cookie.HttpOnly = true cookie.HttpOnly = true
cookie.Path = "/" cookie.Path = "/"
@ -339,7 +340,7 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque
// GetActiveSession Get all active sessions count number. // GetActiveSession Get all active sessions count number.
func (manager *Manager) GetActiveSession() int { func (manager *Manager) GetActiveSession() int {
return manager.provider.SessionAll() return manager.provider.SessionAll(nil)
} }
// SetSecure Set cookie with https. // SetSecure Set cookie with https.

View File

@ -1,6 +1,7 @@
package ssdb package ssdb
import ( import (
"context"
"errors" "errors"
"net/http" "net/http"
"strconv" "strconv"
@ -31,7 +32,7 @@ func (p *Provider) connectInit() error {
} }
// SessionInit init the ssdb with the config // SessionInit init the ssdb with the config
func (p *Provider) SessionInit(maxLifetime int64, savePath string) error { func (p *Provider) SessionInit(ctx context.Context, maxLifetime int64, savePath string) error {
p.maxLifetime = maxLifetime p.maxLifetime = maxLifetime
address := strings.Split(savePath, ":") address := strings.Split(savePath, ":")
p.host = address[0] p.host = address[0]
@ -44,7 +45,7 @@ func (p *Provider) SessionInit(maxLifetime int64, savePath string) error {
} }
// SessionRead return a ssdb client session Store // SessionRead return a ssdb client session Store
func (p *Provider) SessionRead(sid string) (session.Store, error) { func (p *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
if p.client == nil { if p.client == nil {
if err := p.connectInit(); err != nil { if err := p.connectInit(); err != nil {
return nil, err return nil, err
@ -68,7 +69,7 @@ func (p *Provider) SessionRead(sid string) (session.Store, error) {
} }
// SessionExist judged whether sid is exist in session // SessionExist judged whether sid is exist in session
func (p *Provider) SessionExist(sid string) (bool, error) { func (p *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
if p.client == nil { if p.client == nil {
if err := p.connectInit(); err != nil { if err := p.connectInit(); err != nil {
return false, err return false, err
@ -85,7 +86,7 @@ func (p *Provider) SessionExist(sid string) (bool, error) {
} }
// SessionRegenerate regenerate session with new sid and delete oldsid // SessionRegenerate regenerate session with new sid and delete oldsid
func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { func (p *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
//conn.Do("setx", key, v, ttl) //conn.Do("setx", key, v, ttl)
if p.client == nil { if p.client == nil {
if err := p.connectInit(); err != nil { if err := p.connectInit(); err != nil {
@ -118,7 +119,7 @@ func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error)
} }
// SessionDestroy destroy the sid // SessionDestroy destroy the sid
func (p *Provider) SessionDestroy(sid string) error { func (p *Provider) SessionDestroy(ctx context.Context, sid string) error {
if p.client == nil { if p.client == nil {
if err := p.connectInit(); err != nil { if err := p.connectInit(); err != nil {
return err return err
@ -129,11 +130,11 @@ func (p *Provider) SessionDestroy(sid string) error {
} }
// SessionGC not implemented // SessionGC not implemented
func (p *Provider) SessionGC() { func (p *Provider) SessionGC(context.Context) {
} }
// SessionAll not implemented // SessionAll not implemented
func (p *Provider) SessionAll() int { func (p *Provider) SessionAll(context.Context) int {
return 0 return 0
} }
@ -147,7 +148,7 @@ type SessionStore struct {
} }
// Set the key and value // Set the key and value
func (s *SessionStore) Set(key, value interface{}) error { func (s *SessionStore) Set(ctx context.Context, key, value interface{}) error {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
s.values[key] = value s.values[key] = value
@ -155,7 +156,7 @@ func (s *SessionStore) Set(key, value interface{}) error {
} }
// Get return the value by the key // Get return the value by the key
func (s *SessionStore) Get(key interface{}) interface{} { func (s *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
if value, ok := s.values[key]; ok { if value, ok := s.values[key]; ok {
@ -165,7 +166,7 @@ func (s *SessionStore) Get(key interface{}) interface{} {
} }
// Delete the key in session store // Delete the key in session store
func (s *SessionStore) Delete(key interface{}) error { func (s *SessionStore) Delete(ctx context.Context, key interface{}) error {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
delete(s.values, key) delete(s.values, key)
@ -173,7 +174,7 @@ func (s *SessionStore) Delete(key interface{}) error {
} }
// Flush delete all keys and values // Flush delete all keys and values
func (s *SessionStore) Flush() error { func (s *SessionStore) Flush(context.Context) error {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
s.values = make(map[interface{}]interface{}) s.values = make(map[interface{}]interface{})
@ -181,12 +182,12 @@ func (s *SessionStore) Flush() error {
} }
// SessionID return the sessionID // SessionID return the sessionID
func (s *SessionStore) SessionID() string { func (s *SessionStore) SessionID(context.Context) string {
return s.sid return s.sid
} }
// SessionRelease Store the keyvalues into ssdb // SessionRelease Store the keyvalues into ssdb
func (s *SessionStore) SessionRelease(w http.ResponseWriter) { func (s *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(s.values) b, err := session.EncodeGob(s.values)
if err != nil { if err != nil {
return return

View File

@ -361,7 +361,7 @@ func (input *BeegoInput) Cookie(key string) string {
// Session returns current session item value by a given key. // Session returns current session item value by a given key.
// if non-existed, return nil. // if non-existed, return nil.
func (input *BeegoInput) Session(key interface{}) interface{} { func (input *BeegoInput) Session(key interface{}) interface{} {
return input.CruSession.Get(key) return input.CruSession.Get(nil, key)
} }
// CopyBody returns the raw request body data as bytes. // CopyBody returns the raw request body data as bytes.

View File

@ -404,5 +404,5 @@ func stringsToJSON(str string) string {
// Session sets session item value with given key. // Session sets session item value with given key.
func (output *BeegoOutput) Session(name interface{}, value interface{}) { func (output *BeegoOutput) Session(name interface{}, value interface{}) {
output.Context.Input.CruSession.Set(name, value) output.Context.Input.CruSession.Set(nil, name, value)
} }

View File

@ -622,7 +622,7 @@ func (c *Controller) SetSession(name interface{}, value interface{}) {
if c.CruSession == nil { if c.CruSession == nil {
c.StartSession() c.StartSession()
} }
c.CruSession.Set(name, value) c.CruSession.Set(nil, name, value)
} }
// GetSession gets value from session. // GetSession gets value from session.
@ -630,7 +630,7 @@ func (c *Controller) GetSession(name interface{}) interface{} {
if c.CruSession == nil { if c.CruSession == nil {
c.StartSession() c.StartSession()
} }
return c.CruSession.Get(name) return c.CruSession.Get(nil, name)
} }
// DelSession removes value from session. // DelSession removes value from session.
@ -638,14 +638,14 @@ func (c *Controller) DelSession(name interface{}) {
if c.CruSession == nil { if c.CruSession == nil {
c.StartSession() c.StartSession()
} }
c.CruSession.Delete(name) c.CruSession.Delete(nil, name)
} }
// SessionRegenerateID regenerates session id for this session. // SessionRegenerateID regenerates session id for this session.
// the session data have no changes. // the session data have no changes.
func (c *Controller) SessionRegenerateID() { func (c *Controller) SessionRegenerateID() {
if c.CruSession != nil { if c.CruSession != nil {
c.CruSession.SessionRelease(c.Ctx.ResponseWriter) c.CruSession.SessionRelease(nil, c.Ctx.ResponseWriter)
} }
c.CruSession = GlobalSessions.SessionRegenerateID(c.Ctx.ResponseWriter, c.Ctx.Request) c.CruSession = GlobalSessions.SessionRegenerateID(c.Ctx.ResponseWriter, c.Ctx.Request)
c.Ctx.Input.CruSession = c.CruSession c.Ctx.Input.CruSession = c.CruSession
@ -653,7 +653,7 @@ func (c *Controller) SessionRegenerateID() {
// DestroySession cleans session data and session cookie. // DestroySession cleans session data and session cookie.
func (c *Controller) DestroySession() { func (c *Controller) DestroySession() {
c.Ctx.Input.CruSession.Flush() c.Ctx.Input.CruSession.Flush(nil)
c.Ctx.Input.CruSession = nil c.Ctx.Input.CruSession = nil
GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request) GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request)
} }

View File

@ -721,7 +721,7 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) {
} }
defer func() { defer func() {
if ctx.Input.CruSession != nil { if ctx.Input.CruSession != nil {
ctx.Input.CruSession.SessionRelease(rw) ctx.Input.CruSession.SessionRelease(nil, rw)
} }
}() }()
} }