diff --git a/server/web/session/couchbase/sess_couchbase.go b/server/web/session/couchbase/sess_couchbase.go index ddd46401..7f15956a 100644 --- a/server/web/session/couchbase/sess_couchbase.go +++ b/server/web/session/couchbase/sess_couchbase.go @@ -34,6 +34,7 @@ package couchbase import ( "context" + "encoding/json" "net/http" "strings" "sync" @@ -57,9 +58,9 @@ type SessionStore struct { // Provider couchabse provided type Provider struct { maxlifetime int64 - savePath string - pool string - bucket string + SavePath string `json:"save_path"` + Pool string `json:"pool"` + Bucket string `json:"bucket"` b *couchbase.Bucket } @@ -115,17 +116,17 @@ func (cs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite } func (cp *Provider) getBucket() *couchbase.Bucket { - c, err := couchbase.Connect(cp.savePath) + c, err := couchbase.Connect(cp.SavePath) if err != nil { return nil } - pool, err := c.GetPool(cp.pool) + pool, err := c.GetPool(cp.Pool) if err != nil { return nil } - bucket, err := pool.GetBucket(cp.bucket) + bucket, err := pool.GetBucket(cp.Bucket) if err != nil { return nil } @@ -135,18 +136,31 @@ func (cp *Provider) getBucket() *couchbase.Bucket { // SessionInit init couchbase session // savepath like couchbase server REST/JSON URL -// e.g. http://host:port/, Pool, Bucket -func (cp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { +// For v1.x e.g. http://host:port/, Pool, Bucket +// For v2.x, you should pass json string. +// e.g. { "save_path": "http://host:port/", "pool": "mypool", "bucket": "mybucket"} +func (cp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfg string) error { cp.maxlifetime = maxlifetime + cfg = strings.TrimSpace(cfg) + // we think this is v2.0, using json to init the session + if strings.HasPrefix(cfg, "{") { + return json.Unmarshal([]byte(cfg), cp) + } else { + return cp.initOldStyle(cfg) + } +} + +// initOldStyle keep compatible with v1.x +func (cp *Provider) initOldStyle(savePath string) error { configs := strings.Split(savePath, ",") if len(configs) > 0 { - cp.savePath = configs[0] + cp.SavePath = configs[0] } if len(configs) > 1 { - cp.pool = configs[1] + cp.Pool = configs[1] } if len(configs) > 2 { - cp.bucket = configs[2] + cp.Bucket = configs[2] } return nil @@ -225,7 +239,7 @@ func (cp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) ( return cs, nil } -// SessionDestroy Remove bucket in this couchbase +// SessionDestroy Remove Bucket in this couchbase func (cp *Provider) SessionDestroy(ctx context.Context, sid string) error { cp.b = cp.getBucket() defer cp.b.Close() diff --git a/server/web/session/couchbase/sess_couchbase_test.go b/server/web/session/couchbase/sess_couchbase_test.go new file mode 100644 index 00000000..5959f9c3 --- /dev/null +++ b/server/web/session/couchbase/sess_couchbase_test.go @@ -0,0 +1,43 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package couchbase + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProvider_SessionInit(t *testing.T) { + // using old style + savePath := `http://host:port/,Pool,Bucket` + cp := &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "http://host:port/", cp.SavePath) + assert.Equal(t, "Pool", cp.Pool) + assert.Equal(t, "Bucket", cp.Bucket) + assert.Equal(t, int64(12), cp.maxlifetime) + + savePath = ` +{ "save_path": "my save path", "pool": "mypool", "bucket": "mybucket"} +` + cp = &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "my save path", cp.SavePath) + assert.Equal(t, "mypool", cp.Pool) + assert.Equal(t, "mybucket", cp.Bucket) + assert.Equal(t, int64(12), cp.maxlifetime) +} diff --git a/server/web/session/ledis/ledis_session.go b/server/web/session/ledis/ledis_session.go index a920ff7c..5b930fcd 100644 --- a/server/web/session/ledis/ledis_session.go +++ b/server/web/session/ledis/ledis_session.go @@ -3,6 +3,7 @@ package ledis import ( "context" + "encoding/json" "net/http" "strconv" "strings" @@ -79,35 +80,51 @@ func (ls *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite // Provider ledis session provider type Provider struct { maxlifetime int64 - savePath string - db int + SavePath string `json:"save_path"` + Db int `json:"db"` } // SessionInit init ledis session // savepath like ledis server saveDataPath,pool size -// e.g. 127.0.0.1:6379,100,astaxie -func (lp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { +// v1.x e.g. 127.0.0.1:6379,100 +// v2.x you should pass a json string +// e.g. { "save_path": "my save path", "db": 100} +func (lp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error { var err error lp.maxlifetime = maxlifetime - configs := strings.Split(savePath, ",") - if len(configs) == 1 { - lp.savePath = configs[0] - } else if len(configs) == 2 { - lp.savePath = configs[0] - lp.db, err = strconv.Atoi(configs[1]) - if err != nil { - return err - } + cfgStr = strings.TrimSpace(cfgStr) + // we think cfgStr is v2.0, using json to init the session + if strings.HasPrefix(cfgStr, "{") { + err = json.Unmarshal([]byte(cfgStr), lp) + } else { + err = lp.initOldStyle(cfgStr) } + + if err != nil { + return err + } + cfg := new(config.Config) - cfg.DataDir = lp.savePath + cfg.DataDir = lp.SavePath var ledisInstance *ledis.Ledis ledisInstance, err = ledis.Open(cfg) if err != nil { return err } - c, err = ledisInstance.Select(lp.db) + c, err = ledisInstance.Select(lp.Db) + return err +} + +func (lp *Provider) initOldStyle(cfgStr string) error { + var err error + configs := strings.Split(cfgStr, ",") + if len(configs) == 1 { + lp.SavePath = configs[0] + } else if len(configs) == 2 { + lp.SavePath = configs[0] + lp.Db, err = strconv.Atoi(configs[1]) + } return err } diff --git a/server/web/session/ledis/ledis_session_test.go b/server/web/session/ledis/ledis_session_test.go new file mode 100644 index 00000000..1cfb3ed1 --- /dev/null +++ b/server/web/session/ledis/ledis_session_test.go @@ -0,0 +1,41 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ledis + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProvider_SessionInit(t *testing.T) { + // using old style + savePath := `http://host:port/,100` + cp := &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "http://host:port/", cp.SavePath) + assert.Equal(t, 100, cp.Db) + assert.Equal(t, int64(12), cp.maxlifetime) + + savePath = ` +{ "save_path": "my save path", "db": 100} +` + cp = &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "my save path", cp.SavePath) + assert.Equal(t, 100, cp.Db) + assert.Equal(t, int64(12), cp.maxlifetime) +} diff --git a/server/web/session/redis/sess_redis.go b/server/web/session/redis/sess_redis.go index 6ee28e2f..c6e3bcbb 100644 --- a/server/web/session/redis/sess_redis.go +++ b/server/web/session/redis/sess_redis.go @@ -34,6 +34,7 @@ package redis import ( "context" + "encoding/json" "net/http" "strconv" "strings" @@ -110,48 +111,89 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite // Provider redis session provider type Provider struct { - maxlifetime int64 - savePath string - poolsize int - password string - dbNum int - idleTimeout time.Duration - idleCheckFrequency time.Duration - maxRetries int - poollist *redis.Client + maxlifetime int64 + SavePath string `json:"save_path"` + Poolsize int `json:"poolsize"` + Password string `json:"password"` + DbNum int `json:"db_num"` + + idleTimeout time.Duration + IdleTimeoutStr string `json:"idle_timeout"` + + idleCheckFrequency time.Duration + IdleCheckFrequencyStr string `json:"idle_check_frequency"` + MaxRetries int `json:"max_retries"` + poollist *redis.Client } // SessionInit init redis session // savepath like redis server addr,pool size,password,dbnum,IdleTimeout second -// e.g. 127.0.0.1:6379,100,astaxie,0,30 -func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { +// v1.x e.g. 127.0.0.1:6379,100,astaxie,0,30 +// v2.0 you should pass json string +func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error { rp.maxlifetime = maxlifetime + + cfgStr = strings.TrimSpace(cfgStr) + // we think cfgStr is v2.0, using json to init the session + if strings.HasPrefix(cfgStr, "{") { + err := json.Unmarshal([]byte(cfgStr), rp) + if err != nil { + return err + } + rp.idleTimeout, err = time.ParseDuration(rp.IdleTimeoutStr) + if err != nil { + return err + } + + rp.idleCheckFrequency, err = time.ParseDuration(rp.IdleCheckFrequencyStr) + if err != nil { + return err + } + + } else { + rp.initOldStyle(cfgStr) + } + + rp.poollist = redis.NewClient(&redis.Options{ + Addr: rp.SavePath, + Password: rp.Password, + PoolSize: rp.Poolsize, + DB: rp.DbNum, + IdleTimeout: rp.idleTimeout, + IdleCheckFrequency: rp.idleCheckFrequency, + MaxRetries: rp.MaxRetries, + }) + + return rp.poollist.Ping().Err() +} + +func (rp *Provider) initOldStyle(savePath string) { configs := strings.Split(savePath, ",") if len(configs) > 0 { - rp.savePath = configs[0] + rp.SavePath = configs[0] } if len(configs) > 1 { poolsize, err := strconv.Atoi(configs[1]) if err != nil || poolsize < 0 { - rp.poolsize = MaxPoolSize + rp.Poolsize = MaxPoolSize } else { - rp.poolsize = poolsize + rp.Poolsize = poolsize } } else { - rp.poolsize = MaxPoolSize + rp.Poolsize = MaxPoolSize } if len(configs) > 2 { - rp.password = configs[2] + rp.Password = configs[2] } if len(configs) > 3 { dbnum, err := strconv.Atoi(configs[3]) if err != nil || dbnum < 0 { - rp.dbNum = 0 + rp.DbNum = 0 } else { - rp.dbNum = dbnum + rp.DbNum = dbnum } } else { - rp.dbNum = 0 + rp.DbNum = 0 } if len(configs) > 4 { timeout, err := strconv.Atoi(configs[4]) @@ -168,21 +210,9 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath if len(configs) > 6 { retries, err := strconv.Atoi(configs[6]) if err == nil && retries > 0 { - rp.maxRetries = retries + rp.MaxRetries = retries } } - - rp.poollist = redis.NewClient(&redis.Options{ - Addr: rp.savePath, - Password: rp.password, - PoolSize: rp.poolsize, - DB: rp.dbNum, - IdleTimeout: rp.idleTimeout, - IdleCheckFrequency: rp.idleCheckFrequency, - MaxRetries: rp.maxRetries, - }) - - return rp.poollist.Ping().Err() } // SessionRead read redis session by sid diff --git a/server/web/session/redis/sess_redis_test.go b/server/web/session/redis/sess_redis_test.go index 19c8c025..64dbc9f9 100644 --- a/server/web/session/redis/sess_redis_test.go +++ b/server/web/session/redis/sess_redis_test.go @@ -1,11 +1,15 @@ package redis import ( + "context" "fmt" "net/http" "net/http/httptest" "os" "testing" + "time" + + "github.com/stretchr/testify/assert" "github.com/astaxie/beego/server/web/session" ) @@ -94,3 +98,15 @@ func TestRedis(t *testing.T) { sess.SessionRelease(nil, w) } + +func TestProvider_SessionInit(t *testing.T) { + + savePath := ` +{ "save_path": "my save path", "idle_timeout": "3s"} +` + cp := &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "my save path", cp.SavePath) + assert.Equal(t, 3*time.Second, cp.idleTimeout) + assert.Equal(t, int64(12), cp.maxlifetime) +} diff --git a/server/web/session/redis_cluster/redis_cluster.go b/server/web/session/redis_cluster/redis_cluster.go index 17653d56..d2971e71 100644 --- a/server/web/session/redis_cluster/redis_cluster.go +++ b/server/web/session/redis_cluster/redis_cluster.go @@ -34,14 +34,16 @@ package redis_cluster import ( "context" + "encoding/json" "net/http" "strconv" "strings" "sync" "time" - "github.com/astaxie/beego/server/web/session" rediss "github.com/go-redis/redis/v7" + + "github.com/astaxie/beego/server/web/session" ) var redispder = &Provider{} @@ -109,48 +111,86 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite // Provider redis_cluster session provider type Provider struct { - maxlifetime int64 - savePath string - poolsize int - password string - dbNum int - idleTimeout time.Duration - idleCheckFrequency time.Duration - maxRetries int - poollist *rediss.ClusterClient + maxlifetime int64 + SavePath string `json:"save_path"` + Poolsize int `json:"poolsize"` + Password string `json:"password"` + DbNum int `json:"db_num"` + + idleTimeout time.Duration + IdleTimeoutStr string `json:"idle_timeout"` + + idleCheckFrequency time.Duration + IdleCheckFrequencyStr string `json:"idle_check_frequency"` + MaxRetries int `json:"max_retries"` + poollist *rediss.ClusterClient } // SessionInit init redis_cluster session -// savepath like redis server addr,pool size,password,dbnum +// cfgStr like redis server addr,pool size,password,dbnum // e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0 -func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { +func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error { rp.maxlifetime = maxlifetime + cfgStr = strings.TrimSpace(cfgStr) + // we think cfgStr is v2.0, using json to init the session + if strings.HasPrefix(cfgStr, "{") { + err := json.Unmarshal([]byte(cfgStr), rp) + if err != nil { + return err + } + rp.idleTimeout, err = time.ParseDuration(rp.IdleTimeoutStr) + if err != nil { + return err + } + + rp.idleCheckFrequency, err = time.ParseDuration(rp.IdleCheckFrequencyStr) + if err != nil { + return err + } + + } else { + rp.initOldStyle(cfgStr) + } + + rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{ + Addrs: strings.Split(rp.SavePath, ";"), + Password: rp.Password, + PoolSize: rp.Poolsize, + IdleTimeout: rp.idleTimeout, + IdleCheckFrequency: rp.idleCheckFrequency, + MaxRetries: rp.MaxRetries, + }) + return rp.poollist.Ping().Err() +} + +// for v1.x +func (rp *Provider) initOldStyle(savePath string) { configs := strings.Split(savePath, ",") if len(configs) > 0 { - rp.savePath = configs[0] + rp.SavePath = configs[0] } if len(configs) > 1 { poolsize, err := strconv.Atoi(configs[1]) if err != nil || poolsize < 0 { - rp.poolsize = MaxPoolSize + rp.Poolsize = MaxPoolSize } else { - rp.poolsize = poolsize + rp.Poolsize = poolsize } } else { - rp.poolsize = MaxPoolSize + rp.Poolsize = MaxPoolSize } if len(configs) > 2 { - rp.password = configs[2] + rp.Password = configs[2] } if len(configs) > 3 { dbnum, err := strconv.Atoi(configs[3]) if err != nil || dbnum < 0 { - rp.dbNum = 0 + rp.DbNum = 0 } else { - rp.dbNum = dbnum + rp.DbNum = dbnum } } else { - rp.dbNum = 0 + rp.DbNum = 0 } if len(configs) > 4 { timeout, err := strconv.Atoi(configs[4]) @@ -167,19 +207,9 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath if len(configs) > 6 { retries, err := strconv.Atoi(configs[6]) if err == nil && retries > 0 { - rp.maxRetries = retries + rp.MaxRetries = retries } } - - rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{ - Addrs: strings.Split(rp.savePath, ";"), - Password: rp.password, - PoolSize: rp.poolsize, - IdleTimeout: rp.idleTimeout, - IdleCheckFrequency: rp.idleCheckFrequency, - MaxRetries: rp.maxRetries, - }) - return rp.poollist.Ping().Err() } // SessionRead read redis_cluster session by sid diff --git a/server/web/session/redis_cluster/redis_cluster_test.go b/server/web/session/redis_cluster/redis_cluster_test.go new file mode 100644 index 00000000..0192cd87 --- /dev/null +++ b/server/web/session/redis_cluster/redis_cluster_test.go @@ -0,0 +1,35 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package redis_cluster + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestProvider_SessionInit(t *testing.T) { + + savePath := ` +{ "save_path": "my save path", "idle_timeout": "3s"} +` + cp := &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "my save path", cp.SavePath) + assert.Equal(t, 3*time.Second, cp.idleTimeout) + assert.Equal(t, int64(12), cp.maxlifetime) +} diff --git a/server/web/session/redis_sentinel/sess_redis_sentinel.go b/server/web/session/redis_sentinel/sess_redis_sentinel.go index d68b8767..89d73b86 100644 --- a/server/web/session/redis_sentinel/sess_redis_sentinel.go +++ b/server/web/session/redis_sentinel/sess_redis_sentinel.go @@ -34,6 +34,7 @@ package redis_sentinel import ( "context" + "encoding/json" "net/http" "strconv" "strings" @@ -110,58 +111,99 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite // Provider redis_sentinel session provider type Provider struct { - maxlifetime int64 - savePath string - poolsize int - password string - dbNum int - idleTimeout time.Duration - idleCheckFrequency time.Duration - maxRetries int - poollist *redis.Client - masterName string + maxlifetime int64 + SavePath string `json:"save_path"` + Poolsize int `json:"poolsize"` + Password string `json:"password"` + DbNum int `json:"db_num"` + + idleTimeout time.Duration + IdleTimeoutStr string `json:"idle_timeout"` + + idleCheckFrequency time.Duration + IdleCheckFrequencyStr string `json:"idle_check_frequency"` + MaxRetries int `json:"max_retries"` + poollist *redis.Client + MasterName string `json:"master_name"` } // SessionInit init redis_sentinel session -// savepath like redis sentinel addr,pool size,password,dbnum,masterName +// cfgStr like redis sentinel addr,pool size,password,dbnum,masterName // e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster -func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { +func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error { rp.maxlifetime = maxlifetime + cfgStr = strings.TrimSpace(cfgStr) + // we think cfgStr is v2.0, using json to init the session + if strings.HasPrefix(cfgStr, "{") { + err := json.Unmarshal([]byte(cfgStr), rp) + if err != nil { + return err + } + rp.idleTimeout, err = time.ParseDuration(rp.IdleTimeoutStr) + if err != nil { + return err + } + + rp.idleCheckFrequency, err = time.ParseDuration(rp.IdleCheckFrequencyStr) + if err != nil { + return err + } + + } else { + rp.initOldStyle(cfgStr) + } + + rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{ + SentinelAddrs: strings.Split(rp.SavePath, ";"), + Password: rp.Password, + PoolSize: rp.Poolsize, + DB: rp.DbNum, + MasterName: rp.MasterName, + IdleTimeout: rp.idleTimeout, + IdleCheckFrequency: rp.idleCheckFrequency, + MaxRetries: rp.MaxRetries, + }) + + return rp.poollist.Ping().Err() +} + +// for v1.x +func (rp *Provider) initOldStyle(savePath string) { configs := strings.Split(savePath, ",") if len(configs) > 0 { - rp.savePath = configs[0] + rp.SavePath = configs[0] } if len(configs) > 1 { poolsize, err := strconv.Atoi(configs[1]) if err != nil || poolsize < 0 { - rp.poolsize = DefaultPoolSize + rp.Poolsize = DefaultPoolSize } else { - rp.poolsize = poolsize + rp.Poolsize = poolsize } } else { - rp.poolsize = DefaultPoolSize + rp.Poolsize = DefaultPoolSize } if len(configs) > 2 { - rp.password = configs[2] + rp.Password = configs[2] } if len(configs) > 3 { dbnum, err := strconv.Atoi(configs[3]) if err != nil || dbnum < 0 { - rp.dbNum = 0 + rp.DbNum = 0 } else { - rp.dbNum = dbnum + rp.DbNum = dbnum } } else { - rp.dbNum = 0 + rp.DbNum = 0 } if len(configs) > 4 { if configs[4] != "" { - rp.masterName = configs[4] + rp.MasterName = configs[4] } else { - rp.masterName = "mymaster" + rp.MasterName = "mymaster" } } else { - rp.masterName = "mymaster" + rp.MasterName = "mymaster" } if len(configs) > 5 { timeout, err := strconv.Atoi(configs[4]) @@ -178,22 +220,9 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath if len(configs) > 7 { retries, err := strconv.Atoi(configs[6]) if err == nil && retries > 0 { - rp.maxRetries = retries + rp.MaxRetries = retries } } - - rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{ - SentinelAddrs: strings.Split(rp.savePath, ";"), - Password: rp.password, - PoolSize: rp.poolsize, - DB: rp.dbNum, - MasterName: rp.masterName, - IdleTimeout: rp.idleTimeout, - IdleCheckFrequency: rp.idleCheckFrequency, - MaxRetries: rp.maxRetries, - }) - - return rp.poollist.Ping().Err() } // SessionRead read redis_sentinel session by sid diff --git a/server/web/session/redis_sentinel/sess_redis_sentinel_test.go b/server/web/session/redis_sentinel/sess_redis_sentinel_test.go index e4822d11..f052a14a 100644 --- a/server/web/session/redis_sentinel/sess_redis_sentinel_test.go +++ b/server/web/session/redis_sentinel/sess_redis_sentinel_test.go @@ -1,9 +1,13 @@ package redis_sentinel import ( + "context" "net/http" "net/http/httptest" "testing" + "time" + + "github.com/stretchr/testify/assert" "github.com/astaxie/beego/server/web/session" ) @@ -23,7 +27,7 @@ func TestRedisSentinel(t *testing.T) { t.Log(e) return } - //todo test if e==nil + // todo test if e==nil go globalSessions.GC() r, _ := http.NewRequest("GET", "/", nil) @@ -88,3 +92,15 @@ func TestRedisSentinel(t *testing.T) { sess.SessionRelease(nil, w) } + +func TestProvider_SessionInit(t *testing.T) { + + savePath := ` +{ "save_path": "my save path", "idle_timeout": "3s"} +` + cp := &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "my save path", cp.SavePath) + assert.Equal(t, 3*time.Second, cp.idleTimeout) + assert.Equal(t, int64(12), cp.maxlifetime) +} diff --git a/server/web/session/ssdb/sess_ssdb.go b/server/web/session/ssdb/sess_ssdb.go index 9d1230d0..0adc41bd 100644 --- a/server/web/session/ssdb/sess_ssdb.go +++ b/server/web/session/ssdb/sess_ssdb.go @@ -2,6 +2,7 @@ package ssdb import ( "context" + "encoding/json" "errors" "net/http" "strconv" @@ -18,33 +19,48 @@ var ssdbProvider = &Provider{} // Provider holds ssdb client and configs type Provider struct { client *ssdb.Client - host string - port int + Host string `json:"host"` + Port int `json:"port"` maxLifetime int64 } func (p *Provider) connectInit() error { var err error - if p.host == "" || p.port == 0 { + if p.Host == "" || p.Port == 0 { return errors.New("SessionInit First") } - p.client, err = ssdb.Connect(p.host, p.port) + p.client, err = ssdb.Connect(p.Host, p.Port) return err } // SessionInit init the ssdb with the config -func (p *Provider) SessionInit(ctx context.Context, maxLifetime int64, savePath string) error { +func (p *Provider) SessionInit(ctx context.Context, maxLifetime int64, cfg string) error { p.maxLifetime = maxLifetime - address := strings.Split(savePath, ":") - p.host = address[0] + cfg = strings.TrimSpace(cfg) var err error - if p.port, err = strconv.Atoi(address[1]); err != nil { + // we think this is v2.0, using json to init the session + if strings.HasPrefix(cfg, "{") { + err = json.Unmarshal([]byte(cfg), p) + } else { + err = p.initOldStyle(cfg) + } + if err != nil { return err } return p.connectInit() } +// for v1.x +func (p *Provider) initOldStyle(savePath string) error { + address := strings.Split(savePath, ":") + p.Host = address[0] + + var err error + p.Port, err = strconv.Atoi(address[1]) + return err +} + // SessionRead return a ssdb client session Store func (p *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { if p.client == nil { diff --git a/server/web/session/ssdb/sess_ssdb_test.go b/server/web/session/ssdb/sess_ssdb_test.go new file mode 100644 index 00000000..3de5da0a --- /dev/null +++ b/server/web/session/ssdb/sess_ssdb_test.go @@ -0,0 +1,41 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssdb + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProvider_SessionInit(t *testing.T) { + // using old style + savePath := `localhost:8080` + cp := &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "localhost", cp.Host) + assert.Equal(t, 8080, cp.Port) + assert.Equal(t, int64(12), cp.maxLifetime) + + savePath = ` +{ "host": "localhost", "port": 8080} +` + cp = &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "localhost", cp.Host) + assert.Equal(t, 8080, cp.Port) + assert.Equal(t, int64(12), cp.maxLifetime) +}