1
0
mirror of https://github.com/astaxie/beego.git synced 2024-12-22 17:00:50 +00:00

Add ctx to session API

This commit is contained in:
Ming Deng 2020-08-30 15:39:07 +00:00
parent 0019e0fc1b
commit 670064686e
23 changed files with 302 additions and 288 deletions

View File

@ -129,7 +129,7 @@ func TestCache_Scan(t *testing.T) {
t.Error("init err")
}
// insert all
for i := 0; i < 10000; i++ {
for i := 0; i < 100; i++ {
if err = bm.Put(fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil {
t.Error("set Error", err)
}
@ -141,7 +141,7 @@ func TestCache_Scan(t *testing.T) {
t.Error("scan Error", err)
}
assert.Equal(t, 10000, len(keys), "scan all error")
assert.Equal(t, 100, len(keys), "scan all error")
// clear all
if err = bm.ClearAll(); err != nil {

View File

@ -33,6 +33,7 @@
package couchbase
import (
"context"
"net/http"
"strings"
"sync"
@ -63,7 +64,7 @@ type Provider struct {
}
// Set value to couchabse session
func (cs *SessionStore) Set(key, value interface{}) error {
func (cs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
cs.lock.Lock()
defer cs.lock.Unlock()
cs.values[key] = value
@ -71,7 +72,7 @@ func (cs *SessionStore) Set(key, value interface{}) error {
}
// Get value from couchabse session
func (cs *SessionStore) Get(key interface{}) interface{} {
func (cs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
cs.lock.RLock()
defer cs.lock.RUnlock()
if v, ok := cs.values[key]; ok {
@ -81,7 +82,7 @@ func (cs *SessionStore) Get(key interface{}) interface{} {
}
// Delete value in couchbase session by given key
func (cs *SessionStore) Delete(key interface{}) error {
func (cs *SessionStore) Delete(ctx context.Context, key interface{}) error {
cs.lock.Lock()
defer cs.lock.Unlock()
delete(cs.values, key)
@ -89,7 +90,7 @@ func (cs *SessionStore) Delete(key interface{}) error {
}
// Flush Clean all values in couchbase session
func (cs *SessionStore) Flush() error {
func (cs *SessionStore) Flush(context.Context) error {
cs.lock.Lock()
defer cs.lock.Unlock()
cs.values = make(map[interface{}]interface{})
@ -97,12 +98,12 @@ func (cs *SessionStore) Flush() error {
}
// SessionID Get couchbase session store id
func (cs *SessionStore) SessionID() string {
func (cs *SessionStore) SessionID(context.Context) string {
return cs.sid
}
// SessionRelease Write couchbase session with Gob string
func (cs *SessionStore) SessionRelease(w http.ResponseWriter) {
func (cs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
defer cs.b.Close()
bo, err := session.EncodeGob(cs.values)
@ -135,7 +136,7 @@ func (cp *Provider) getBucket() *couchbase.Bucket {
// SessionInit init couchbase session
// savepath like couchbase server REST/JSON URL
// e.g. http://host:port/, Pool, Bucket
func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error {
func (cp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
cp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
@ -152,7 +153,7 @@ func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error {
}
// SessionRead read couchbase session by sid
func (cp *Provider) SessionRead(sid string) (session.Store, error) {
func (cp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
cp.b = cp.getBucket()
var (
@ -179,7 +180,7 @@ func (cp *Provider) SessionRead(sid string) (session.Store, error) {
// SessionExist Check couchbase session exist.
// it checkes sid exist or not.
func (cp *Provider) SessionExist(sid string) (bool, error) {
func (cp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
cp.b = cp.getBucket()
defer cp.b.Close()
@ -192,7 +193,7 @@ func (cp *Provider) SessionExist(sid string) (bool, error) {
}
// SessionRegenerate remove oldsid and use sid to generate new session
func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
func (cp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
cp.b = cp.getBucket()
var doc []byte
@ -225,7 +226,7 @@ func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error)
}
// SessionDestroy Remove bucket in this couchbase
func (cp *Provider) SessionDestroy(sid string) error {
func (cp *Provider) SessionDestroy(ctx context.Context, sid string) error {
cp.b = cp.getBucket()
defer cp.b.Close()
@ -234,11 +235,11 @@ func (cp *Provider) SessionDestroy(sid string) error {
}
// SessionGC Recycle
func (cp *Provider) SessionGC() {
func (cp *Provider) SessionGC(context.Context) {
}
// SessionAll return all active session
func (cp *Provider) SessionAll() int {
func (cp *Provider) SessionAll(context.Context) int {
return 0
}

View File

@ -2,6 +2,7 @@
package ledis
import (
"context"
"net/http"
"strconv"
"strings"
@ -27,7 +28,7 @@ type SessionStore struct {
}
// Set value in ledis session
func (ls *SessionStore) Set(key, value interface{}) error {
func (ls *SessionStore) Set(ctx context.Context, key, value interface{}) error {
ls.lock.Lock()
defer ls.lock.Unlock()
ls.values[key] = value
@ -35,7 +36,7 @@ func (ls *SessionStore) Set(key, value interface{}) error {
}
// Get value in ledis session
func (ls *SessionStore) Get(key interface{}) interface{} {
func (ls *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
ls.lock.RLock()
defer ls.lock.RUnlock()
if v, ok := ls.values[key]; ok {
@ -45,7 +46,7 @@ func (ls *SessionStore) Get(key interface{}) interface{} {
}
// Delete value in ledis session
func (ls *SessionStore) Delete(key interface{}) error {
func (ls *SessionStore) Delete(ctx context.Context, key interface{}) error {
ls.lock.Lock()
defer ls.lock.Unlock()
delete(ls.values, key)
@ -53,7 +54,7 @@ func (ls *SessionStore) Delete(key interface{}) error {
}
// Flush clear all values in ledis session
func (ls *SessionStore) Flush() error {
func (ls *SessionStore) Flush(context.Context) error {
ls.lock.Lock()
defer ls.lock.Unlock()
ls.values = make(map[interface{}]interface{})
@ -61,12 +62,12 @@ func (ls *SessionStore) Flush() error {
}
// SessionID get ledis session id
func (ls *SessionStore) SessionID() string {
func (ls *SessionStore) SessionID(context.Context) string {
return ls.sid
}
// SessionRelease save session values to ledis
func (ls *SessionStore) SessionRelease(w http.ResponseWriter) {
func (ls *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(ls.values)
if err != nil {
return
@ -85,7 +86,7 @@ type Provider struct {
// SessionInit init ledis session
// savepath like ledis server saveDataPath,pool size
// e.g. 127.0.0.1:6379,100,astaxie
func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error {
func (lp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
var err error
lp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
@ -111,7 +112,7 @@ func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error {
}
// SessionRead read ledis session by sid
func (lp *Provider) SessionRead(sid string) (session.Store, error) {
func (lp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
var (
kv map[interface{}]interface{}
err error
@ -132,13 +133,13 @@ func (lp *Provider) SessionRead(sid string) (session.Store, error) {
}
// SessionExist check ledis session exist by sid
func (lp *Provider) SessionExist(sid string) (bool, error) {
func (lp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
count, _ := c.Exists([]byte(sid))
return count != 0, nil
}
// SessionRegenerate generate new sid for ledis session
func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
func (lp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
count, _ := c.Exists([]byte(sid))
if count == 0 {
// oldsid doesn't exists, set the new sid directly
@ -151,21 +152,21 @@ func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error)
c.Set([]byte(sid), data)
c.Expire([]byte(sid), lp.maxlifetime)
}
return lp.SessionRead(sid)
return lp.SessionRead(context.Background(), sid)
}
// SessionDestroy delete ledis session by id
func (lp *Provider) SessionDestroy(sid string) error {
func (lp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c.Del([]byte(sid))
return nil
}
// SessionGC Impelment method, no used.
func (lp *Provider) SessionGC() {
func (lp *Provider) SessionGC(context.Context) {
}
// SessionAll return all active session
func (lp *Provider) SessionAll() int {
func (lp *Provider) SessionAll(context.Context) int {
return 0
}
func init() {

View File

@ -33,6 +33,7 @@
package memcache
import (
"context"
"net/http"
"strings"
"sync"
@ -54,7 +55,7 @@ type SessionStore struct {
}
// Set value in memcache session
func (rs *SessionStore) Set(key, value interface{}) error {
func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
@ -62,7 +63,7 @@ func (rs *SessionStore) Set(key, value interface{}) error {
}
// Get value in memcache session
func (rs *SessionStore) Get(key interface{}) interface{} {
func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
@ -72,7 +73,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} {
}
// Delete value in memcache session
func (rs *SessionStore) Delete(key interface{}) error {
func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
@ -80,7 +81,7 @@ func (rs *SessionStore) Delete(key interface{}) error {
}
// Flush clear all values in memcache session
func (rs *SessionStore) Flush() error {
func (rs *SessionStore) Flush(context.Context) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
@ -88,12 +89,12 @@ func (rs *SessionStore) Flush() error {
}
// SessionID get memcache session id
func (rs *SessionStore) SessionID() string {
func (rs *SessionStore) SessionID(context.Context) string {
return rs.sid
}
// SessionRelease save session values to memcache
func (rs *SessionStore) SessionRelease(w http.ResponseWriter) {
func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
@ -113,7 +114,7 @@ type MemProvider struct {
// SessionInit init memcache session
// savepath like
// e.g. 127.0.0.1:9090
func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error {
func (rp *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
rp.conninfo = strings.Split(savePath, ";")
client = memcache.New(rp.conninfo...)
@ -121,7 +122,7 @@ func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error {
}
// SessionRead read memcache session by sid
func (rp *MemProvider) SessionRead(sid string) (session.Store, error) {
func (rp *MemProvider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
if client == nil {
if err := rp.connectInit(); err != nil {
return nil, err
@ -149,7 +150,7 @@ func (rp *MemProvider) SessionRead(sid string) (session.Store, error) {
}
// SessionExist check memcache session exist by sid
func (rp *MemProvider) SessionExist(sid string) (bool, error) {
func (rp *MemProvider) SessionExist(ctx context.Context, sid string) (bool, error) {
if client == nil {
if err := rp.connectInit(); err != nil {
return false, err
@ -162,7 +163,7 @@ func (rp *MemProvider) SessionExist(sid string) (bool, error) {
}
// SessionRegenerate generate new sid for memcache session
func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
func (rp *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
if client == nil {
if err := rp.connectInit(); err != nil {
return nil, err
@ -201,7 +202,7 @@ func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, err
}
// SessionDestroy delete memcache session by id
func (rp *MemProvider) SessionDestroy(sid string) error {
func (rp *MemProvider) SessionDestroy(ctx context.Context, sid string) error {
if client == nil {
if err := rp.connectInit(); err != nil {
return err
@ -217,11 +218,11 @@ func (rp *MemProvider) connectInit() error {
}
// SessionGC Impelment method, no used.
func (rp *MemProvider) SessionGC() {
func (rp *MemProvider) SessionGC(context.Context) {
}
// SessionAll return all activeSession
func (rp *MemProvider) SessionAll() int {
func (rp *MemProvider) SessionAll(context.Context) int {
return 0
}

View File

@ -41,6 +41,7 @@
package mysql
import (
"context"
"database/sql"
"net/http"
"sync"
@ -67,7 +68,7 @@ type SessionStore struct {
// Set value in mysql session.
// it is temp value in map.
func (st *SessionStore) Set(key, value interface{}) error {
func (st *SessionStore) Set(ctx context.Context, key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
@ -75,7 +76,7 @@ func (st *SessionStore) Set(key, value interface{}) error {
}
// Get value from mysql session
func (st *SessionStore) Get(key interface{}) interface{} {
func (st *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
@ -85,7 +86,7 @@ func (st *SessionStore) Get(key interface{}) interface{} {
}
// Delete value in mysql session
func (st *SessionStore) Delete(key interface{}) error {
func (st *SessionStore) Delete(ctx context.Context, key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
@ -93,7 +94,7 @@ func (st *SessionStore) Delete(key interface{}) error {
}
// Flush clear all values in mysql session
func (st *SessionStore) Flush() error {
func (st *SessionStore) Flush(context.Context) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
@ -101,13 +102,13 @@ func (st *SessionStore) Flush() error {
}
// SessionID get session id of this mysql session store
func (st *SessionStore) SessionID() string {
func (st *SessionStore) SessionID(context.Context) string {
return st.sid
}
// SessionRelease save mysql session values to database.
// must call this method to save values to database.
func (st *SessionStore) SessionRelease(w http.ResponseWriter) {
func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
defer st.c.Close()
b, err := session.EncodeGob(st.values)
if err != nil {
@ -134,14 +135,14 @@ func (mp *Provider) connectInit() *sql.DB {
// SessionInit init mysql session.
// savepath is the connection string of mysql.
func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error {
func (mp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
mp.maxlifetime = maxlifetime
mp.savePath = savePath
return nil
}
// SessionRead get mysql session by sid
func (mp *Provider) SessionRead(sid string) (session.Store, error) {
func (mp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid)
var sessiondata []byte
@ -164,7 +165,7 @@ func (mp *Provider) SessionRead(sid string) (session.Store, error) {
}
// SessionExist check mysql session exist
func (mp *Provider) SessionExist(sid string) (bool, error) {
func (mp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := mp.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid)
@ -180,7 +181,7 @@ func (mp *Provider) SessionExist(sid string) (bool, error) {
}
// SessionRegenerate generate new sid for mysql session
func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
func (mp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid)
var sessiondata []byte
@ -203,7 +204,7 @@ func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error)
}
// SessionDestroy delete mysql session by sid
func (mp *Provider) SessionDestroy(sid string) error {
func (mp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := mp.connectInit()
c.Exec("DELETE FROM "+TableName+" where session_key=?", sid)
c.Close()
@ -211,14 +212,14 @@ func (mp *Provider) SessionDestroy(sid string) error {
}
// SessionGC delete expired values in mysql session
func (mp *Provider) SessionGC() {
func (mp *Provider) SessionGC(context.Context) {
c := mp.connectInit()
c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime)
c.Close()
}
// SessionAll count values in mysql session
func (mp *Provider) SessionAll() int {
func (mp *Provider) SessionAll(context.Context) int {
c := mp.connectInit()
defer c.Close()
var total int

View File

@ -51,6 +51,7 @@
package postgres
import (
"context"
"database/sql"
"net/http"
"sync"
@ -73,7 +74,7 @@ type SessionStore struct {
// Set value in postgresql session.
// it is temp value in map.
func (st *SessionStore) Set(key, value interface{}) error {
func (st *SessionStore) Set(ctx context.Context, key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
@ -81,7 +82,7 @@ func (st *SessionStore) Set(key, value interface{}) error {
}
// Get value from postgresql session
func (st *SessionStore) Get(key interface{}) interface{} {
func (st *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
@ -91,7 +92,7 @@ func (st *SessionStore) Get(key interface{}) interface{} {
}
// Delete value in postgresql session
func (st *SessionStore) Delete(key interface{}) error {
func (st *SessionStore) Delete(ctx context.Context, key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
@ -99,7 +100,7 @@ func (st *SessionStore) Delete(key interface{}) error {
}
// Flush clear all values in postgresql session
func (st *SessionStore) Flush() error {
func (st *SessionStore) Flush(context.Context) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
@ -107,13 +108,13 @@ func (st *SessionStore) Flush() error {
}
// SessionID get session id of this postgresql session store
func (st *SessionStore) SessionID() string {
func (st *SessionStore) SessionID(context.Context) string {
return st.sid
}
// SessionRelease save postgresql session values to database.
// must call this method to save values to database.
func (st *SessionStore) SessionRelease(w http.ResponseWriter) {
func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
defer st.c.Close()
b, err := session.EncodeGob(st.values)
if err != nil {
@ -141,14 +142,14 @@ func (mp *Provider) connectInit() *sql.DB {
// SessionInit init postgresql session.
// savepath is the connection string of postgresql.
func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error {
func (mp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
mp.maxlifetime = maxlifetime
mp.savePath = savePath
return nil
}
// SessionRead get postgresql session by sid
func (mp *Provider) SessionRead(sid string) (session.Store, error) {
func (mp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=$1", sid)
var sessiondata []byte
@ -178,7 +179,7 @@ func (mp *Provider) SessionRead(sid string) (session.Store, error) {
}
// SessionExist check postgresql session exist
func (mp *Provider) SessionExist(sid string) (bool, error) {
func (mp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := mp.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from session where session_key=$1", sid)
@ -194,7 +195,7 @@ func (mp *Provider) SessionExist(sid string) (bool, error) {
}
// SessionRegenerate generate new sid for postgresql session
func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
func (mp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=$1", oldsid)
var sessiondata []byte
@ -218,7 +219,7 @@ func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error)
}
// SessionDestroy delete postgresql session by sid
func (mp *Provider) SessionDestroy(sid string) error {
func (mp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := mp.connectInit()
c.Exec("DELETE FROM session where session_key=$1", sid)
c.Close()
@ -226,14 +227,14 @@ func (mp *Provider) SessionDestroy(sid string) error {
}
// SessionGC delete expired values in postgresql session
func (mp *Provider) SessionGC() {
func (mp *Provider) SessionGC(context.Context) {
c := mp.connectInit()
c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime)
c.Close()
}
// SessionAll count values in postgresql session
func (mp *Provider) SessionAll() int {
func (mp *Provider) SessionAll(context.Context) int {
c := mp.connectInit()
defer c.Close()
var total int

View File

@ -33,6 +33,7 @@
package redis
import (
"context"
"net/http"
"strconv"
"strings"
@ -59,7 +60,7 @@ type SessionStore struct {
}
// Set value in redis session
func (rs *SessionStore) Set(key, value interface{}) error {
func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
@ -67,7 +68,7 @@ func (rs *SessionStore) Set(key, value interface{}) error {
}
// Get value in redis session
func (rs *SessionStore) Get(key interface{}) interface{} {
func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
@ -77,7 +78,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} {
}
// Delete value in redis session
func (rs *SessionStore) Delete(key interface{}) error {
func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
@ -85,7 +86,7 @@ func (rs *SessionStore) Delete(key interface{}) error {
}
// Flush clear all values in redis session
func (rs *SessionStore) Flush() error {
func (rs *SessionStore) Flush(context.Context) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
@ -93,12 +94,12 @@ func (rs *SessionStore) Flush() error {
}
// SessionID get redis session id
func (rs *SessionStore) SessionID() string {
func (rs *SessionStore) SessionID(context.Context) string {
return rs.sid
}
// SessionRelease save session values to redis
func (rs *SessionStore) SessionRelease(w http.ResponseWriter) {
func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
@ -123,7 +124,7 @@ type Provider struct {
// SessionInit init redis session
// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second
// e.g. 127.0.0.1:6379,100,astaxie,0,30
func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
@ -185,7 +186,7 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
}
// SessionRead read redis session by sid
func (rp *Provider) SessionRead(sid string) (session.Store, error) {
func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
var kv map[interface{}]interface{}
kvs, err := rp.poollist.Get(sid).Result()
@ -205,7 +206,7 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) {
}
// SessionExist check redis session exist by sid
func (rp *Provider) SessionExist(sid string) (bool, error) {
func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := rp.poollist
if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 {
@ -215,7 +216,7 @@ func (rp *Provider) SessionExist(sid string) (bool, error) {
}
// SessionRegenerate generate new sid for redis session
func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := rp.poollist
if existed, _ := c.Exists(oldsid).Result(); existed == 0 {
// oldsid doesn't exists, set the new sid directly
@ -226,11 +227,11 @@ func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error)
c.Rename(oldsid, sid)
c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second)
}
return rp.SessionRead(sid)
return rp.SessionRead(context.Background(), sid)
}
// SessionDestroy delete redis session by id
func (rp *Provider) SessionDestroy(sid string) error {
func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := rp.poollist
c.Del(sid)
@ -238,11 +239,11 @@ func (rp *Provider) SessionDestroy(sid string) error {
}
// SessionGC Impelment method, no used.
func (rp *Provider) SessionGC() {
func (rp *Provider) SessionGC(context.Context) {
}
// SessionAll return all activeSession
func (rp *Provider) SessionAll() int {
func (rp *Provider) SessionAll(context.Context) int {
return 0
}

View File

@ -40,57 +40,57 @@ func TestRedis(t *testing.T) {
if err != nil {
t.Fatal("session start failed:", err)
}
defer sess.SessionRelease(w)
defer sess.SessionRelease(nil, w)
// SET AND GET
err = sess.Set("username", "astaxie")
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set username failed:", err)
}
username := sess.Get("username")
username := sess.Get(nil, "username")
if username != "astaxie" {
t.Fatal("get username failed")
}
// DELETE
err = sess.Delete("username")
err = sess.Delete(nil, "username")
if err != nil {
t.Fatal("delete username failed:", err)
}
username = sess.Get("username")
username = sess.Get(nil, "username")
if username != nil {
t.Fatal("delete username failed")
}
// FLUSH
err = sess.Set("username", "astaxie")
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set failed:", err)
}
err = sess.Set("password", "1qaz2wsx")
err = sess.Set(nil, "password", "1qaz2wsx")
if err != nil {
t.Fatal("set failed:", err)
}
username = sess.Get("username")
username = sess.Get(nil, "username")
if username != "astaxie" {
t.Fatal("get username failed")
}
password := sess.Get("password")
password := sess.Get(nil, "password")
if password != "1qaz2wsx" {
t.Fatal("get password failed")
}
err = sess.Flush()
err = sess.Flush(nil)
if err != nil {
t.Fatal("flush failed:", err)
}
username = sess.Get("username")
username = sess.Get(nil, "username")
if username != nil {
t.Fatal("flush failed")
}
password = sess.Get("password")
password = sess.Get(nil, "password")
if password != nil {
t.Fatal("flush failed")
}
sess.SessionRelease(w)
sess.SessionRelease(nil, w)
}

View File

@ -33,6 +33,7 @@
package redis_cluster
import (
"context"
"net/http"
"strconv"
"strings"
@ -58,7 +59,7 @@ type SessionStore struct {
}
// Set value in redis_cluster session
func (rs *SessionStore) Set(key, value interface{}) error {
func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
@ -66,7 +67,7 @@ func (rs *SessionStore) Set(key, value interface{}) error {
}
// Get value in redis_cluster session
func (rs *SessionStore) Get(key interface{}) interface{} {
func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
@ -76,7 +77,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} {
}
// Delete value in redis_cluster session
func (rs *SessionStore) Delete(key interface{}) error {
func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
@ -84,7 +85,7 @@ func (rs *SessionStore) Delete(key interface{}) error {
}
// Flush clear all values in redis_cluster session
func (rs *SessionStore) Flush() error {
func (rs *SessionStore) Flush(context.Context) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
@ -92,12 +93,12 @@ func (rs *SessionStore) Flush() error {
}
// SessionID get redis_cluster session id
func (rs *SessionStore) SessionID() string {
func (rs *SessionStore) SessionID(context.Context) string {
return rs.sid
}
// SessionRelease save session values to redis_cluster
func (rs *SessionStore) SessionRelease(w http.ResponseWriter) {
func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
@ -122,7 +123,7 @@ type Provider struct {
// SessionInit init redis_cluster session
// savepath like redis server addr,pool size,password,dbnum
// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0
func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
@ -182,7 +183,7 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
}
// SessionRead read redis_cluster session by sid
func (rp *Provider) SessionRead(sid string) (session.Store, error) {
func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
var kv map[interface{}]interface{}
kvs, err := rp.poollist.Get(sid).Result()
if err != nil && err != rediss.Nil {
@ -201,7 +202,7 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) {
}
// SessionExist check redis_cluster session exist by sid
func (rp *Provider) SessionExist(sid string) (bool, error) {
func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := rp.poollist
if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 {
return false, err
@ -210,7 +211,7 @@ func (rp *Provider) SessionExist(sid string) (bool, error) {
}
// SessionRegenerate generate new sid for redis_cluster session
func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := rp.poollist
if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 {
@ -222,22 +223,22 @@ func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error)
c.Rename(oldsid, sid)
c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second)
}
return rp.SessionRead(sid)
return rp.SessionRead(context.Background(), sid)
}
// SessionDestroy delete redis session by id
func (rp *Provider) SessionDestroy(sid string) error {
func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := rp.poollist
c.Del(sid)
return nil
}
// SessionGC Impelment method, no used.
func (rp *Provider) SessionGC() {
func (rp *Provider) SessionGC(context.Context) {
}
// SessionAll return all activeSession
func (rp *Provider) SessionAll() int {
func (rp *Provider) SessionAll(context.Context) int {
return 0
}

View File

@ -33,6 +33,7 @@
package redis_sentinel
import (
"context"
"net/http"
"strconv"
"strings"
@ -58,7 +59,7 @@ type SessionStore struct {
}
// Set value in redis_sentinel session
func (rs *SessionStore) Set(key, value interface{}) error {
func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
@ -66,7 +67,7 @@ func (rs *SessionStore) Set(key, value interface{}) error {
}
// Get value in redis_sentinel session
func (rs *SessionStore) Get(key interface{}) interface{} {
func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
@ -76,7 +77,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} {
}
// Delete value in redis_sentinel session
func (rs *SessionStore) Delete(key interface{}) error {
func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
@ -84,7 +85,7 @@ func (rs *SessionStore) Delete(key interface{}) error {
}
// Flush clear all values in redis_sentinel session
func (rs *SessionStore) Flush() error {
func (rs *SessionStore) Flush(context.Context) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
@ -92,12 +93,12 @@ func (rs *SessionStore) Flush() error {
}
// SessionID get redis_sentinel session id
func (rs *SessionStore) SessionID() string {
func (rs *SessionStore) SessionID(context.Context) string {
return rs.sid
}
// SessionRelease save session values to redis_sentinel
func (rs *SessionStore) SessionRelease(w http.ResponseWriter) {
func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
@ -123,7 +124,7 @@ type Provider struct {
// SessionInit init redis_sentinel session
// savepath like redis sentinel addr,pool size,password,dbnum,masterName
// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster
func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
@ -195,7 +196,7 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
}
// SessionRead read redis_sentinel session by sid
func (rp *Provider) SessionRead(sid string) (session.Store, error) {
func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
var kv map[interface{}]interface{}
kvs, err := rp.poollist.Get(sid).Result()
if err != nil && err != redis.Nil {
@ -214,7 +215,7 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) {
}
// SessionExist check redis_sentinel session exist by sid
func (rp *Provider) SessionExist(sid string) (bool, error) {
func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := rp.poollist
if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 {
return false, err
@ -223,7 +224,7 @@ func (rp *Provider) SessionExist(sid string) (bool, error) {
}
// SessionRegenerate generate new sid for redis_sentinel session
func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := rp.poollist
if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 {
@ -235,22 +236,22 @@ func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error)
c.Rename(oldsid, sid)
c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second)
}
return rp.SessionRead(sid)
return rp.SessionRead(context.Background(), sid)
}
// SessionDestroy delete redis session by id
func (rp *Provider) SessionDestroy(sid string) error {
func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := rp.poollist
c.Del(sid)
return nil
}
// SessionGC Impelment method, no used.
func (rp *Provider) SessionGC() {
func (rp *Provider) SessionGC(context.Context) {
}
// SessionAll return all activeSession
func (rp *Provider) SessionAll() int {
func (rp *Provider) SessionAll(context.Context) int {
return 0
}

View File

@ -33,58 +33,58 @@ func TestRedisSentinel(t *testing.T) {
if err != nil {
t.Fatal("session start failed:", err)
}
defer sess.SessionRelease(w)
defer sess.SessionRelease(nil, w)
// SET AND GET
err = sess.Set("username", "astaxie")
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set username failed:", err)
}
username := sess.Get("username")
username := sess.Get(nil, "username")
if username != "astaxie" {
t.Fatal("get username failed")
}
// DELETE
err = sess.Delete("username")
err = sess.Delete(nil, "username")
if err != nil {
t.Fatal("delete username failed:", err)
}
username = sess.Get("username")
username = sess.Get(nil, "username")
if username != nil {
t.Fatal("delete username failed")
}
// FLUSH
err = sess.Set("username", "astaxie")
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set failed:", err)
}
err = sess.Set("password", "1qaz2wsx")
err = sess.Set(nil, "password", "1qaz2wsx")
if err != nil {
t.Fatal("set failed:", err)
}
username = sess.Get("username")
username = sess.Get(nil, "username")
if username != "astaxie" {
t.Fatal("get username failed")
}
password := sess.Get("password")
password := sess.Get(nil, "password")
if password != "1qaz2wsx" {
t.Fatal("get password failed")
}
err = sess.Flush()
err = sess.Flush(nil)
if err != nil {
t.Fatal("flush failed:", err)
}
username = sess.Get("username")
username = sess.Get(nil, "username")
if username != nil {
t.Fatal("flush failed")
}
password = sess.Get("password")
password = sess.Get(nil, "password")
if password != nil {
t.Fatal("flush failed")
}
sess.SessionRelease(w)
sess.SessionRelease(nil, w)
}

View File

@ -15,6 +15,7 @@
package session
import (
"context"
"crypto/aes"
"crypto/cipher"
"encoding/json"
@ -34,7 +35,7 @@ type CookieSessionStore struct {
// Set value to cookie session.
// the value are encoded as gob with hash block string.
func (st *CookieSessionStore) Set(key, value interface{}) error {
func (st *CookieSessionStore) Set(ctx context.Context, key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
@ -42,7 +43,7 @@ func (st *CookieSessionStore) Set(key, value interface{}) error {
}
// Get value from cookie session
func (st *CookieSessionStore) Get(key interface{}) interface{} {
func (st *CookieSessionStore) Get(ctx context.Context, key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
@ -52,7 +53,7 @@ func (st *CookieSessionStore) Get(key interface{}) interface{} {
}
// Delete value in cookie session
func (st *CookieSessionStore) Delete(key interface{}) error {
func (st *CookieSessionStore) Delete(ctx context.Context, key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
@ -60,7 +61,7 @@ func (st *CookieSessionStore) Delete(key interface{}) error {
}
// Flush Clean all values in cookie session
func (st *CookieSessionStore) Flush() error {
func (st *CookieSessionStore) Flush(context.Context) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
@ -68,12 +69,12 @@ func (st *CookieSessionStore) Flush() error {
}
// SessionID Return id of this cookie session
func (st *CookieSessionStore) SessionID() string {
func (st *CookieSessionStore) SessionID(context.Context) string {
return st.sid
}
// SessionRelease Write cookie session to http response cookie
func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) {
func (st *CookieSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
st.lock.Lock()
encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values)
st.lock.Unlock()
@ -112,7 +113,7 @@ type CookieProvider struct {
// securityName - recognized name in encoded cookie string
// cookieName - cookie name
// maxage - cookie max life time.
func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error {
func (pder *CookieProvider) SessionInit(ctx context.Context, maxlifetime int64, config string) error {
pder.config = &cookieConfig{}
err := json.Unmarshal([]byte(config), pder.config)
if err != nil {
@ -134,7 +135,7 @@ func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error
// SessionRead Get SessionStore in cooke.
// decode cooke string to map and put into SessionStore with sid.
func (pder *CookieProvider) SessionRead(sid string) (Store, error) {
func (pder *CookieProvider) SessionRead(ctx context.Context, sid string) (Store, error) {
maps, _ := decodeCookie(pder.block,
pder.config.SecurityKey,
pder.config.SecurityName,
@ -147,26 +148,26 @@ func (pder *CookieProvider) SessionRead(sid string) (Store, error) {
}
// SessionExist Cookie session is always existed
func (pder *CookieProvider) SessionExist(sid string) (bool, error) {
func (pder *CookieProvider) SessionExist(ctx context.Context, sid string) (bool, error) {
return true, nil
}
// SessionRegenerate Implement method, no used.
func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (Store, error) {
func (pder *CookieProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) {
return nil, nil
}
// SessionDestroy Implement method, no used.
func (pder *CookieProvider) SessionDestroy(sid string) error {
func (pder *CookieProvider) SessionDestroy(ctx context.Context, sid string) error {
return nil
}
// SessionGC Implement method, no used.
func (pder *CookieProvider) SessionGC() {
func (pder *CookieProvider) SessionGC(context.Context) {
}
// SessionAll Implement method, return 0.
func (pder *CookieProvider) SessionAll() int {
func (pder *CookieProvider) SessionAll(context.Context) int {
return 0
}

View File

@ -38,14 +38,14 @@ func TestCookie(t *testing.T) {
if err != nil {
t.Fatal("set error,", err)
}
err = sess.Set("username", "astaxie")
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set error,", err)
}
if username := sess.Get("username"); username != "astaxie" {
if username := sess.Get(nil, "username"); username != "astaxie" {
t.Fatal("get username error")
}
sess.SessionRelease(w)
sess.SessionRelease(nil, w)
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
t.Fatal("setcookie error")
} else {
@ -85,7 +85,7 @@ func TestDestorySessionCookie(t *testing.T) {
if err != nil {
t.Fatal("session start err,", err)
}
if newSession.SessionID() != session.SessionID() {
if newSession.SessionID(nil) != session.SessionID(nil) {
t.Fatal("get cookie session id is not the same again.")
}
@ -99,7 +99,7 @@ func TestDestorySessionCookie(t *testing.T) {
if err != nil {
t.Fatal("session start error")
}
if newSession.SessionID() == session.SessionID() {
if newSession.SessionID(nil) == session.SessionID(nil) {
t.Fatal("after destroy session and reqeust again ,get cookie session id is same.")
}
}

View File

@ -15,6 +15,7 @@
package session
import (
"context"
"errors"
"fmt"
"io/ioutil"
@ -40,7 +41,7 @@ type FileSessionStore struct {
}
// Set value to file session
func (fs *FileSessionStore) Set(key, value interface{}) error {
func (fs *FileSessionStore) Set(ctx context.Context, key, value interface{}) error {
fs.lock.Lock()
defer fs.lock.Unlock()
fs.values[key] = value
@ -48,7 +49,7 @@ func (fs *FileSessionStore) Set(key, value interface{}) error {
}
// Get value from file session
func (fs *FileSessionStore) Get(key interface{}) interface{} {
func (fs *FileSessionStore) Get(ctx context.Context, key interface{}) interface{} {
fs.lock.RLock()
defer fs.lock.RUnlock()
if v, ok := fs.values[key]; ok {
@ -58,7 +59,7 @@ func (fs *FileSessionStore) Get(key interface{}) interface{} {
}
// Delete value in file session by given key
func (fs *FileSessionStore) Delete(key interface{}) error {
func (fs *FileSessionStore) Delete(ctx context.Context, key interface{}) error {
fs.lock.Lock()
defer fs.lock.Unlock()
delete(fs.values, key)
@ -66,7 +67,7 @@ func (fs *FileSessionStore) Delete(key interface{}) error {
}
// Flush Clean all values in file session
func (fs *FileSessionStore) Flush() error {
func (fs *FileSessionStore) Flush(context.Context) error {
fs.lock.Lock()
defer fs.lock.Unlock()
fs.values = make(map[interface{}]interface{})
@ -74,12 +75,12 @@ func (fs *FileSessionStore) Flush() error {
}
// SessionID Get file session store id
func (fs *FileSessionStore) SessionID() string {
func (fs *FileSessionStore) SessionID(context.Context) string {
return fs.sid
}
// SessionRelease Write file session to local file with Gob string
func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
func (fs *FileSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
b, err := EncodeGob(fs.values)
@ -119,7 +120,7 @@ type FileProvider struct {
// SessionInit Init file session provider.
// savePath sets the session files path.
func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error {
func (fp *FileProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
fp.maxlifetime = maxlifetime
fp.savePath = savePath
return nil
@ -128,7 +129,7 @@ func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error {
// SessionRead Read file session by sid.
// if file is not exist, create it.
// the file path is generated from sid string.
func (fp *FileProvider) SessionRead(sid string) (Store, error) {
func (fp *FileProvider) SessionRead(ctx context.Context, sid string) (Store, error) {
invalidChars := "./"
if strings.ContainsAny(sid, invalidChars) {
return nil, errors.New("the sid shouldn't have following characters: " + invalidChars)
@ -176,7 +177,7 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) {
// SessionExist Check file session exist.
// it checks the file named from sid exist or not.
func (fp *FileProvider) SessionExist(sid string) (bool, error) {
func (fp *FileProvider) SessionExist(ctx context.Context, sid string) (bool, error) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
@ -190,7 +191,7 @@ func (fp *FileProvider) SessionExist(sid string) (bool, error) {
}
// SessionDestroy Remove all files in this save path
func (fp *FileProvider) SessionDestroy(sid string) error {
func (fp *FileProvider) SessionDestroy(ctx context.Context, sid string) error {
filepder.lock.Lock()
defer filepder.lock.Unlock()
os.Remove(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
@ -198,7 +199,7 @@ func (fp *FileProvider) SessionDestroy(sid string) error {
}
// SessionGC Recycle files in save path
func (fp *FileProvider) SessionGC() {
func (fp *FileProvider) SessionGC(context.Context) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
@ -208,7 +209,7 @@ func (fp *FileProvider) SessionGC() {
// SessionAll Get active file session number.
// it walks save path to count files.
func (fp *FileProvider) SessionAll() int {
func (fp *FileProvider) SessionAll(context.Context) int {
a := &activeSession{}
err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error {
return a.visit(path, f, err)
@ -222,7 +223,7 @@ func (fp *FileProvider) SessionAll() int {
// SessionRegenerate Generate new sid for file session.
// it delete old file and create new file named from new sid.
func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) {
func (fp *FileProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) {
filepder.lock.Lock()
defer filepder.lock.Unlock()

View File

@ -15,6 +15,7 @@
package session
import (
"context"
"fmt"
"os"
"sync"
@ -37,7 +38,7 @@ func TestFileProvider_SessionInit(t *testing.T) {
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath)
_ = fp.SessionInit(context.Background(), 180, sessionPath)
if fp.maxlifetime != 180 {
t.Error()
}
@ -54,9 +55,9 @@ func TestFileProvider_SessionExist(t *testing.T) {
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath)
_ = fp.SessionInit(context.Background(), 180, sessionPath)
exists, err := fp.SessionExist(sid)
exists, err := fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
@ -64,12 +65,12 @@ func TestFileProvider_SessionExist(t *testing.T) {
t.Error()
}
_, err = fp.SessionRead(sid)
_, err = fp.SessionRead(context.Background(), sid)
if err != nil {
t.Error(err)
}
exists, err = fp.SessionExist(sid)
exists, err = fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
@ -85,9 +86,9 @@ func TestFileProvider_SessionExist2(t *testing.T) {
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath)
_ = fp.SessionInit(context.Background(), 180, sessionPath)
exists, err := fp.SessionExist(sid)
exists, err := fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
@ -95,7 +96,7 @@ func TestFileProvider_SessionExist2(t *testing.T) {
t.Error()
}
exists, err = fp.SessionExist("")
exists, err = fp.SessionExist(context.Background(), "")
if err == nil {
t.Error()
}
@ -103,7 +104,7 @@ func TestFileProvider_SessionExist2(t *testing.T) {
t.Error()
}
exists, err = fp.SessionExist("1")
exists, err = fp.SessionExist(context.Background(), "1")
if err == nil {
t.Error()
}
@ -119,15 +120,15 @@ func TestFileProvider_SessionRead(t *testing.T) {
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath)
_ = fp.SessionInit(context.Background(), 180, sessionPath)
s, err := fp.SessionRead(sid)
s, err := fp.SessionRead(context.Background(), sid)
if err != nil {
t.Error(err)
}
_ = s.Set("sessionValue", 18975)
v := s.Get("sessionValue")
_ = s.Set(nil, "sessionValue", 18975)
v := s.Get(nil, "sessionValue")
if v.(int) != 18975 {
t.Error()
@ -141,14 +142,14 @@ func TestFileProvider_SessionRead1(t *testing.T) {
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath)
_ = fp.SessionInit(context.Background(), 180, sessionPath)
_, err := fp.SessionRead("")
_, err := fp.SessionRead(context.Background(), "")
if err == nil {
t.Error(err)
}
_, err = fp.SessionRead("1")
_, err = fp.SessionRead(context.Background(), "1")
if err == nil {
t.Error(err)
}
@ -161,18 +162,18 @@ func TestFileProvider_SessionAll(t *testing.T) {
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath)
_ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 546
for i := 1; i <= sessionCount; i++ {
_, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i))
_, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil {
t.Error(err)
}
}
if fp.SessionAll() != sessionCount {
if fp.SessionAll(nil) != sessionCount {
t.Error()
}
}
@ -184,14 +185,14 @@ func TestFileProvider_SessionRegenerate(t *testing.T) {
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath)
_ = fp.SessionInit(context.Background(), 180, sessionPath)
_, err := fp.SessionRead(sid)
_, err := fp.SessionRead(context.Background(), sid)
if err != nil {
t.Error(err)
}
exists, err := fp.SessionExist(sid)
exists, err := fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
@ -199,12 +200,12 @@ func TestFileProvider_SessionRegenerate(t *testing.T) {
t.Error()
}
_, err = fp.SessionRegenerate(sid, sidNew)
_, err = fp.SessionRegenerate(context.Background(), sid, sidNew)
if err != nil {
t.Error(err)
}
exists, err = fp.SessionExist(sid)
exists, err = fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
@ -212,7 +213,7 @@ func TestFileProvider_SessionRegenerate(t *testing.T) {
t.Error()
}
exists, err = fp.SessionExist(sidNew)
exists, err = fp.SessionExist(context.Background(), sidNew)
if err != nil {
t.Error(err)
}
@ -228,14 +229,14 @@ func TestFileProvider_SessionDestroy(t *testing.T) {
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath)
_ = fp.SessionInit(context.Background(), 180, sessionPath)
_, err := fp.SessionRead(sid)
_, err := fp.SessionRead(context.Background(), sid)
if err != nil {
t.Error(err)
}
exists, err := fp.SessionExist(sid)
exists, err := fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
@ -243,12 +244,12 @@ func TestFileProvider_SessionDestroy(t *testing.T) {
t.Error()
}
err = fp.SessionDestroy(sid)
err = fp.SessionDestroy(context.Background(), sid)
if err != nil {
t.Error(err)
}
exists, err = fp.SessionExist(sid)
exists, err = fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
@ -264,12 +265,12 @@ func TestFileProvider_SessionGC(t *testing.T) {
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(1, sessionPath)
_ = fp.SessionInit(context.Background(), 1, sessionPath)
sessionCount := 412
for i := 1; i <= sessionCount; i++ {
_, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i))
_, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil {
t.Error(err)
}
@ -277,8 +278,8 @@ func TestFileProvider_SessionGC(t *testing.T) {
time.Sleep(2 * time.Second)
fp.SessionGC()
if fp.SessionAll() != 0 {
fp.SessionGC(nil)
if fp.SessionAll(nil) != 0 {
t.Error()
}
}
@ -290,12 +291,12 @@ func TestFileSessionStore_Set(t *testing.T) {
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath)
_ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 100
s, _ := fp.SessionRead(sid)
s, _ := fp.SessionRead(context.Background(), sid)
for i := 1; i <= sessionCount; i++ {
err := s.Set(i, i)
err := s.Set(nil, i, i)
if err != nil {
t.Error(err)
}
@ -309,14 +310,14 @@ func TestFileSessionStore_Get(t *testing.T) {
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath)
_ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 100
s, _ := fp.SessionRead(sid)
s, _ := fp.SessionRead(context.Background(), sid)
for i := 1; i <= sessionCount; i++ {
_ = s.Set(i, i)
_ = s.Set(nil, i, i)
v := s.Get(i)
v := s.Get(nil, i)
if v.(int) != i {
t.Error()
}
@ -330,18 +331,18 @@ func TestFileSessionStore_Delete(t *testing.T) {
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath)
_ = fp.SessionInit(context.Background(), 180, sessionPath)
s, _ := fp.SessionRead(sid)
s.Set("1", 1)
s, _ := fp.SessionRead(context.Background(), sid)
s.Set(nil, "1", 1)
if s.Get("1") == nil {
if s.Get(nil, "1") == nil {
t.Error()
}
s.Delete("1")
s.Delete(nil, "1")
if s.Get("1") != nil {
if s.Get(nil, "1") != nil {
t.Error()
}
}
@ -353,18 +354,18 @@ func TestFileSessionStore_Flush(t *testing.T) {
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath)
_ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 100
s, _ := fp.SessionRead(sid)
s, _ := fp.SessionRead(context.Background(), sid)
for i := 1; i <= sessionCount; i++ {
_ = s.Set(i, i)
_ = s.Set(nil, i, i)
}
_ = s.Flush()
_ = s.Flush(nil)
for i := 1; i <= sessionCount; i++ {
if s.Get(i) != nil {
if s.Get(nil, i) != nil {
t.Error()
}
}
@ -377,16 +378,16 @@ func TestFileSessionStore_SessionID(t *testing.T) {
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath)
_ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 85
for i := 1; i <= sessionCount; i++ {
s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i))
s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil {
t.Error(err)
}
if s.SessionID() != fmt.Sprintf("%s_%d", sid, i) {
if s.SessionID(nil) != fmt.Sprintf("%s_%d", sid, i) {
t.Error(err)
}
}
@ -399,27 +400,27 @@ func TestFileSessionStore_SessionRelease(t *testing.T) {
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(180, sessionPath)
_ = fp.SessionInit(context.Background(), 180, sessionPath)
filepder.savePath = sessionPath
sessionCount := 85
for i := 1; i <= sessionCount; i++ {
s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i))
s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil {
t.Error(err)
}
s.Set(i, i)
s.SessionRelease(nil)
s.Set(nil, i, i)
s.SessionRelease(nil, nil)
}
for i := 1; i <= sessionCount; i++ {
s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i))
s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil {
t.Error(err)
}
if s.Get(i).(int) != i {
if s.Get(nil, i).(int) != i {
t.Error()
}
}

View File

@ -16,6 +16,7 @@ package session
import (
"container/list"
"context"
"net/http"
"sync"
"time"
@ -33,7 +34,7 @@ type MemSessionStore struct {
}
// Set value to memory session
func (st *MemSessionStore) Set(key, value interface{}) error {
func (st *MemSessionStore) Set(ctx context.Context, key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.value[key] = value
@ -41,7 +42,7 @@ func (st *MemSessionStore) Set(key, value interface{}) error {
}
// Get value from memory session by key
func (st *MemSessionStore) Get(key interface{}) interface{} {
func (st *MemSessionStore) Get(ctx context.Context, key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.value[key]; ok {
@ -51,7 +52,7 @@ func (st *MemSessionStore) Get(key interface{}) interface{} {
}
// Delete in memory session by key
func (st *MemSessionStore) Delete(key interface{}) error {
func (st *MemSessionStore) Delete(ctx context.Context, key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.value, key)
@ -59,7 +60,7 @@ func (st *MemSessionStore) Delete(key interface{}) error {
}
// Flush clear all values in memory session
func (st *MemSessionStore) Flush() error {
func (st *MemSessionStore) Flush(context.Context) error {
st.lock.Lock()
defer st.lock.Unlock()
st.value = make(map[interface{}]interface{})
@ -67,12 +68,12 @@ func (st *MemSessionStore) Flush() error {
}
// SessionID get this id of memory session store
func (st *MemSessionStore) SessionID() string {
func (st *MemSessionStore) SessionID(context.Context) string {
return st.sid
}
// SessionRelease Implement method, no used.
func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) {
func (st *MemSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
}
// MemProvider Implement the provider interface
@ -85,14 +86,14 @@ type MemProvider struct {
}
// SessionInit init memory session
func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error {
func (pder *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
pder.maxlifetime = maxlifetime
pder.savePath = savePath
return nil
}
// SessionRead get memory session store by sid
func (pder *MemProvider) SessionRead(sid string) (Store, error) {
func (pder *MemProvider) SessionRead(ctx context.Context, sid string) (Store, error) {
pder.lock.RLock()
if element, ok := pder.sessions[sid]; ok {
go pder.SessionUpdate(sid)
@ -109,7 +110,7 @@ func (pder *MemProvider) SessionRead(sid string) (Store, error) {
}
// SessionExist check session store exist in memory session by sid
func (pder *MemProvider) SessionExist(sid string) (bool, error) {
func (pder *MemProvider) SessionExist(ctx context.Context, sid string) (bool, error) {
pder.lock.RLock()
defer pder.lock.RUnlock()
if _, ok := pder.sessions[sid]; ok {
@ -119,7 +120,7 @@ func (pder *MemProvider) SessionExist(sid string) (bool, error) {
}
// SessionRegenerate generate new sid for session store in memory session
func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) {
func (pder *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) {
pder.lock.RLock()
if element, ok := pder.sessions[oldsid]; ok {
go pder.SessionUpdate(oldsid)
@ -141,7 +142,7 @@ func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) {
}
// SessionDestroy delete session store in memory session by id
func (pder *MemProvider) SessionDestroy(sid string) error {
func (pder *MemProvider) SessionDestroy(ctx context.Context, sid string) error {
pder.lock.Lock()
defer pder.lock.Unlock()
if element, ok := pder.sessions[sid]; ok {
@ -153,7 +154,7 @@ func (pder *MemProvider) SessionDestroy(sid string) error {
}
// SessionGC clean expired session stores in memory session
func (pder *MemProvider) SessionGC() {
func (pder *MemProvider) SessionGC(context.Context) {
pder.lock.RLock()
for {
element := pder.list.Back()
@ -175,7 +176,7 @@ func (pder *MemProvider) SessionGC() {
}
// SessionAll get count number of memory session
func (pder *MemProvider) SessionAll() int {
func (pder *MemProvider) SessionAll(context.Context) int {
return pder.list.Len()
}

View File

@ -36,12 +36,12 @@ func TestMem(t *testing.T) {
if err != nil {
t.Fatal("set error,", err)
}
defer sess.SessionRelease(w)
err = sess.Set("username", "astaxie")
defer sess.SessionRelease(nil, w)
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set error,", err)
}
if username := sess.Get("username"); username != "astaxie" {
if username := sess.Get(nil, "username"); username != "astaxie" {
t.Fatal("get username error")
}
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {

View File

@ -28,6 +28,7 @@
package session
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
@ -43,24 +44,24 @@ import (
// Store contains all data for one session process with specific id.
type Store interface {
Set(key, value interface{}) error //set session value
Get(key interface{}) interface{} //get session value
Delete(key interface{}) error //delete session value
SessionID() string //back current sessionID
SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
Flush() error //delete all data
Set(ctx context.Context, key, value interface{}) error //set session value
Get(ctx context.Context, key interface{}) interface{} //get session value
Delete(ctx context.Context, key interface{}) error //delete session value
SessionID(ctx context.Context) string //back current sessionID
SessionRelease(ctx context.Context, w http.ResponseWriter) // release the resource & save data to provider & return the data
Flush(ctx context.Context) error //delete all data
}
// Provider contains global session methods and saved SessionStores.
// it can operate a SessionStore by its id.
type Provider interface {
SessionInit(gclifetime int64, config string) error
SessionRead(sid string) (Store, error)
SessionExist(sid string) (bool, error)
SessionRegenerate(oldsid, sid string) (Store, error)
SessionDestroy(sid string) error
SessionAll() int //get all active session
SessionGC()
SessionInit(ctx context.Context, gclifetime int64, config string) error
SessionRead(ctx context.Context, sid string) (Store, error)
SessionExist(ctx context.Context, sid string) (bool, error)
SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error)
SessionDestroy(ctx context.Context, sid string) error
SessionAll(ctx context.Context) int //get all active session
SessionGC(ctx context.Context)
}
var provides = make(map[string]Provider)
@ -148,7 +149,7 @@ func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) {
}
}
err := provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig)
err := provider.SessionInit(nil, cf.Maxlifetime, cf.ProviderConfig)
if err != nil {
return nil, err
}
@ -212,12 +213,12 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
}
if sid != "" {
exists, err := manager.provider.SessionExist(sid)
exists, err := manager.provider.SessionExist(nil, sid)
if err != nil {
return nil, err
}
if exists {
return manager.provider.SessionRead(sid)
return manager.provider.SessionRead(nil, sid)
}
}
@ -227,7 +228,7 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
return nil, errs
}
session, err = manager.provider.SessionRead(sid)
session, err = manager.provider.SessionRead(nil, sid)
if err != nil {
return nil, err
}
@ -269,7 +270,7 @@ func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
}
sid, _ := url.QueryUnescape(cookie.Value)
manager.provider.SessionDestroy(sid)
manager.provider.SessionDestroy(nil, sid)
if manager.config.EnableSetCookie {
expiration := time.Now()
cookie = &http.Cookie{Name: manager.config.CookieName,
@ -285,14 +286,14 @@ func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
// GetSessionStore Get SessionStore by its id.
func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) {
sessions, err = manager.provider.SessionRead(sid)
sessions, err = manager.provider.SessionRead(nil, sid)
return
}
// GC Start session gc process.
// it can do gc in times after gc lifetime.
func (manager *Manager) GC() {
manager.provider.SessionGC()
manager.provider.SessionGC(nil)
time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() })
}
@ -305,7 +306,7 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque
cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" {
//delete old cookie
session, _ = manager.provider.SessionRead(sid)
session, _ = manager.provider.SessionRead(nil, sid)
cookie = &http.Cookie{Name: manager.config.CookieName,
Value: url.QueryEscape(sid),
Path: "/",
@ -315,7 +316,7 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque
}
} else {
oldsid, _ := url.QueryUnescape(cookie.Value)
session, _ = manager.provider.SessionRegenerate(oldsid, sid)
session, _ = manager.provider.SessionRegenerate(nil, oldsid, sid)
cookie.Value = url.QueryEscape(sid)
cookie.HttpOnly = true
cookie.Path = "/"
@ -339,7 +340,7 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque
// GetActiveSession Get all active sessions count number.
func (manager *Manager) GetActiveSession() int {
return manager.provider.SessionAll()
return manager.provider.SessionAll(nil)
}
// SetSecure Set cookie with https.

View File

@ -1,6 +1,7 @@
package ssdb
import (
"context"
"errors"
"net/http"
"strconv"
@ -31,7 +32,7 @@ func (p *Provider) connectInit() error {
}
// SessionInit init the ssdb with the config
func (p *Provider) SessionInit(maxLifetime int64, savePath string) error {
func (p *Provider) SessionInit(ctx context.Context, maxLifetime int64, savePath string) error {
p.maxLifetime = maxLifetime
address := strings.Split(savePath, ":")
p.host = address[0]
@ -44,7 +45,7 @@ func (p *Provider) SessionInit(maxLifetime int64, savePath string) error {
}
// SessionRead return a ssdb client session Store
func (p *Provider) SessionRead(sid string) (session.Store, error) {
func (p *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
if p.client == nil {
if err := p.connectInit(); err != nil {
return nil, err
@ -68,7 +69,7 @@ func (p *Provider) SessionRead(sid string) (session.Store, error) {
}
// SessionExist judged whether sid is exist in session
func (p *Provider) SessionExist(sid string) (bool, error) {
func (p *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
if p.client == nil {
if err := p.connectInit(); err != nil {
return false, err
@ -85,7 +86,7 @@ func (p *Provider) SessionExist(sid string) (bool, error) {
}
// SessionRegenerate regenerate session with new sid and delete oldsid
func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
func (p *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
//conn.Do("setx", key, v, ttl)
if p.client == nil {
if err := p.connectInit(); err != nil {
@ -118,7 +119,7 @@ func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error)
}
// SessionDestroy destroy the sid
func (p *Provider) SessionDestroy(sid string) error {
func (p *Provider) SessionDestroy(ctx context.Context, sid string) error {
if p.client == nil {
if err := p.connectInit(); err != nil {
return err
@ -129,11 +130,11 @@ func (p *Provider) SessionDestroy(sid string) error {
}
// SessionGC not implemented
func (p *Provider) SessionGC() {
func (p *Provider) SessionGC(context.Context) {
}
// SessionAll not implemented
func (p *Provider) SessionAll() int {
func (p *Provider) SessionAll(context.Context) int {
return 0
}
@ -147,7 +148,7 @@ type SessionStore struct {
}
// Set the key and value
func (s *SessionStore) Set(key, value interface{}) error {
func (s *SessionStore) Set(ctx context.Context, key, value interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
s.values[key] = value
@ -155,7 +156,7 @@ func (s *SessionStore) Set(key, value interface{}) error {
}
// Get return the value by the key
func (s *SessionStore) Get(key interface{}) interface{} {
func (s *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
s.lock.Lock()
defer s.lock.Unlock()
if value, ok := s.values[key]; ok {
@ -165,7 +166,7 @@ func (s *SessionStore) Get(key interface{}) interface{} {
}
// Delete the key in session store
func (s *SessionStore) Delete(key interface{}) error {
func (s *SessionStore) Delete(ctx context.Context, key interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
delete(s.values, key)
@ -173,7 +174,7 @@ func (s *SessionStore) Delete(key interface{}) error {
}
// Flush delete all keys and values
func (s *SessionStore) Flush() error {
func (s *SessionStore) Flush(context.Context) error {
s.lock.Lock()
defer s.lock.Unlock()
s.values = make(map[interface{}]interface{})
@ -181,12 +182,12 @@ func (s *SessionStore) Flush() error {
}
// SessionID return the sessionID
func (s *SessionStore) SessionID() string {
func (s *SessionStore) SessionID(context.Context) string {
return s.sid
}
// SessionRelease Store the keyvalues into ssdb
func (s *SessionStore) SessionRelease(w http.ResponseWriter) {
func (s *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(s.values)
if err != nil {
return

View File

@ -361,7 +361,7 @@ func (input *BeegoInput) Cookie(key string) string {
// Session returns current session item value by a given key.
// if non-existed, return nil.
func (input *BeegoInput) Session(key interface{}) interface{} {
return input.CruSession.Get(key)
return input.CruSession.Get(nil, key)
}
// CopyBody returns the raw request body data as bytes.

View File

@ -404,5 +404,5 @@ func stringsToJSON(str string) string {
// Session sets session item value with given key.
func (output *BeegoOutput) Session(name interface{}, value interface{}) {
output.Context.Input.CruSession.Set(name, value)
output.Context.Input.CruSession.Set(nil, name, value)
}

View File

@ -622,7 +622,7 @@ func (c *Controller) SetSession(name interface{}, value interface{}) {
if c.CruSession == nil {
c.StartSession()
}
c.CruSession.Set(name, value)
c.CruSession.Set(nil, name, value)
}
// GetSession gets value from session.
@ -630,7 +630,7 @@ func (c *Controller) GetSession(name interface{}) interface{} {
if c.CruSession == nil {
c.StartSession()
}
return c.CruSession.Get(name)
return c.CruSession.Get(nil, name)
}
// DelSession removes value from session.
@ -638,14 +638,14 @@ func (c *Controller) DelSession(name interface{}) {
if c.CruSession == nil {
c.StartSession()
}
c.CruSession.Delete(name)
c.CruSession.Delete(nil, name)
}
// SessionRegenerateID regenerates session id for this session.
// the session data have no changes.
func (c *Controller) SessionRegenerateID() {
if c.CruSession != nil {
c.CruSession.SessionRelease(c.Ctx.ResponseWriter)
c.CruSession.SessionRelease(nil, c.Ctx.ResponseWriter)
}
c.CruSession = GlobalSessions.SessionRegenerateID(c.Ctx.ResponseWriter, c.Ctx.Request)
c.Ctx.Input.CruSession = c.CruSession
@ -653,7 +653,7 @@ func (c *Controller) SessionRegenerateID() {
// DestroySession cleans session data and session cookie.
func (c *Controller) DestroySession() {
c.Ctx.Input.CruSession.Flush()
c.Ctx.Input.CruSession.Flush(nil)
c.Ctx.Input.CruSession = nil
GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request)
}

View File

@ -721,7 +721,7 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) {
}
defer func() {
if ctx.Input.CruSession != nil {
ctx.Input.CruSession.SessionRelease(rw)
ctx.Input.CruSession.SessionRelease(nil, rw)
}
}()
}