diff --git a/pkg/client/cache/redis/redis_test.go b/pkg/client/cache/redis/redis_test.go index 00206157..dc0ca40f 100644 --- a/pkg/client/cache/redis/redis_test.go +++ b/pkg/client/cache/redis/redis_test.go @@ -129,7 +129,7 @@ func TestCache_Scan(t *testing.T) { t.Error("init err") } // insert all - for i := 0; i < 10000; i++ { + for i := 0; i < 100; i++ { if err = bm.Put(fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil { t.Error("set Error", err) } @@ -141,7 +141,7 @@ func TestCache_Scan(t *testing.T) { t.Error("scan Error", err) } - assert.Equal(t, 10000, len(keys), "scan all error") + assert.Equal(t, 100, len(keys), "scan all error") // clear all if err = bm.ClearAll(); err != nil { diff --git a/pkg/infrastructure/session/couchbase/sess_couchbase.go b/pkg/infrastructure/session/couchbase/sess_couchbase.go index 378cfc9f..ddb4be58 100644 --- a/pkg/infrastructure/session/couchbase/sess_couchbase.go +++ b/pkg/infrastructure/session/couchbase/sess_couchbase.go @@ -33,6 +33,7 @@ package couchbase import ( + "context" "net/http" "strings" "sync" @@ -63,7 +64,7 @@ type Provider struct { } // Set value to couchabse session -func (cs *SessionStore) Set(key, value interface{}) error { +func (cs *SessionStore) Set(ctx context.Context, key, value interface{}) error { cs.lock.Lock() defer cs.lock.Unlock() cs.values[key] = value @@ -71,7 +72,7 @@ func (cs *SessionStore) Set(key, value interface{}) error { } // Get value from couchabse session -func (cs *SessionStore) Get(key interface{}) interface{} { +func (cs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { cs.lock.RLock() defer cs.lock.RUnlock() if v, ok := cs.values[key]; ok { @@ -81,7 +82,7 @@ func (cs *SessionStore) Get(key interface{}) interface{} { } // Delete value in couchbase session by given key -func (cs *SessionStore) Delete(key interface{}) error { +func (cs *SessionStore) Delete(ctx context.Context, key interface{}) error { cs.lock.Lock() defer cs.lock.Unlock() delete(cs.values, key) @@ -89,7 +90,7 @@ func (cs *SessionStore) Delete(key interface{}) error { } // Flush Clean all values in couchbase session -func (cs *SessionStore) Flush() error { +func (cs *SessionStore) Flush(context.Context) error { cs.lock.Lock() defer cs.lock.Unlock() cs.values = make(map[interface{}]interface{}) @@ -97,12 +98,12 @@ func (cs *SessionStore) Flush() error { } // SessionID Get couchbase session store id -func (cs *SessionStore) SessionID() string { +func (cs *SessionStore) SessionID(context.Context) string { return cs.sid } // SessionRelease Write couchbase session with Gob string -func (cs *SessionStore) SessionRelease(w http.ResponseWriter) { +func (cs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { defer cs.b.Close() bo, err := session.EncodeGob(cs.values) @@ -135,7 +136,7 @@ func (cp *Provider) getBucket() *couchbase.Bucket { // SessionInit init couchbase session // savepath like couchbase server REST/JSON URL // e.g. http://host:port/, Pool, Bucket -func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (cp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { cp.maxlifetime = maxlifetime configs := strings.Split(savePath, ",") if len(configs) > 0 { @@ -152,7 +153,7 @@ func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error { } // SessionRead read couchbase session by sid -func (cp *Provider) SessionRead(sid string) (session.Store, error) { +func (cp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { cp.b = cp.getBucket() var ( @@ -179,7 +180,7 @@ func (cp *Provider) SessionRead(sid string) (session.Store, error) { // SessionExist Check couchbase session exist. // it checkes sid exist or not. -func (cp *Provider) SessionExist(sid string) (bool, error) { +func (cp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { cp.b = cp.getBucket() defer cp.b.Close() @@ -192,7 +193,7 @@ func (cp *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate remove oldsid and use sid to generate new session -func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (cp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { cp.b = cp.getBucket() var doc []byte @@ -225,7 +226,7 @@ func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) } // SessionDestroy Remove bucket in this couchbase -func (cp *Provider) SessionDestroy(sid string) error { +func (cp *Provider) SessionDestroy(ctx context.Context, sid string) error { cp.b = cp.getBucket() defer cp.b.Close() @@ -234,11 +235,11 @@ func (cp *Provider) SessionDestroy(sid string) error { } // SessionGC Recycle -func (cp *Provider) SessionGC() { +func (cp *Provider) SessionGC(context.Context) { } // SessionAll return all active session -func (cp *Provider) SessionAll() int { +func (cp *Provider) SessionAll(context.Context) int { return 0 } diff --git a/pkg/infrastructure/session/ledis/ledis_session.go b/pkg/infrastructure/session/ledis/ledis_session.go index 96e6efa3..74bf9b65 100644 --- a/pkg/infrastructure/session/ledis/ledis_session.go +++ b/pkg/infrastructure/session/ledis/ledis_session.go @@ -2,6 +2,7 @@ package ledis import ( + "context" "net/http" "strconv" "strings" @@ -27,7 +28,7 @@ type SessionStore struct { } // Set value in ledis session -func (ls *SessionStore) Set(key, value interface{}) error { +func (ls *SessionStore) Set(ctx context.Context, key, value interface{}) error { ls.lock.Lock() defer ls.lock.Unlock() ls.values[key] = value @@ -35,7 +36,7 @@ func (ls *SessionStore) Set(key, value interface{}) error { } // Get value in ledis session -func (ls *SessionStore) Get(key interface{}) interface{} { +func (ls *SessionStore) Get(ctx context.Context, key interface{}) interface{} { ls.lock.RLock() defer ls.lock.RUnlock() if v, ok := ls.values[key]; ok { @@ -45,7 +46,7 @@ func (ls *SessionStore) Get(key interface{}) interface{} { } // Delete value in ledis session -func (ls *SessionStore) Delete(key interface{}) error { +func (ls *SessionStore) Delete(ctx context.Context, key interface{}) error { ls.lock.Lock() defer ls.lock.Unlock() delete(ls.values, key) @@ -53,7 +54,7 @@ func (ls *SessionStore) Delete(key interface{}) error { } // Flush clear all values in ledis session -func (ls *SessionStore) Flush() error { +func (ls *SessionStore) Flush(context.Context) error { ls.lock.Lock() defer ls.lock.Unlock() ls.values = make(map[interface{}]interface{}) @@ -61,12 +62,12 @@ func (ls *SessionStore) Flush() error { } // SessionID get ledis session id -func (ls *SessionStore) SessionID() string { +func (ls *SessionStore) SessionID(context.Context) string { return ls.sid } // SessionRelease save session values to ledis -func (ls *SessionStore) SessionRelease(w http.ResponseWriter) { +func (ls *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { b, err := session.EncodeGob(ls.values) if err != nil { return @@ -85,7 +86,7 @@ type Provider struct { // SessionInit init ledis session // savepath like ledis server saveDataPath,pool size // e.g. 127.0.0.1:6379,100,astaxie -func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (lp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { var err error lp.maxlifetime = maxlifetime configs := strings.Split(savePath, ",") @@ -111,7 +112,7 @@ func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error { } // SessionRead read ledis session by sid -func (lp *Provider) SessionRead(sid string) (session.Store, error) { +func (lp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { var ( kv map[interface{}]interface{} err error @@ -132,13 +133,13 @@ func (lp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check ledis session exist by sid -func (lp *Provider) SessionExist(sid string) (bool, error) { +func (lp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { count, _ := c.Exists([]byte(sid)) return count != 0, nil } // SessionRegenerate generate new sid for ledis session -func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (lp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { count, _ := c.Exists([]byte(sid)) if count == 0 { // oldsid doesn't exists, set the new sid directly @@ -151,21 +152,21 @@ func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) c.Set([]byte(sid), data) c.Expire([]byte(sid), lp.maxlifetime) } - return lp.SessionRead(sid) + return lp.SessionRead(context.Background(), sid) } // SessionDestroy delete ledis session by id -func (lp *Provider) SessionDestroy(sid string) error { +func (lp *Provider) SessionDestroy(ctx context.Context, sid string) error { c.Del([]byte(sid)) return nil } // SessionGC Impelment method, no used. -func (lp *Provider) SessionGC() { +func (lp *Provider) SessionGC(context.Context) { } // SessionAll return all active session -func (lp *Provider) SessionAll() int { +func (lp *Provider) SessionAll(context.Context) int { return 0 } func init() { diff --git a/pkg/infrastructure/session/memcache/sess_memcache.go b/pkg/infrastructure/session/memcache/sess_memcache.go index 0758c43f..57df2844 100644 --- a/pkg/infrastructure/session/memcache/sess_memcache.go +++ b/pkg/infrastructure/session/memcache/sess_memcache.go @@ -33,6 +33,7 @@ package memcache import ( + "context" "net/http" "strings" "sync" @@ -54,7 +55,7 @@ type SessionStore struct { } // Set value in memcache session -func (rs *SessionStore) Set(key, value interface{}) error { +func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values[key] = value @@ -62,7 +63,7 @@ func (rs *SessionStore) Set(key, value interface{}) error { } // Get value in memcache session -func (rs *SessionStore) Get(key interface{}) interface{} { +func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { rs.lock.RLock() defer rs.lock.RUnlock() if v, ok := rs.values[key]; ok { @@ -72,7 +73,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} { } // Delete value in memcache session -func (rs *SessionStore) Delete(key interface{}) error { +func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() delete(rs.values, key) @@ -80,7 +81,7 @@ func (rs *SessionStore) Delete(key interface{}) error { } // Flush clear all values in memcache session -func (rs *SessionStore) Flush() error { +func (rs *SessionStore) Flush(context.Context) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values = make(map[interface{}]interface{}) @@ -88,12 +89,12 @@ func (rs *SessionStore) Flush() error { } // SessionID get memcache session id -func (rs *SessionStore) SessionID() string { +func (rs *SessionStore) SessionID(context.Context) string { return rs.sid } // SessionRelease save session values to memcache -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { +func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { b, err := session.EncodeGob(rs.values) if err != nil { return @@ -113,7 +114,7 @@ type MemProvider struct { // SessionInit init memcache session // savepath like // e.g. 127.0.0.1:9090 -func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error { +func (rp *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { rp.maxlifetime = maxlifetime rp.conninfo = strings.Split(savePath, ";") client = memcache.New(rp.conninfo...) @@ -121,7 +122,7 @@ func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error { } // SessionRead read memcache session by sid -func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { +func (rp *MemProvider) SessionRead(ctx context.Context, sid string) (session.Store, error) { if client == nil { if err := rp.connectInit(); err != nil { return nil, err @@ -149,7 +150,7 @@ func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { } // SessionExist check memcache session exist by sid -func (rp *MemProvider) SessionExist(sid string) (bool, error) { +func (rp *MemProvider) SessionExist(ctx context.Context, sid string) (bool, error) { if client == nil { if err := rp.connectInit(); err != nil { return false, err @@ -162,7 +163,7 @@ func (rp *MemProvider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for memcache session -func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (rp *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { if client == nil { if err := rp.connectInit(); err != nil { return nil, err @@ -201,7 +202,7 @@ func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, err } // SessionDestroy delete memcache session by id -func (rp *MemProvider) SessionDestroy(sid string) error { +func (rp *MemProvider) SessionDestroy(ctx context.Context, sid string) error { if client == nil { if err := rp.connectInit(); err != nil { return err @@ -217,11 +218,11 @@ func (rp *MemProvider) connectInit() error { } // SessionGC Impelment method, no used. -func (rp *MemProvider) SessionGC() { +func (rp *MemProvider) SessionGC(context.Context) { } // SessionAll return all activeSession -func (rp *MemProvider) SessionAll() int { +func (rp *MemProvider) SessionAll(context.Context) int { return 0 } diff --git a/pkg/infrastructure/session/mysql/sess_mysql.go b/pkg/infrastructure/session/mysql/sess_mysql.go index 2dadd317..fe1d69dc 100644 --- a/pkg/infrastructure/session/mysql/sess_mysql.go +++ b/pkg/infrastructure/session/mysql/sess_mysql.go @@ -41,6 +41,7 @@ package mysql import ( + "context" "database/sql" "net/http" "sync" @@ -67,7 +68,7 @@ type SessionStore struct { // Set value in mysql session. // it is temp value in map. -func (st *SessionStore) Set(key, value interface{}) error { +func (st *SessionStore) Set(ctx context.Context, key, value interface{}) error { st.lock.Lock() defer st.lock.Unlock() st.values[key] = value @@ -75,7 +76,7 @@ func (st *SessionStore) Set(key, value interface{}) error { } // Get value from mysql session -func (st *SessionStore) Get(key interface{}) interface{} { +func (st *SessionStore) Get(ctx context.Context, key interface{}) interface{} { st.lock.RLock() defer st.lock.RUnlock() if v, ok := st.values[key]; ok { @@ -85,7 +86,7 @@ func (st *SessionStore) Get(key interface{}) interface{} { } // Delete value in mysql session -func (st *SessionStore) Delete(key interface{}) error { +func (st *SessionStore) Delete(ctx context.Context, key interface{}) error { st.lock.Lock() defer st.lock.Unlock() delete(st.values, key) @@ -93,7 +94,7 @@ func (st *SessionStore) Delete(key interface{}) error { } // Flush clear all values in mysql session -func (st *SessionStore) Flush() error { +func (st *SessionStore) Flush(context.Context) error { st.lock.Lock() defer st.lock.Unlock() st.values = make(map[interface{}]interface{}) @@ -101,13 +102,13 @@ func (st *SessionStore) Flush() error { } // SessionID get session id of this mysql session store -func (st *SessionStore) SessionID() string { +func (st *SessionStore) SessionID(context.Context) string { return st.sid } // SessionRelease save mysql session values to database. // must call this method to save values to database. -func (st *SessionStore) SessionRelease(w http.ResponseWriter) { +func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { defer st.c.Close() b, err := session.EncodeGob(st.values) if err != nil { @@ -134,14 +135,14 @@ func (mp *Provider) connectInit() *sql.DB { // SessionInit init mysql session. // savepath is the connection string of mysql. -func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (mp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { mp.maxlifetime = maxlifetime mp.savePath = savePath return nil } // SessionRead get mysql session by sid -func (mp *Provider) SessionRead(sid string) (session.Store, error) { +func (mp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { c := mp.connectInit() row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) var sessiondata []byte @@ -164,7 +165,7 @@ func (mp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check mysql session exist -func (mp *Provider) SessionExist(sid string) (bool, error) { +func (mp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { c := mp.connectInit() defer c.Close() row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) @@ -180,7 +181,7 @@ func (mp *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for mysql session -func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (mp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { c := mp.connectInit() row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid) var sessiondata []byte @@ -203,7 +204,7 @@ func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) } // SessionDestroy delete mysql session by sid -func (mp *Provider) SessionDestroy(sid string) error { +func (mp *Provider) SessionDestroy(ctx context.Context, sid string) error { c := mp.connectInit() c.Exec("DELETE FROM "+TableName+" where session_key=?", sid) c.Close() @@ -211,14 +212,14 @@ func (mp *Provider) SessionDestroy(sid string) error { } // SessionGC delete expired values in mysql session -func (mp *Provider) SessionGC() { +func (mp *Provider) SessionGC(context.Context) { c := mp.connectInit() c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime) c.Close() } // SessionAll count values in mysql session -func (mp *Provider) SessionAll() int { +func (mp *Provider) SessionAll(context.Context) int { c := mp.connectInit() defer c.Close() var total int diff --git a/pkg/infrastructure/session/postgres/sess_postgresql.go b/pkg/infrastructure/session/postgres/sess_postgresql.go index adcf647b..2fadbed0 100644 --- a/pkg/infrastructure/session/postgres/sess_postgresql.go +++ b/pkg/infrastructure/session/postgres/sess_postgresql.go @@ -51,6 +51,7 @@ package postgres import ( + "context" "database/sql" "net/http" "sync" @@ -73,7 +74,7 @@ type SessionStore struct { // Set value in postgresql session. // it is temp value in map. -func (st *SessionStore) Set(key, value interface{}) error { +func (st *SessionStore) Set(ctx context.Context, key, value interface{}) error { st.lock.Lock() defer st.lock.Unlock() st.values[key] = value @@ -81,7 +82,7 @@ func (st *SessionStore) Set(key, value interface{}) error { } // Get value from postgresql session -func (st *SessionStore) Get(key interface{}) interface{} { +func (st *SessionStore) Get(ctx context.Context, key interface{}) interface{} { st.lock.RLock() defer st.lock.RUnlock() if v, ok := st.values[key]; ok { @@ -91,7 +92,7 @@ func (st *SessionStore) Get(key interface{}) interface{} { } // Delete value in postgresql session -func (st *SessionStore) Delete(key interface{}) error { +func (st *SessionStore) Delete(ctx context.Context, key interface{}) error { st.lock.Lock() defer st.lock.Unlock() delete(st.values, key) @@ -99,7 +100,7 @@ func (st *SessionStore) Delete(key interface{}) error { } // Flush clear all values in postgresql session -func (st *SessionStore) Flush() error { +func (st *SessionStore) Flush(context.Context) error { st.lock.Lock() defer st.lock.Unlock() st.values = make(map[interface{}]interface{}) @@ -107,13 +108,13 @@ func (st *SessionStore) Flush() error { } // SessionID get session id of this postgresql session store -func (st *SessionStore) SessionID() string { +func (st *SessionStore) SessionID(context.Context) string { return st.sid } // SessionRelease save postgresql session values to database. // must call this method to save values to database. -func (st *SessionStore) SessionRelease(w http.ResponseWriter) { +func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { defer st.c.Close() b, err := session.EncodeGob(st.values) if err != nil { @@ -141,14 +142,14 @@ func (mp *Provider) connectInit() *sql.DB { // SessionInit init postgresql session. // savepath is the connection string of postgresql. -func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (mp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { mp.maxlifetime = maxlifetime mp.savePath = savePath return nil } // SessionRead get postgresql session by sid -func (mp *Provider) SessionRead(sid string) (session.Store, error) { +func (mp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { c := mp.connectInit() row := c.QueryRow("select session_data from session where session_key=$1", sid) var sessiondata []byte @@ -178,7 +179,7 @@ func (mp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check postgresql session exist -func (mp *Provider) SessionExist(sid string) (bool, error) { +func (mp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { c := mp.connectInit() defer c.Close() row := c.QueryRow("select session_data from session where session_key=$1", sid) @@ -194,7 +195,7 @@ func (mp *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for postgresql session -func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (mp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { c := mp.connectInit() row := c.QueryRow("select session_data from session where session_key=$1", oldsid) var sessiondata []byte @@ -218,7 +219,7 @@ func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) } // SessionDestroy delete postgresql session by sid -func (mp *Provider) SessionDestroy(sid string) error { +func (mp *Provider) SessionDestroy(ctx context.Context, sid string) error { c := mp.connectInit() c.Exec("DELETE FROM session where session_key=$1", sid) c.Close() @@ -226,14 +227,14 @@ func (mp *Provider) SessionDestroy(sid string) error { } // SessionGC delete expired values in postgresql session -func (mp *Provider) SessionGC() { +func (mp *Provider) SessionGC(context.Context) { c := mp.connectInit() c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime) c.Close() } // SessionAll count values in postgresql session -func (mp *Provider) SessionAll() int { +func (mp *Provider) SessionAll(context.Context) int { c := mp.connectInit() defer c.Close() var total int diff --git a/pkg/infrastructure/session/redis/sess_redis.go b/pkg/infrastructure/session/redis/sess_redis.go index e775102c..c7bfbcbf 100644 --- a/pkg/infrastructure/session/redis/sess_redis.go +++ b/pkg/infrastructure/session/redis/sess_redis.go @@ -33,6 +33,7 @@ package redis import ( + "context" "net/http" "strconv" "strings" @@ -59,7 +60,7 @@ type SessionStore struct { } // Set value in redis session -func (rs *SessionStore) Set(key, value interface{}) error { +func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values[key] = value @@ -67,7 +68,7 @@ func (rs *SessionStore) Set(key, value interface{}) error { } // Get value in redis session -func (rs *SessionStore) Get(key interface{}) interface{} { +func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { rs.lock.RLock() defer rs.lock.RUnlock() if v, ok := rs.values[key]; ok { @@ -77,7 +78,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} { } // Delete value in redis session -func (rs *SessionStore) Delete(key interface{}) error { +func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() delete(rs.values, key) @@ -85,7 +86,7 @@ func (rs *SessionStore) Delete(key interface{}) error { } // Flush clear all values in redis session -func (rs *SessionStore) Flush() error { +func (rs *SessionStore) Flush(context.Context) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values = make(map[interface{}]interface{}) @@ -93,12 +94,12 @@ func (rs *SessionStore) Flush() error { } // SessionID get redis session id -func (rs *SessionStore) SessionID() string { +func (rs *SessionStore) SessionID(context.Context) string { return rs.sid } // SessionRelease save session values to redis -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { +func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { b, err := session.EncodeGob(rs.values) if err != nil { return @@ -123,7 +124,7 @@ type Provider struct { // SessionInit init redis session // savepath like redis server addr,pool size,password,dbnum,IdleTimeout second // e.g. 127.0.0.1:6379,100,astaxie,0,30 -func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { rp.maxlifetime = maxlifetime configs := strings.Split(savePath, ",") if len(configs) > 0 { @@ -185,7 +186,7 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } // SessionRead read redis session by sid -func (rp *Provider) SessionRead(sid string) (session.Store, error) { +func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { var kv map[interface{}]interface{} kvs, err := rp.poollist.Get(sid).Result() @@ -205,7 +206,7 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check redis session exist by sid -func (rp *Provider) SessionExist(sid string) (bool, error) { +func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { c := rp.poollist if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { @@ -215,7 +216,7 @@ func (rp *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for redis session -func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { c := rp.poollist if existed, _ := c.Exists(oldsid).Result(); existed == 0 { // oldsid doesn't exists, set the new sid directly @@ -226,11 +227,11 @@ func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) c.Rename(oldsid, sid) c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) } - return rp.SessionRead(sid) + return rp.SessionRead(context.Background(), sid) } // SessionDestroy delete redis session by id -func (rp *Provider) SessionDestroy(sid string) error { +func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error { c := rp.poollist c.Del(sid) @@ -238,11 +239,11 @@ func (rp *Provider) SessionDestroy(sid string) error { } // SessionGC Impelment method, no used. -func (rp *Provider) SessionGC() { +func (rp *Provider) SessionGC(context.Context) { } // SessionAll return all activeSession -func (rp *Provider) SessionAll() int { +func (rp *Provider) SessionAll(context.Context) int { return 0 } diff --git a/pkg/infrastructure/session/redis/sess_redis_test.go b/pkg/infrastructure/session/redis/sess_redis_test.go index ef466eab..df77204d 100644 --- a/pkg/infrastructure/session/redis/sess_redis_test.go +++ b/pkg/infrastructure/session/redis/sess_redis_test.go @@ -40,57 +40,57 @@ func TestRedis(t *testing.T) { if err != nil { t.Fatal("session start failed:", err) } - defer sess.SessionRelease(w) + defer sess.SessionRelease(nil, w) // SET AND GET - err = sess.Set("username", "astaxie") + err = sess.Set(nil, "username", "astaxie") if err != nil { t.Fatal("set username failed:", err) } - username := sess.Get("username") + username := sess.Get(nil, "username") if username != "astaxie" { t.Fatal("get username failed") } // DELETE - err = sess.Delete("username") + err = sess.Delete(nil, "username") if err != nil { t.Fatal("delete username failed:", err) } - username = sess.Get("username") + username = sess.Get(nil, "username") if username != nil { t.Fatal("delete username failed") } // FLUSH - err = sess.Set("username", "astaxie") + err = sess.Set(nil, "username", "astaxie") if err != nil { t.Fatal("set failed:", err) } - err = sess.Set("password", "1qaz2wsx") + err = sess.Set(nil, "password", "1qaz2wsx") if err != nil { t.Fatal("set failed:", err) } - username = sess.Get("username") + username = sess.Get(nil, "username") if username != "astaxie" { t.Fatal("get username failed") } - password := sess.Get("password") + password := sess.Get(nil, "password") if password != "1qaz2wsx" { t.Fatal("get password failed") } - err = sess.Flush() + err = sess.Flush(nil) if err != nil { t.Fatal("flush failed:", err) } - username = sess.Get("username") + username = sess.Get(nil, "username") if username != nil { t.Fatal("flush failed") } - password = sess.Get("password") + password = sess.Get(nil, "password") if password != nil { t.Fatal("flush failed") } - sess.SessionRelease(w) + sess.SessionRelease(nil, w) } diff --git a/pkg/infrastructure/session/redis_cluster/redis_cluster.go b/pkg/infrastructure/session/redis_cluster/redis_cluster.go index 40487d76..95907a5f 100644 --- a/pkg/infrastructure/session/redis_cluster/redis_cluster.go +++ b/pkg/infrastructure/session/redis_cluster/redis_cluster.go @@ -33,6 +33,7 @@ package redis_cluster import ( + "context" "net/http" "strconv" "strings" @@ -58,7 +59,7 @@ type SessionStore struct { } // Set value in redis_cluster session -func (rs *SessionStore) Set(key, value interface{}) error { +func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values[key] = value @@ -66,7 +67,7 @@ func (rs *SessionStore) Set(key, value interface{}) error { } // Get value in redis_cluster session -func (rs *SessionStore) Get(key interface{}) interface{} { +func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { rs.lock.RLock() defer rs.lock.RUnlock() if v, ok := rs.values[key]; ok { @@ -76,7 +77,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} { } // Delete value in redis_cluster session -func (rs *SessionStore) Delete(key interface{}) error { +func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() delete(rs.values, key) @@ -84,7 +85,7 @@ func (rs *SessionStore) Delete(key interface{}) error { } // Flush clear all values in redis_cluster session -func (rs *SessionStore) Flush() error { +func (rs *SessionStore) Flush(context.Context) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values = make(map[interface{}]interface{}) @@ -92,12 +93,12 @@ func (rs *SessionStore) Flush() error { } // SessionID get redis_cluster session id -func (rs *SessionStore) SessionID() string { +func (rs *SessionStore) SessionID(context.Context) string { return rs.sid } // SessionRelease save session values to redis_cluster -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { +func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { b, err := session.EncodeGob(rs.values) if err != nil { return @@ -122,7 +123,7 @@ type Provider struct { // SessionInit init redis_cluster session // savepath like redis server addr,pool size,password,dbnum // e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0 -func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { rp.maxlifetime = maxlifetime configs := strings.Split(savePath, ",") if len(configs) > 0 { @@ -182,7 +183,7 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } // SessionRead read redis_cluster session by sid -func (rp *Provider) SessionRead(sid string) (session.Store, error) { +func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { var kv map[interface{}]interface{} kvs, err := rp.poollist.Get(sid).Result() if err != nil && err != rediss.Nil { @@ -201,7 +202,7 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check redis_cluster session exist by sid -func (rp *Provider) SessionExist(sid string) (bool, error) { +func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { c := rp.poollist if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { return false, err @@ -210,7 +211,7 @@ func (rp *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for redis_cluster session -func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { c := rp.poollist if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { @@ -222,22 +223,22 @@ func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) c.Rename(oldsid, sid) c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) } - return rp.SessionRead(sid) + return rp.SessionRead(context.Background(), sid) } // SessionDestroy delete redis session by id -func (rp *Provider) SessionDestroy(sid string) error { +func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error { c := rp.poollist c.Del(sid) return nil } // SessionGC Impelment method, no used. -func (rp *Provider) SessionGC() { +func (rp *Provider) SessionGC(context.Context) { } // SessionAll return all activeSession -func (rp *Provider) SessionAll() int { +func (rp *Provider) SessionAll(context.Context) int { return 0 } diff --git a/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel.go b/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel.go index 1f6ebaa7..1b9c841b 100644 --- a/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel.go +++ b/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel.go @@ -33,6 +33,7 @@ package redis_sentinel import ( + "context" "net/http" "strconv" "strings" @@ -58,7 +59,7 @@ type SessionStore struct { } // Set value in redis_sentinel session -func (rs *SessionStore) Set(key, value interface{}) error { +func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values[key] = value @@ -66,7 +67,7 @@ func (rs *SessionStore) Set(key, value interface{}) error { } // Get value in redis_sentinel session -func (rs *SessionStore) Get(key interface{}) interface{} { +func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { rs.lock.RLock() defer rs.lock.RUnlock() if v, ok := rs.values[key]; ok { @@ -76,7 +77,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} { } // Delete value in redis_sentinel session -func (rs *SessionStore) Delete(key interface{}) error { +func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() delete(rs.values, key) @@ -84,7 +85,7 @@ func (rs *SessionStore) Delete(key interface{}) error { } // Flush clear all values in redis_sentinel session -func (rs *SessionStore) Flush() error { +func (rs *SessionStore) Flush(context.Context) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values = make(map[interface{}]interface{}) @@ -92,12 +93,12 @@ func (rs *SessionStore) Flush() error { } // SessionID get redis_sentinel session id -func (rs *SessionStore) SessionID() string { +func (rs *SessionStore) SessionID(context.Context) string { return rs.sid } // SessionRelease save session values to redis_sentinel -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { +func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { b, err := session.EncodeGob(rs.values) if err != nil { return @@ -123,7 +124,7 @@ type Provider struct { // SessionInit init redis_sentinel session // savepath like redis sentinel addr,pool size,password,dbnum,masterName // e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster -func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { rp.maxlifetime = maxlifetime configs := strings.Split(savePath, ",") if len(configs) > 0 { @@ -195,7 +196,7 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { } // SessionRead read redis_sentinel session by sid -func (rp *Provider) SessionRead(sid string) (session.Store, error) { +func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { var kv map[interface{}]interface{} kvs, err := rp.poollist.Get(sid).Result() if err != nil && err != redis.Nil { @@ -214,7 +215,7 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check redis_sentinel session exist by sid -func (rp *Provider) SessionExist(sid string) (bool, error) { +func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { c := rp.poollist if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { return false, err @@ -223,7 +224,7 @@ func (rp *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for redis_sentinel session -func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { c := rp.poollist if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { @@ -235,22 +236,22 @@ func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) c.Rename(oldsid, sid) c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) } - return rp.SessionRead(sid) + return rp.SessionRead(context.Background(), sid) } // SessionDestroy delete redis session by id -func (rp *Provider) SessionDestroy(sid string) error { +func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error { c := rp.poollist c.Del(sid) return nil } // SessionGC Impelment method, no used. -func (rp *Provider) SessionGC() { +func (rp *Provider) SessionGC(context.Context) { } // SessionAll return all activeSession -func (rp *Provider) SessionAll() int { +func (rp *Provider) SessionAll(context.Context) int { return 0 } diff --git a/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel_test.go b/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel_test.go index 0dc3520a..fcec9806 100644 --- a/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel_test.go +++ b/pkg/infrastructure/session/redis_sentinel/sess_redis_sentinel_test.go @@ -33,58 +33,58 @@ func TestRedisSentinel(t *testing.T) { if err != nil { t.Fatal("session start failed:", err) } - defer sess.SessionRelease(w) + defer sess.SessionRelease(nil, w) // SET AND GET - err = sess.Set("username", "astaxie") + err = sess.Set(nil, "username", "astaxie") if err != nil { t.Fatal("set username failed:", err) } - username := sess.Get("username") + username := sess.Get(nil, "username") if username != "astaxie" { t.Fatal("get username failed") } // DELETE - err = sess.Delete("username") + err = sess.Delete(nil, "username") if err != nil { t.Fatal("delete username failed:", err) } - username = sess.Get("username") + username = sess.Get(nil, "username") if username != nil { t.Fatal("delete username failed") } // FLUSH - err = sess.Set("username", "astaxie") + err = sess.Set(nil, "username", "astaxie") if err != nil { t.Fatal("set failed:", err) } - err = sess.Set("password", "1qaz2wsx") + err = sess.Set(nil, "password", "1qaz2wsx") if err != nil { t.Fatal("set failed:", err) } - username = sess.Get("username") + username = sess.Get(nil, "username") if username != "astaxie" { t.Fatal("get username failed") } - password := sess.Get("password") + password := sess.Get(nil, "password") if password != "1qaz2wsx" { t.Fatal("get password failed") } - err = sess.Flush() + err = sess.Flush(nil) if err != nil { t.Fatal("flush failed:", err) } - username = sess.Get("username") + username = sess.Get(nil, "username") if username != nil { t.Fatal("flush failed") } - password = sess.Get("password") + password = sess.Get(nil, "password") if password != nil { t.Fatal("flush failed") } - sess.SessionRelease(w) + sess.SessionRelease(nil, w) } diff --git a/pkg/infrastructure/session/sess_cookie.go b/pkg/infrastructure/session/sess_cookie.go index 30a7032e..ffb19fb7 100644 --- a/pkg/infrastructure/session/sess_cookie.go +++ b/pkg/infrastructure/session/sess_cookie.go @@ -15,6 +15,7 @@ package session import ( + "context" "crypto/aes" "crypto/cipher" "encoding/json" @@ -34,7 +35,7 @@ type CookieSessionStore struct { // Set value to cookie session. // the value are encoded as gob with hash block string. -func (st *CookieSessionStore) Set(key, value interface{}) error { +func (st *CookieSessionStore) Set(ctx context.Context, key, value interface{}) error { st.lock.Lock() defer st.lock.Unlock() st.values[key] = value @@ -42,7 +43,7 @@ func (st *CookieSessionStore) Set(key, value interface{}) error { } // Get value from cookie session -func (st *CookieSessionStore) Get(key interface{}) interface{} { +func (st *CookieSessionStore) Get(ctx context.Context, key interface{}) interface{} { st.lock.RLock() defer st.lock.RUnlock() if v, ok := st.values[key]; ok { @@ -52,7 +53,7 @@ func (st *CookieSessionStore) Get(key interface{}) interface{} { } // Delete value in cookie session -func (st *CookieSessionStore) Delete(key interface{}) error { +func (st *CookieSessionStore) Delete(ctx context.Context, key interface{}) error { st.lock.Lock() defer st.lock.Unlock() delete(st.values, key) @@ -60,7 +61,7 @@ func (st *CookieSessionStore) Delete(key interface{}) error { } // Flush Clean all values in cookie session -func (st *CookieSessionStore) Flush() error { +func (st *CookieSessionStore) Flush(context.Context) error { st.lock.Lock() defer st.lock.Unlock() st.values = make(map[interface{}]interface{}) @@ -68,12 +69,12 @@ func (st *CookieSessionStore) Flush() error { } // SessionID Return id of this cookie session -func (st *CookieSessionStore) SessionID() string { +func (st *CookieSessionStore) SessionID(context.Context) string { return st.sid } // SessionRelease Write cookie session to http response cookie -func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) { +func (st *CookieSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { st.lock.Lock() encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values) st.lock.Unlock() @@ -112,7 +113,7 @@ type CookieProvider struct { // securityName - recognized name in encoded cookie string // cookieName - cookie name // maxage - cookie max life time. -func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error { +func (pder *CookieProvider) SessionInit(ctx context.Context, maxlifetime int64, config string) error { pder.config = &cookieConfig{} err := json.Unmarshal([]byte(config), pder.config) if err != nil { @@ -134,7 +135,7 @@ func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error // SessionRead Get SessionStore in cooke. // decode cooke string to map and put into SessionStore with sid. -func (pder *CookieProvider) SessionRead(sid string) (Store, error) { +func (pder *CookieProvider) SessionRead(ctx context.Context, sid string) (Store, error) { maps, _ := decodeCookie(pder.block, pder.config.SecurityKey, pder.config.SecurityName, @@ -147,26 +148,26 @@ func (pder *CookieProvider) SessionRead(sid string) (Store, error) { } // SessionExist Cookie session is always existed -func (pder *CookieProvider) SessionExist(sid string) (bool, error) { +func (pder *CookieProvider) SessionExist(ctx context.Context, sid string) (bool, error) { return true, nil } // SessionRegenerate Implement method, no used. -func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (Store, error) { +func (pder *CookieProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) { return nil, nil } // SessionDestroy Implement method, no used. -func (pder *CookieProvider) SessionDestroy(sid string) error { +func (pder *CookieProvider) SessionDestroy(ctx context.Context, sid string) error { return nil } // SessionGC Implement method, no used. -func (pder *CookieProvider) SessionGC() { +func (pder *CookieProvider) SessionGC(context.Context) { } // SessionAll Implement method, return 0. -func (pder *CookieProvider) SessionAll() int { +func (pder *CookieProvider) SessionAll(context.Context) int { return 0 } diff --git a/pkg/infrastructure/session/sess_cookie_test.go b/pkg/infrastructure/session/sess_cookie_test.go index b6726005..a9fc876d 100644 --- a/pkg/infrastructure/session/sess_cookie_test.go +++ b/pkg/infrastructure/session/sess_cookie_test.go @@ -38,14 +38,14 @@ func TestCookie(t *testing.T) { if err != nil { t.Fatal("set error,", err) } - err = sess.Set("username", "astaxie") + err = sess.Set(nil, "username", "astaxie") if err != nil { t.Fatal("set error,", err) } - if username := sess.Get("username"); username != "astaxie" { + if username := sess.Get(nil, "username"); username != "astaxie" { t.Fatal("get username error") } - sess.SessionRelease(w) + sess.SessionRelease(nil, w) if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { t.Fatal("setcookie error") } else { @@ -85,7 +85,7 @@ func TestDestorySessionCookie(t *testing.T) { if err != nil { t.Fatal("session start err,", err) } - if newSession.SessionID() != session.SessionID() { + if newSession.SessionID(nil) != session.SessionID(nil) { t.Fatal("get cookie session id is not the same again.") } @@ -99,7 +99,7 @@ func TestDestorySessionCookie(t *testing.T) { if err != nil { t.Fatal("session start error") } - if newSession.SessionID() == session.SessionID() { + if newSession.SessionID(nil) == session.SessionID(nil) { t.Fatal("after destroy session and reqeust again ,get cookie session id is same.") } } diff --git a/pkg/infrastructure/session/sess_file.go b/pkg/infrastructure/session/sess_file.go index 37d5bd68..90de9a79 100644 --- a/pkg/infrastructure/session/sess_file.go +++ b/pkg/infrastructure/session/sess_file.go @@ -15,6 +15,7 @@ package session import ( + "context" "errors" "fmt" "io/ioutil" @@ -40,7 +41,7 @@ type FileSessionStore struct { } // Set value to file session -func (fs *FileSessionStore) Set(key, value interface{}) error { +func (fs *FileSessionStore) Set(ctx context.Context, key, value interface{}) error { fs.lock.Lock() defer fs.lock.Unlock() fs.values[key] = value @@ -48,7 +49,7 @@ func (fs *FileSessionStore) Set(key, value interface{}) error { } // Get value from file session -func (fs *FileSessionStore) Get(key interface{}) interface{} { +func (fs *FileSessionStore) Get(ctx context.Context, key interface{}) interface{} { fs.lock.RLock() defer fs.lock.RUnlock() if v, ok := fs.values[key]; ok { @@ -58,7 +59,7 @@ func (fs *FileSessionStore) Get(key interface{}) interface{} { } // Delete value in file session by given key -func (fs *FileSessionStore) Delete(key interface{}) error { +func (fs *FileSessionStore) Delete(ctx context.Context, key interface{}) error { fs.lock.Lock() defer fs.lock.Unlock() delete(fs.values, key) @@ -66,7 +67,7 @@ func (fs *FileSessionStore) Delete(key interface{}) error { } // Flush Clean all values in file session -func (fs *FileSessionStore) Flush() error { +func (fs *FileSessionStore) Flush(context.Context) error { fs.lock.Lock() defer fs.lock.Unlock() fs.values = make(map[interface{}]interface{}) @@ -74,12 +75,12 @@ func (fs *FileSessionStore) Flush() error { } // SessionID Get file session store id -func (fs *FileSessionStore) SessionID() string { +func (fs *FileSessionStore) SessionID(context.Context) string { return fs.sid } // SessionRelease Write file session to local file with Gob string -func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { +func (fs *FileSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { filepder.lock.Lock() defer filepder.lock.Unlock() b, err := EncodeGob(fs.values) @@ -119,7 +120,7 @@ type FileProvider struct { // SessionInit Init file session provider. // savePath sets the session files path. -func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { +func (fp *FileProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { fp.maxlifetime = maxlifetime fp.savePath = savePath return nil @@ -128,7 +129,7 @@ func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { // SessionRead Read file session by sid. // if file is not exist, create it. // the file path is generated from sid string. -func (fp *FileProvider) SessionRead(sid string) (Store, error) { +func (fp *FileProvider) SessionRead(ctx context.Context, sid string) (Store, error) { invalidChars := "./" if strings.ContainsAny(sid, invalidChars) { return nil, errors.New("the sid shouldn't have following characters: " + invalidChars) @@ -176,7 +177,7 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) { // SessionExist Check file session exist. // it checks the file named from sid exist or not. -func (fp *FileProvider) SessionExist(sid string) (bool, error) { +func (fp *FileProvider) SessionExist(ctx context.Context, sid string) (bool, error) { filepder.lock.Lock() defer filepder.lock.Unlock() @@ -190,7 +191,7 @@ func (fp *FileProvider) SessionExist(sid string) (bool, error) { } // SessionDestroy Remove all files in this save path -func (fp *FileProvider) SessionDestroy(sid string) error { +func (fp *FileProvider) SessionDestroy(ctx context.Context, sid string) error { filepder.lock.Lock() defer filepder.lock.Unlock() os.Remove(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) @@ -198,7 +199,7 @@ func (fp *FileProvider) SessionDestroy(sid string) error { } // SessionGC Recycle files in save path -func (fp *FileProvider) SessionGC() { +func (fp *FileProvider) SessionGC(context.Context) { filepder.lock.Lock() defer filepder.lock.Unlock() @@ -208,7 +209,7 @@ func (fp *FileProvider) SessionGC() { // SessionAll Get active file session number. // it walks save path to count files. -func (fp *FileProvider) SessionAll() int { +func (fp *FileProvider) SessionAll(context.Context) int { a := &activeSession{} err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error { return a.visit(path, f, err) @@ -222,7 +223,7 @@ func (fp *FileProvider) SessionAll() int { // SessionRegenerate Generate new sid for file session. // it delete old file and create new file named from new sid. -func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { +func (fp *FileProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) { filepder.lock.Lock() defer filepder.lock.Unlock() diff --git a/pkg/infrastructure/session/sess_file_test.go b/pkg/infrastructure/session/sess_file_test.go index a27d30a6..f40de69f 100644 --- a/pkg/infrastructure/session/sess_file_test.go +++ b/pkg/infrastructure/session/sess_file_test.go @@ -15,6 +15,7 @@ package session import ( + "context" "fmt" "os" "sync" @@ -37,7 +38,7 @@ func TestFileProvider_SessionInit(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) if fp.maxlifetime != 180 { t.Error() } @@ -54,9 +55,9 @@ func TestFileProvider_SessionExist(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) - exists, err := fp.SessionExist(sid) + exists, err := fp.SessionExist(context.Background(), sid) if err != nil { t.Error(err) } @@ -64,12 +65,12 @@ func TestFileProvider_SessionExist(t *testing.T) { t.Error() } - _, err = fp.SessionRead(sid) + _, err = fp.SessionRead(context.Background(), sid) if err != nil { t.Error(err) } - exists, err = fp.SessionExist(sid) + exists, err = fp.SessionExist(context.Background(), sid) if err != nil { t.Error(err) } @@ -85,9 +86,9 @@ func TestFileProvider_SessionExist2(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) - exists, err := fp.SessionExist(sid) + exists, err := fp.SessionExist(context.Background(), sid) if err != nil { t.Error(err) } @@ -95,7 +96,7 @@ func TestFileProvider_SessionExist2(t *testing.T) { t.Error() } - exists, err = fp.SessionExist("") + exists, err = fp.SessionExist(context.Background(), "") if err == nil { t.Error() } @@ -103,7 +104,7 @@ func TestFileProvider_SessionExist2(t *testing.T) { t.Error() } - exists, err = fp.SessionExist("1") + exists, err = fp.SessionExist(context.Background(), "1") if err == nil { t.Error() } @@ -119,15 +120,15 @@ func TestFileProvider_SessionRead(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) - s, err := fp.SessionRead(sid) + s, err := fp.SessionRead(context.Background(), sid) if err != nil { t.Error(err) } - _ = s.Set("sessionValue", 18975) - v := s.Get("sessionValue") + _ = s.Set(nil, "sessionValue", 18975) + v := s.Get(nil, "sessionValue") if v.(int) != 18975 { t.Error() @@ -141,14 +142,14 @@ func TestFileProvider_SessionRead1(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) - _, err := fp.SessionRead("") + _, err := fp.SessionRead(context.Background(), "") if err == nil { t.Error(err) } - _, err = fp.SessionRead("1") + _, err = fp.SessionRead(context.Background(), "1") if err == nil { t.Error(err) } @@ -161,18 +162,18 @@ func TestFileProvider_SessionAll(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) sessionCount := 546 for i := 1; i <= sessionCount; i++ { - _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + _, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i)) if err != nil { t.Error(err) } } - if fp.SessionAll() != sessionCount { + if fp.SessionAll(nil) != sessionCount { t.Error() } } @@ -184,14 +185,14 @@ func TestFileProvider_SessionRegenerate(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) - _, err := fp.SessionRead(sid) + _, err := fp.SessionRead(context.Background(), sid) if err != nil { t.Error(err) } - exists, err := fp.SessionExist(sid) + exists, err := fp.SessionExist(context.Background(), sid) if err != nil { t.Error(err) } @@ -199,12 +200,12 @@ func TestFileProvider_SessionRegenerate(t *testing.T) { t.Error() } - _, err = fp.SessionRegenerate(sid, sidNew) + _, err = fp.SessionRegenerate(context.Background(), sid, sidNew) if err != nil { t.Error(err) } - exists, err = fp.SessionExist(sid) + exists, err = fp.SessionExist(context.Background(), sid) if err != nil { t.Error(err) } @@ -212,7 +213,7 @@ func TestFileProvider_SessionRegenerate(t *testing.T) { t.Error() } - exists, err = fp.SessionExist(sidNew) + exists, err = fp.SessionExist(context.Background(), sidNew) if err != nil { t.Error(err) } @@ -228,14 +229,14 @@ func TestFileProvider_SessionDestroy(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) - _, err := fp.SessionRead(sid) + _, err := fp.SessionRead(context.Background(), sid) if err != nil { t.Error(err) } - exists, err := fp.SessionExist(sid) + exists, err := fp.SessionExist(context.Background(), sid) if err != nil { t.Error(err) } @@ -243,12 +244,12 @@ func TestFileProvider_SessionDestroy(t *testing.T) { t.Error() } - err = fp.SessionDestroy(sid) + err = fp.SessionDestroy(context.Background(), sid) if err != nil { t.Error(err) } - exists, err = fp.SessionExist(sid) + exists, err = fp.SessionExist(context.Background(), sid) if err != nil { t.Error(err) } @@ -264,12 +265,12 @@ func TestFileProvider_SessionGC(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(1, sessionPath) + _ = fp.SessionInit(context.Background(), 1, sessionPath) sessionCount := 412 for i := 1; i <= sessionCount; i++ { - _, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + _, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i)) if err != nil { t.Error(err) } @@ -277,8 +278,8 @@ func TestFileProvider_SessionGC(t *testing.T) { time.Sleep(2 * time.Second) - fp.SessionGC() - if fp.SessionAll() != 0 { + fp.SessionGC(nil) + if fp.SessionAll(nil) != 0 { t.Error() } } @@ -290,12 +291,12 @@ func TestFileSessionStore_Set(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) sessionCount := 100 - s, _ := fp.SessionRead(sid) + s, _ := fp.SessionRead(context.Background(), sid) for i := 1; i <= sessionCount; i++ { - err := s.Set(i, i) + err := s.Set(nil, i, i) if err != nil { t.Error(err) } @@ -309,14 +310,14 @@ func TestFileSessionStore_Get(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) sessionCount := 100 - s, _ := fp.SessionRead(sid) + s, _ := fp.SessionRead(context.Background(), sid) for i := 1; i <= sessionCount; i++ { - _ = s.Set(i, i) + _ = s.Set(nil, i, i) - v := s.Get(i) + v := s.Get(nil, i) if v.(int) != i { t.Error() } @@ -330,18 +331,18 @@ func TestFileSessionStore_Delete(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) - s, _ := fp.SessionRead(sid) - s.Set("1", 1) + s, _ := fp.SessionRead(context.Background(), sid) + s.Set(nil, "1", 1) - if s.Get("1") == nil { + if s.Get(nil, "1") == nil { t.Error() } - s.Delete("1") + s.Delete(nil, "1") - if s.Get("1") != nil { + if s.Get(nil, "1") != nil { t.Error() } } @@ -353,18 +354,18 @@ func TestFileSessionStore_Flush(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) sessionCount := 100 - s, _ := fp.SessionRead(sid) + s, _ := fp.SessionRead(context.Background(), sid) for i := 1; i <= sessionCount; i++ { - _ = s.Set(i, i) + _ = s.Set(nil, i, i) } - _ = s.Flush() + _ = s.Flush(nil) for i := 1; i <= sessionCount; i++ { - if s.Get(i) != nil { + if s.Get(nil, i) != nil { t.Error() } } @@ -377,16 +378,16 @@ func TestFileSessionStore_SessionID(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) sessionCount := 85 for i := 1; i <= sessionCount; i++ { - s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i)) if err != nil { t.Error(err) } - if s.SessionID() != fmt.Sprintf("%s_%d", sid, i) { + if s.SessionID(nil) != fmt.Sprintf("%s_%d", sid, i) { t.Error(err) } } @@ -399,27 +400,27 @@ func TestFileSessionStore_SessionRelease(t *testing.T) { defer os.RemoveAll(sessionPath) fp := &FileProvider{} - _ = fp.SessionInit(180, sessionPath) + _ = fp.SessionInit(context.Background(), 180, sessionPath) filepder.savePath = sessionPath sessionCount := 85 for i := 1; i <= sessionCount; i++ { - s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i)) if err != nil { t.Error(err) } - s.Set(i, i) - s.SessionRelease(nil) + s.Set(nil, i, i) + s.SessionRelease(nil, nil) } for i := 1; i <= sessionCount; i++ { - s, err := fp.SessionRead(fmt.Sprintf("%s_%d", sid, i)) + s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i)) if err != nil { t.Error(err) } - if s.Get(i).(int) != i { + if s.Get(nil, i).(int) != i { t.Error() } } diff --git a/pkg/infrastructure/session/sess_mem.go b/pkg/infrastructure/session/sess_mem.go index bd69ff80..9a27c331 100644 --- a/pkg/infrastructure/session/sess_mem.go +++ b/pkg/infrastructure/session/sess_mem.go @@ -16,6 +16,7 @@ package session import ( "container/list" + "context" "net/http" "sync" "time" @@ -33,7 +34,7 @@ type MemSessionStore struct { } // Set value to memory session -func (st *MemSessionStore) Set(key, value interface{}) error { +func (st *MemSessionStore) Set(ctx context.Context, key, value interface{}) error { st.lock.Lock() defer st.lock.Unlock() st.value[key] = value @@ -41,7 +42,7 @@ func (st *MemSessionStore) Set(key, value interface{}) error { } // Get value from memory session by key -func (st *MemSessionStore) Get(key interface{}) interface{} { +func (st *MemSessionStore) Get(ctx context.Context, key interface{}) interface{} { st.lock.RLock() defer st.lock.RUnlock() if v, ok := st.value[key]; ok { @@ -51,7 +52,7 @@ func (st *MemSessionStore) Get(key interface{}) interface{} { } // Delete in memory session by key -func (st *MemSessionStore) Delete(key interface{}) error { +func (st *MemSessionStore) Delete(ctx context.Context, key interface{}) error { st.lock.Lock() defer st.lock.Unlock() delete(st.value, key) @@ -59,7 +60,7 @@ func (st *MemSessionStore) Delete(key interface{}) error { } // Flush clear all values in memory session -func (st *MemSessionStore) Flush() error { +func (st *MemSessionStore) Flush(context.Context) error { st.lock.Lock() defer st.lock.Unlock() st.value = make(map[interface{}]interface{}) @@ -67,12 +68,12 @@ func (st *MemSessionStore) Flush() error { } // SessionID get this id of memory session store -func (st *MemSessionStore) SessionID() string { +func (st *MemSessionStore) SessionID(context.Context) string { return st.sid } // SessionRelease Implement method, no used. -func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) { +func (st *MemSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { } // MemProvider Implement the provider interface @@ -85,14 +86,14 @@ type MemProvider struct { } // SessionInit init memory session -func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error { +func (pder *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { pder.maxlifetime = maxlifetime pder.savePath = savePath return nil } // SessionRead get memory session store by sid -func (pder *MemProvider) SessionRead(sid string) (Store, error) { +func (pder *MemProvider) SessionRead(ctx context.Context, sid string) (Store, error) { pder.lock.RLock() if element, ok := pder.sessions[sid]; ok { go pder.SessionUpdate(sid) @@ -109,7 +110,7 @@ func (pder *MemProvider) SessionRead(sid string) (Store, error) { } // SessionExist check session store exist in memory session by sid -func (pder *MemProvider) SessionExist(sid string) (bool, error) { +func (pder *MemProvider) SessionExist(ctx context.Context, sid string) (bool, error) { pder.lock.RLock() defer pder.lock.RUnlock() if _, ok := pder.sessions[sid]; ok { @@ -119,7 +120,7 @@ func (pder *MemProvider) SessionExist(sid string) (bool, error) { } // SessionRegenerate generate new sid for session store in memory session -func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) { +func (pder *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) { pder.lock.RLock() if element, ok := pder.sessions[oldsid]; ok { go pder.SessionUpdate(oldsid) @@ -141,7 +142,7 @@ func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) { } // SessionDestroy delete session store in memory session by id -func (pder *MemProvider) SessionDestroy(sid string) error { +func (pder *MemProvider) SessionDestroy(ctx context.Context, sid string) error { pder.lock.Lock() defer pder.lock.Unlock() if element, ok := pder.sessions[sid]; ok { @@ -153,7 +154,7 @@ func (pder *MemProvider) SessionDestroy(sid string) error { } // SessionGC clean expired session stores in memory session -func (pder *MemProvider) SessionGC() { +func (pder *MemProvider) SessionGC(context.Context) { pder.lock.RLock() for { element := pder.list.Back() @@ -175,7 +176,7 @@ func (pder *MemProvider) SessionGC() { } // SessionAll get count number of memory session -func (pder *MemProvider) SessionAll() int { +func (pder *MemProvider) SessionAll(context.Context) int { return pder.list.Len() } diff --git a/pkg/infrastructure/session/sess_mem_test.go b/pkg/infrastructure/session/sess_mem_test.go index 2e8934b8..e6d35476 100644 --- a/pkg/infrastructure/session/sess_mem_test.go +++ b/pkg/infrastructure/session/sess_mem_test.go @@ -36,12 +36,12 @@ func TestMem(t *testing.T) { if err != nil { t.Fatal("set error,", err) } - defer sess.SessionRelease(w) - err = sess.Set("username", "astaxie") + defer sess.SessionRelease(nil, w) + err = sess.Set(nil, "username", "astaxie") if err != nil { t.Fatal("set error,", err) } - if username := sess.Get("username"); username != "astaxie" { + if username := sess.Get(nil, "username"); username != "astaxie" { t.Fatal("get username error") } if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { diff --git a/pkg/infrastructure/session/session.go b/pkg/infrastructure/session/session.go index 92e35de4..bb7e5bd6 100644 --- a/pkg/infrastructure/session/session.go +++ b/pkg/infrastructure/session/session.go @@ -28,6 +28,7 @@ package session import ( + "context" "crypto/rand" "encoding/hex" "errors" @@ -43,24 +44,24 @@ import ( // Store contains all data for one session process with specific id. type Store interface { - Set(key, value interface{}) error //set session value - Get(key interface{}) interface{} //get session value - Delete(key interface{}) error //delete session value - SessionID() string //back current sessionID - SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data - Flush() error //delete all data + Set(ctx context.Context, key, value interface{}) error //set session value + Get(ctx context.Context, key interface{}) interface{} //get session value + Delete(ctx context.Context, key interface{}) error //delete session value + SessionID(ctx context.Context) string //back current sessionID + SessionRelease(ctx context.Context, w http.ResponseWriter) // release the resource & save data to provider & return the data + Flush(ctx context.Context) error //delete all data } // Provider contains global session methods and saved SessionStores. // it can operate a SessionStore by its id. type Provider interface { - SessionInit(gclifetime int64, config string) error - SessionRead(sid string) (Store, error) - SessionExist(sid string) (bool, error) - SessionRegenerate(oldsid, sid string) (Store, error) - SessionDestroy(sid string) error - SessionAll() int //get all active session - SessionGC() + SessionInit(ctx context.Context, gclifetime int64, config string) error + SessionRead(ctx context.Context, sid string) (Store, error) + SessionExist(ctx context.Context, sid string) (bool, error) + SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) + SessionDestroy(ctx context.Context, sid string) error + SessionAll(ctx context.Context) int //get all active session + SessionGC(ctx context.Context) } var provides = make(map[string]Provider) @@ -148,7 +149,7 @@ func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) { } } - err := provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig) + err := provider.SessionInit(nil, cf.Maxlifetime, cf.ProviderConfig) if err != nil { return nil, err } @@ -212,12 +213,12 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se } if sid != "" { - exists, err := manager.provider.SessionExist(sid) + exists, err := manager.provider.SessionExist(nil, sid) if err != nil { return nil, err } if exists { - return manager.provider.SessionRead(sid) + return manager.provider.SessionRead(nil, sid) } } @@ -227,7 +228,7 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se return nil, errs } - session, err = manager.provider.SessionRead(sid) + session, err = manager.provider.SessionRead(nil, sid) if err != nil { return nil, err } @@ -269,7 +270,7 @@ func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { } sid, _ := url.QueryUnescape(cookie.Value) - manager.provider.SessionDestroy(sid) + manager.provider.SessionDestroy(nil, sid) if manager.config.EnableSetCookie { expiration := time.Now() cookie = &http.Cookie{Name: manager.config.CookieName, @@ -285,14 +286,14 @@ func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { // GetSessionStore Get SessionStore by its id. func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) { - sessions, err = manager.provider.SessionRead(sid) + sessions, err = manager.provider.SessionRead(nil, sid) return } // GC Start session gc process. // it can do gc in times after gc lifetime. func (manager *Manager) GC() { - manager.provider.SessionGC() + manager.provider.SessionGC(nil) time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() }) } @@ -305,7 +306,7 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque cookie, err := r.Cookie(manager.config.CookieName) if err != nil || cookie.Value == "" { //delete old cookie - session, _ = manager.provider.SessionRead(sid) + session, _ = manager.provider.SessionRead(nil, sid) cookie = &http.Cookie{Name: manager.config.CookieName, Value: url.QueryEscape(sid), Path: "/", @@ -315,7 +316,7 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque } } else { oldsid, _ := url.QueryUnescape(cookie.Value) - session, _ = manager.provider.SessionRegenerate(oldsid, sid) + session, _ = manager.provider.SessionRegenerate(nil, oldsid, sid) cookie.Value = url.QueryEscape(sid) cookie.HttpOnly = true cookie.Path = "/" @@ -339,7 +340,7 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque // GetActiveSession Get all active sessions count number. func (manager *Manager) GetActiveSession() int { - return manager.provider.SessionAll() + return manager.provider.SessionAll(nil) } // SetSecure Set cookie with https. diff --git a/pkg/infrastructure/session/ssdb/sess_ssdb.go b/pkg/infrastructure/session/ssdb/sess_ssdb.go index 77d0c5c2..6e4f341e 100644 --- a/pkg/infrastructure/session/ssdb/sess_ssdb.go +++ b/pkg/infrastructure/session/ssdb/sess_ssdb.go @@ -1,6 +1,7 @@ package ssdb import ( + "context" "errors" "net/http" "strconv" @@ -31,7 +32,7 @@ func (p *Provider) connectInit() error { } // SessionInit init the ssdb with the config -func (p *Provider) SessionInit(maxLifetime int64, savePath string) error { +func (p *Provider) SessionInit(ctx context.Context, maxLifetime int64, savePath string) error { p.maxLifetime = maxLifetime address := strings.Split(savePath, ":") p.host = address[0] @@ -44,7 +45,7 @@ func (p *Provider) SessionInit(maxLifetime int64, savePath string) error { } // SessionRead return a ssdb client session Store -func (p *Provider) SessionRead(sid string) (session.Store, error) { +func (p *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { if p.client == nil { if err := p.connectInit(); err != nil { return nil, err @@ -68,7 +69,7 @@ func (p *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist judged whether sid is exist in session -func (p *Provider) SessionExist(sid string) (bool, error) { +func (p *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { if p.client == nil { if err := p.connectInit(); err != nil { return false, err @@ -85,7 +86,7 @@ func (p *Provider) SessionExist(sid string) (bool, error) { } // SessionRegenerate regenerate session with new sid and delete oldsid -func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (p *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { //conn.Do("setx", key, v, ttl) if p.client == nil { if err := p.connectInit(); err != nil { @@ -118,7 +119,7 @@ func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) } // SessionDestroy destroy the sid -func (p *Provider) SessionDestroy(sid string) error { +func (p *Provider) SessionDestroy(ctx context.Context, sid string) error { if p.client == nil { if err := p.connectInit(); err != nil { return err @@ -129,11 +130,11 @@ func (p *Provider) SessionDestroy(sid string) error { } // SessionGC not implemented -func (p *Provider) SessionGC() { +func (p *Provider) SessionGC(context.Context) { } // SessionAll not implemented -func (p *Provider) SessionAll() int { +func (p *Provider) SessionAll(context.Context) int { return 0 } @@ -147,7 +148,7 @@ type SessionStore struct { } // Set the key and value -func (s *SessionStore) Set(key, value interface{}) error { +func (s *SessionStore) Set(ctx context.Context, key, value interface{}) error { s.lock.Lock() defer s.lock.Unlock() s.values[key] = value @@ -155,7 +156,7 @@ func (s *SessionStore) Set(key, value interface{}) error { } // Get return the value by the key -func (s *SessionStore) Get(key interface{}) interface{} { +func (s *SessionStore) Get(ctx context.Context, key interface{}) interface{} { s.lock.Lock() defer s.lock.Unlock() if value, ok := s.values[key]; ok { @@ -165,7 +166,7 @@ func (s *SessionStore) Get(key interface{}) interface{} { } // Delete the key in session store -func (s *SessionStore) Delete(key interface{}) error { +func (s *SessionStore) Delete(ctx context.Context, key interface{}) error { s.lock.Lock() defer s.lock.Unlock() delete(s.values, key) @@ -173,7 +174,7 @@ func (s *SessionStore) Delete(key interface{}) error { } // Flush delete all keys and values -func (s *SessionStore) Flush() error { +func (s *SessionStore) Flush(context.Context) error { s.lock.Lock() defer s.lock.Unlock() s.values = make(map[interface{}]interface{}) @@ -181,12 +182,12 @@ func (s *SessionStore) Flush() error { } // SessionID return the sessionID -func (s *SessionStore) SessionID() string { +func (s *SessionStore) SessionID(context.Context) string { return s.sid } // SessionRelease Store the keyvalues into ssdb -func (s *SessionStore) SessionRelease(w http.ResponseWriter) { +func (s *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { b, err := session.EncodeGob(s.values) if err != nil { return diff --git a/pkg/server/web/context/input.go b/pkg/server/web/context/input.go index b8272f64..a6fec774 100644 --- a/pkg/server/web/context/input.go +++ b/pkg/server/web/context/input.go @@ -361,7 +361,7 @@ func (input *BeegoInput) Cookie(key string) string { // Session returns current session item value by a given key. // if non-existed, return nil. func (input *BeegoInput) Session(key interface{}) interface{} { - return input.CruSession.Get(key) + return input.CruSession.Get(nil, key) } // CopyBody returns the raw request body data as bytes. diff --git a/pkg/server/web/context/output.go b/pkg/server/web/context/output.go index 0a530244..a6e83681 100644 --- a/pkg/server/web/context/output.go +++ b/pkg/server/web/context/output.go @@ -404,5 +404,5 @@ func stringsToJSON(str string) string { // Session sets session item value with given key. func (output *BeegoOutput) Session(name interface{}, value interface{}) { - output.Context.Input.CruSession.Set(name, value) + output.Context.Input.CruSession.Set(nil, name, value) } diff --git a/pkg/server/web/controller.go b/pkg/server/web/controller.go index 6b71d617..2081e647 100644 --- a/pkg/server/web/controller.go +++ b/pkg/server/web/controller.go @@ -622,7 +622,7 @@ func (c *Controller) SetSession(name interface{}, value interface{}) { if c.CruSession == nil { c.StartSession() } - c.CruSession.Set(name, value) + c.CruSession.Set(nil, name, value) } // GetSession gets value from session. @@ -630,7 +630,7 @@ func (c *Controller) GetSession(name interface{}) interface{} { if c.CruSession == nil { c.StartSession() } - return c.CruSession.Get(name) + return c.CruSession.Get(nil, name) } // DelSession removes value from session. @@ -638,14 +638,14 @@ func (c *Controller) DelSession(name interface{}) { if c.CruSession == nil { c.StartSession() } - c.CruSession.Delete(name) + c.CruSession.Delete(nil, name) } // SessionRegenerateID regenerates session id for this session. // the session data have no changes. func (c *Controller) SessionRegenerateID() { if c.CruSession != nil { - c.CruSession.SessionRelease(c.Ctx.ResponseWriter) + c.CruSession.SessionRelease(nil, c.Ctx.ResponseWriter) } c.CruSession = GlobalSessions.SessionRegenerateID(c.Ctx.ResponseWriter, c.Ctx.Request) c.Ctx.Input.CruSession = c.CruSession @@ -653,7 +653,7 @@ func (c *Controller) SessionRegenerateID() { // DestroySession cleans session data and session cookie. func (c *Controller) DestroySession() { - c.Ctx.Input.CruSession.Flush() + c.Ctx.Input.CruSession.Flush(nil) c.Ctx.Input.CruSession = nil GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request) } diff --git a/pkg/server/web/router.go b/pkg/server/web/router.go index 9b70753e..c3eddd29 100644 --- a/pkg/server/web/router.go +++ b/pkg/server/web/router.go @@ -721,7 +721,7 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { } defer func() { if ctx.Input.CruSession != nil { - ctx.Input.CruSession.SessionRelease(rw) + ctx.Input.CruSession.SessionRelease(nil, rw) } }() }