support using json string to init session

This commit is contained in:
Ming Deng 2020-10-21 22:12:25 +08:00
parent 03ba495b7f
commit 05f4e0c146
12 changed files with 464 additions and 136 deletions

View File

@ -34,6 +34,7 @@ package couchbase
import ( import (
"context" "context"
"encoding/json"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
@ -57,9 +58,9 @@ type SessionStore struct {
// Provider couchabse provided // Provider couchabse provided
type Provider struct { type Provider struct {
maxlifetime int64 maxlifetime int64
savePath string SavePath string `json:"save_path"`
pool string Pool string `json:"pool"`
bucket string Bucket string `json:"bucket"`
b *couchbase.Bucket b *couchbase.Bucket
} }
@ -115,17 +116,17 @@ func (cs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
} }
func (cp *Provider) getBucket() *couchbase.Bucket { func (cp *Provider) getBucket() *couchbase.Bucket {
c, err := couchbase.Connect(cp.savePath) c, err := couchbase.Connect(cp.SavePath)
if err != nil { if err != nil {
return nil return nil
} }
pool, err := c.GetPool(cp.pool) pool, err := c.GetPool(cp.Pool)
if err != nil { if err != nil {
return nil return nil
} }
bucket, err := pool.GetBucket(cp.bucket) bucket, err := pool.GetBucket(cp.Bucket)
if err != nil { if err != nil {
return nil return nil
} }
@ -135,18 +136,31 @@ func (cp *Provider) getBucket() *couchbase.Bucket {
// SessionInit init couchbase session // SessionInit init couchbase session
// savepath like couchbase server REST/JSON URL // savepath like couchbase server REST/JSON URL
// e.g. http://host:port/, Pool, Bucket // For v1.x e.g. http://host:port/, Pool, Bucket
func (cp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { // For v2.x, you should pass json string.
// e.g. { "save_path": "http://host:port/", "pool": "mypool", "bucket": "mybucket"}
func (cp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfg string) error {
cp.maxlifetime = maxlifetime cp.maxlifetime = maxlifetime
cfg = strings.TrimSpace(cfg)
// we think this is v2.0, using json to init the session
if strings.HasPrefix(cfg, "{") {
return json.Unmarshal([]byte(cfg), cp)
} else {
return cp.initOldStyle(cfg)
}
}
// initOldStyle keep compatible with v1.x
func (cp *Provider) initOldStyle(savePath string) error {
configs := strings.Split(savePath, ",") configs := strings.Split(savePath, ",")
if len(configs) > 0 { if len(configs) > 0 {
cp.savePath = configs[0] cp.SavePath = configs[0]
} }
if len(configs) > 1 { if len(configs) > 1 {
cp.pool = configs[1] cp.Pool = configs[1]
} }
if len(configs) > 2 { if len(configs) > 2 {
cp.bucket = configs[2] cp.Bucket = configs[2]
} }
return nil return nil
@ -225,7 +239,7 @@ func (cp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (
return cs, nil return cs, nil
} }
// SessionDestroy Remove bucket in this couchbase // SessionDestroy Remove Bucket in this couchbase
func (cp *Provider) SessionDestroy(ctx context.Context, sid string) error { func (cp *Provider) SessionDestroy(ctx context.Context, sid string) error {
cp.b = cp.getBucket() cp.b = cp.getBucket()
defer cp.b.Close() defer cp.b.Close()

View File

@ -0,0 +1,43 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package couchbase
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestProvider_SessionInit(t *testing.T) {
// using old style
savePath := `http://host:port/,Pool,Bucket`
cp := &Provider{}
cp.SessionInit(context.Background(), 12, savePath)
assert.Equal(t, "http://host:port/", cp.SavePath)
assert.Equal(t, "Pool", cp.Pool)
assert.Equal(t, "Bucket", cp.Bucket)
assert.Equal(t, int64(12), cp.maxlifetime)
savePath = `
{ "save_path": "my save path", "pool": "mypool", "bucket": "mybucket"}
`
cp = &Provider{}
cp.SessionInit(context.Background(), 12, savePath)
assert.Equal(t, "my save path", cp.SavePath)
assert.Equal(t, "mypool", cp.Pool)
assert.Equal(t, "mybucket", cp.Bucket)
assert.Equal(t, int64(12), cp.maxlifetime)
}

View File

@ -3,6 +3,7 @@ package ledis
import ( import (
"context" "context"
"encoding/json"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@ -79,35 +80,51 @@ func (ls *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
// Provider ledis session provider // Provider ledis session provider
type Provider struct { type Provider struct {
maxlifetime int64 maxlifetime int64
savePath string SavePath string `json:"save_path"`
db int Db int `json:"db"`
} }
// SessionInit init ledis session // SessionInit init ledis session
// savepath like ledis server saveDataPath,pool size // savepath like ledis server saveDataPath,pool size
// e.g. 127.0.0.1:6379,100,astaxie // v1.x e.g. 127.0.0.1:6379,100
func (lp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { // v2.x you should pass a json string
// e.g. { "save_path": "my save path", "db": 100}
func (lp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error {
var err error var err error
lp.maxlifetime = maxlifetime lp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",") cfgStr = strings.TrimSpace(cfgStr)
if len(configs) == 1 { // we think cfgStr is v2.0, using json to init the session
lp.savePath = configs[0] if strings.HasPrefix(cfgStr, "{") {
} else if len(configs) == 2 { err = json.Unmarshal([]byte(cfgStr), lp)
lp.savePath = configs[0] } else {
lp.db, err = strconv.Atoi(configs[1]) err = lp.initOldStyle(cfgStr)
if err != nil {
return err
}
} }
if err != nil {
return err
}
cfg := new(config.Config) cfg := new(config.Config)
cfg.DataDir = lp.savePath cfg.DataDir = lp.SavePath
var ledisInstance *ledis.Ledis var ledisInstance *ledis.Ledis
ledisInstance, err = ledis.Open(cfg) ledisInstance, err = ledis.Open(cfg)
if err != nil { if err != nil {
return err return err
} }
c, err = ledisInstance.Select(lp.db) c, err = ledisInstance.Select(lp.Db)
return err
}
func (lp *Provider) initOldStyle(cfgStr string) error {
var err error
configs := strings.Split(cfgStr, ",")
if len(configs) == 1 {
lp.SavePath = configs[0]
} else if len(configs) == 2 {
lp.SavePath = configs[0]
lp.Db, err = strconv.Atoi(configs[1])
}
return err return err
} }

View File

@ -0,0 +1,41 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ledis
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestProvider_SessionInit(t *testing.T) {
// using old style
savePath := `http://host:port/,100`
cp := &Provider{}
cp.SessionInit(context.Background(), 12, savePath)
assert.Equal(t, "http://host:port/", cp.SavePath)
assert.Equal(t, 100, cp.Db)
assert.Equal(t, int64(12), cp.maxlifetime)
savePath = `
{ "save_path": "my save path", "db": 100}
`
cp = &Provider{}
cp.SessionInit(context.Background(), 12, savePath)
assert.Equal(t, "my save path", cp.SavePath)
assert.Equal(t, 100, cp.Db)
assert.Equal(t, int64(12), cp.maxlifetime)
}

View File

@ -34,6 +34,7 @@ package redis
import ( import (
"context" "context"
"encoding/json"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@ -110,48 +111,89 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
// Provider redis session provider // Provider redis session provider
type Provider struct { type Provider struct {
maxlifetime int64 maxlifetime int64
savePath string SavePath string `json:"save_path"`
poolsize int Poolsize int `json:"poolsize"`
password string Password string `json:"password"`
dbNum int DbNum int `json:"db_num"`
idleTimeout time.Duration
idleCheckFrequency time.Duration idleTimeout time.Duration
maxRetries int IdleTimeoutStr string `json:"idle_timeout"`
poollist *redis.Client
idleCheckFrequency time.Duration
IdleCheckFrequencyStr string `json:"idle_check_frequency"`
MaxRetries int `json:"max_retries"`
poollist *redis.Client
} }
// SessionInit init redis session // SessionInit init redis session
// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second // savepath like redis server addr,pool size,password,dbnum,IdleTimeout second
// e.g. 127.0.0.1:6379,100,astaxie,0,30 // v1.x e.g. 127.0.0.1:6379,100,astaxie,0,30
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { // v2.0 you should pass json string
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error {
rp.maxlifetime = maxlifetime rp.maxlifetime = maxlifetime
cfgStr = strings.TrimSpace(cfgStr)
// we think cfgStr is v2.0, using json to init the session
if strings.HasPrefix(cfgStr, "{") {
err := json.Unmarshal([]byte(cfgStr), rp)
if err != nil {
return err
}
rp.idleTimeout, err = time.ParseDuration(rp.IdleTimeoutStr)
if err != nil {
return err
}
rp.idleCheckFrequency, err = time.ParseDuration(rp.IdleCheckFrequencyStr)
if err != nil {
return err
}
} else {
rp.initOldStyle(cfgStr)
}
rp.poollist = redis.NewClient(&redis.Options{
Addr: rp.SavePath,
Password: rp.Password,
PoolSize: rp.Poolsize,
DB: rp.DbNum,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.MaxRetries,
})
return rp.poollist.Ping().Err()
}
func (rp *Provider) initOldStyle(savePath string) {
configs := strings.Split(savePath, ",") configs := strings.Split(savePath, ",")
if len(configs) > 0 { if len(configs) > 0 {
rp.savePath = configs[0] rp.SavePath = configs[0]
} }
if len(configs) > 1 { if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1]) poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize < 0 { if err != nil || poolsize < 0 {
rp.poolsize = MaxPoolSize rp.Poolsize = MaxPoolSize
} else { } else {
rp.poolsize = poolsize rp.Poolsize = poolsize
} }
} else { } else {
rp.poolsize = MaxPoolSize rp.Poolsize = MaxPoolSize
} }
if len(configs) > 2 { if len(configs) > 2 {
rp.password = configs[2] rp.Password = configs[2]
} }
if len(configs) > 3 { if len(configs) > 3 {
dbnum, err := strconv.Atoi(configs[3]) dbnum, err := strconv.Atoi(configs[3])
if err != nil || dbnum < 0 { if err != nil || dbnum < 0 {
rp.dbNum = 0 rp.DbNum = 0
} else { } else {
rp.dbNum = dbnum rp.DbNum = dbnum
} }
} else { } else {
rp.dbNum = 0 rp.DbNum = 0
} }
if len(configs) > 4 { if len(configs) > 4 {
timeout, err := strconv.Atoi(configs[4]) timeout, err := strconv.Atoi(configs[4])
@ -168,21 +210,9 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath
if len(configs) > 6 { if len(configs) > 6 {
retries, err := strconv.Atoi(configs[6]) retries, err := strconv.Atoi(configs[6])
if err == nil && retries > 0 { if err == nil && retries > 0 {
rp.maxRetries = retries rp.MaxRetries = retries
} }
} }
rp.poollist = redis.NewClient(&redis.Options{
Addr: rp.savePath,
Password: rp.password,
PoolSize: rp.poolsize,
DB: rp.dbNum,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.maxRetries,
})
return rp.poollist.Ping().Err()
} }
// SessionRead read redis session by sid // SessionRead read redis session by sid

View File

@ -1,11 +1,15 @@
package redis package redis
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/astaxie/beego/server/web/session" "github.com/astaxie/beego/server/web/session"
) )
@ -94,3 +98,15 @@ func TestRedis(t *testing.T) {
sess.SessionRelease(nil, w) sess.SessionRelease(nil, w)
} }
func TestProvider_SessionInit(t *testing.T) {
savePath := `
{ "save_path": "my save path", "idle_timeout": "3s"}
`
cp := &Provider{}
cp.SessionInit(context.Background(), 12, savePath)
assert.Equal(t, "my save path", cp.SavePath)
assert.Equal(t, 3*time.Second, cp.idleTimeout)
assert.Equal(t, int64(12), cp.maxlifetime)
}

View File

@ -34,14 +34,16 @@ package redis_cluster
import ( import (
"context" "context"
"encoding/json"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/astaxie/beego/server/web/session"
rediss "github.com/go-redis/redis/v7" rediss "github.com/go-redis/redis/v7"
"github.com/astaxie/beego/server/web/session"
) )
var redispder = &Provider{} var redispder = &Provider{}
@ -109,48 +111,86 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
// Provider redis_cluster session provider // Provider redis_cluster session provider
type Provider struct { type Provider struct {
maxlifetime int64 maxlifetime int64
savePath string SavePath string `json:"save_path"`
poolsize int Poolsize int `json:"poolsize"`
password string Password string `json:"password"`
dbNum int DbNum int `json:"db_num"`
idleTimeout time.Duration
idleCheckFrequency time.Duration idleTimeout time.Duration
maxRetries int IdleTimeoutStr string `json:"idle_timeout"`
poollist *rediss.ClusterClient
idleCheckFrequency time.Duration
IdleCheckFrequencyStr string `json:"idle_check_frequency"`
MaxRetries int `json:"max_retries"`
poollist *rediss.ClusterClient
} }
// SessionInit init redis_cluster session // SessionInit init redis_cluster session
// savepath like redis server addr,pool size,password,dbnum // cfgStr like redis server addr,pool size,password,dbnum
// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0 // e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error {
rp.maxlifetime = maxlifetime rp.maxlifetime = maxlifetime
cfgStr = strings.TrimSpace(cfgStr)
// we think cfgStr is v2.0, using json to init the session
if strings.HasPrefix(cfgStr, "{") {
err := json.Unmarshal([]byte(cfgStr), rp)
if err != nil {
return err
}
rp.idleTimeout, err = time.ParseDuration(rp.IdleTimeoutStr)
if err != nil {
return err
}
rp.idleCheckFrequency, err = time.ParseDuration(rp.IdleCheckFrequencyStr)
if err != nil {
return err
}
} else {
rp.initOldStyle(cfgStr)
}
rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{
Addrs: strings.Split(rp.SavePath, ";"),
Password: rp.Password,
PoolSize: rp.Poolsize,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.MaxRetries,
})
return rp.poollist.Ping().Err()
}
// for v1.x
func (rp *Provider) initOldStyle(savePath string) {
configs := strings.Split(savePath, ",") configs := strings.Split(savePath, ",")
if len(configs) > 0 { if len(configs) > 0 {
rp.savePath = configs[0] rp.SavePath = configs[0]
} }
if len(configs) > 1 { if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1]) poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize < 0 { if err != nil || poolsize < 0 {
rp.poolsize = MaxPoolSize rp.Poolsize = MaxPoolSize
} else { } else {
rp.poolsize = poolsize rp.Poolsize = poolsize
} }
} else { } else {
rp.poolsize = MaxPoolSize rp.Poolsize = MaxPoolSize
} }
if len(configs) > 2 { if len(configs) > 2 {
rp.password = configs[2] rp.Password = configs[2]
} }
if len(configs) > 3 { if len(configs) > 3 {
dbnum, err := strconv.Atoi(configs[3]) dbnum, err := strconv.Atoi(configs[3])
if err != nil || dbnum < 0 { if err != nil || dbnum < 0 {
rp.dbNum = 0 rp.DbNum = 0
} else { } else {
rp.dbNum = dbnum rp.DbNum = dbnum
} }
} else { } else {
rp.dbNum = 0 rp.DbNum = 0
} }
if len(configs) > 4 { if len(configs) > 4 {
timeout, err := strconv.Atoi(configs[4]) timeout, err := strconv.Atoi(configs[4])
@ -167,19 +207,9 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath
if len(configs) > 6 { if len(configs) > 6 {
retries, err := strconv.Atoi(configs[6]) retries, err := strconv.Atoi(configs[6])
if err == nil && retries > 0 { if err == nil && retries > 0 {
rp.maxRetries = retries rp.MaxRetries = retries
} }
} }
rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{
Addrs: strings.Split(rp.savePath, ";"),
Password: rp.password,
PoolSize: rp.poolsize,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.maxRetries,
})
return rp.poollist.Ping().Err()
} }
// SessionRead read redis_cluster session by sid // SessionRead read redis_cluster session by sid

View File

@ -0,0 +1,35 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package redis_cluster
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestProvider_SessionInit(t *testing.T) {
savePath := `
{ "save_path": "my save path", "idle_timeout": "3s"}
`
cp := &Provider{}
cp.SessionInit(context.Background(), 12, savePath)
assert.Equal(t, "my save path", cp.SavePath)
assert.Equal(t, 3*time.Second, cp.idleTimeout)
assert.Equal(t, int64(12), cp.maxlifetime)
}

View File

@ -34,6 +34,7 @@ package redis_sentinel
import ( import (
"context" "context"
"encoding/json"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@ -110,58 +111,99 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
// Provider redis_sentinel session provider // Provider redis_sentinel session provider
type Provider struct { type Provider struct {
maxlifetime int64 maxlifetime int64
savePath string SavePath string `json:"save_path"`
poolsize int Poolsize int `json:"poolsize"`
password string Password string `json:"password"`
dbNum int DbNum int `json:"db_num"`
idleTimeout time.Duration
idleCheckFrequency time.Duration idleTimeout time.Duration
maxRetries int IdleTimeoutStr string `json:"idle_timeout"`
poollist *redis.Client
masterName string idleCheckFrequency time.Duration
IdleCheckFrequencyStr string `json:"idle_check_frequency"`
MaxRetries int `json:"max_retries"`
poollist *redis.Client
MasterName string `json:"master_name"`
} }
// SessionInit init redis_sentinel session // SessionInit init redis_sentinel session
// savepath like redis sentinel addr,pool size,password,dbnum,masterName // cfgStr like redis sentinel addr,pool size,password,dbnum,masterName
// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster // e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error {
rp.maxlifetime = maxlifetime rp.maxlifetime = maxlifetime
cfgStr = strings.TrimSpace(cfgStr)
// we think cfgStr is v2.0, using json to init the session
if strings.HasPrefix(cfgStr, "{") {
err := json.Unmarshal([]byte(cfgStr), rp)
if err != nil {
return err
}
rp.idleTimeout, err = time.ParseDuration(rp.IdleTimeoutStr)
if err != nil {
return err
}
rp.idleCheckFrequency, err = time.ParseDuration(rp.IdleCheckFrequencyStr)
if err != nil {
return err
}
} else {
rp.initOldStyle(cfgStr)
}
rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{
SentinelAddrs: strings.Split(rp.SavePath, ";"),
Password: rp.Password,
PoolSize: rp.Poolsize,
DB: rp.DbNum,
MasterName: rp.MasterName,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.MaxRetries,
})
return rp.poollist.Ping().Err()
}
// for v1.x
func (rp *Provider) initOldStyle(savePath string) {
configs := strings.Split(savePath, ",") configs := strings.Split(savePath, ",")
if len(configs) > 0 { if len(configs) > 0 {
rp.savePath = configs[0] rp.SavePath = configs[0]
} }
if len(configs) > 1 { if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1]) poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize < 0 { if err != nil || poolsize < 0 {
rp.poolsize = DefaultPoolSize rp.Poolsize = DefaultPoolSize
} else { } else {
rp.poolsize = poolsize rp.Poolsize = poolsize
} }
} else { } else {
rp.poolsize = DefaultPoolSize rp.Poolsize = DefaultPoolSize
} }
if len(configs) > 2 { if len(configs) > 2 {
rp.password = configs[2] rp.Password = configs[2]
} }
if len(configs) > 3 { if len(configs) > 3 {
dbnum, err := strconv.Atoi(configs[3]) dbnum, err := strconv.Atoi(configs[3])
if err != nil || dbnum < 0 { if err != nil || dbnum < 0 {
rp.dbNum = 0 rp.DbNum = 0
} else { } else {
rp.dbNum = dbnum rp.DbNum = dbnum
} }
} else { } else {
rp.dbNum = 0 rp.DbNum = 0
} }
if len(configs) > 4 { if len(configs) > 4 {
if configs[4] != "" { if configs[4] != "" {
rp.masterName = configs[4] rp.MasterName = configs[4]
} else { } else {
rp.masterName = "mymaster" rp.MasterName = "mymaster"
} }
} else { } else {
rp.masterName = "mymaster" rp.MasterName = "mymaster"
} }
if len(configs) > 5 { if len(configs) > 5 {
timeout, err := strconv.Atoi(configs[4]) timeout, err := strconv.Atoi(configs[4])
@ -178,22 +220,9 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath
if len(configs) > 7 { if len(configs) > 7 {
retries, err := strconv.Atoi(configs[6]) retries, err := strconv.Atoi(configs[6])
if err == nil && retries > 0 { if err == nil && retries > 0 {
rp.maxRetries = retries rp.MaxRetries = retries
} }
} }
rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{
SentinelAddrs: strings.Split(rp.savePath, ";"),
Password: rp.password,
PoolSize: rp.poolsize,
DB: rp.dbNum,
MasterName: rp.masterName,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.maxRetries,
})
return rp.poollist.Ping().Err()
} }
// SessionRead read redis_sentinel session by sid // SessionRead read redis_sentinel session by sid

View File

@ -1,9 +1,13 @@
package redis_sentinel package redis_sentinel
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/astaxie/beego/server/web/session" "github.com/astaxie/beego/server/web/session"
) )
@ -23,7 +27,7 @@ func TestRedisSentinel(t *testing.T) {
t.Log(e) t.Log(e)
return return
} }
//todo test if e==nil // todo test if e==nil
go globalSessions.GC() go globalSessions.GC()
r, _ := http.NewRequest("GET", "/", nil) r, _ := http.NewRequest("GET", "/", nil)
@ -88,3 +92,15 @@ func TestRedisSentinel(t *testing.T) {
sess.SessionRelease(nil, w) sess.SessionRelease(nil, w)
} }
func TestProvider_SessionInit(t *testing.T) {
savePath := `
{ "save_path": "my save path", "idle_timeout": "3s"}
`
cp := &Provider{}
cp.SessionInit(context.Background(), 12, savePath)
assert.Equal(t, "my save path", cp.SavePath)
assert.Equal(t, 3*time.Second, cp.idleTimeout)
assert.Equal(t, int64(12), cp.maxlifetime)
}

View File

@ -2,6 +2,7 @@ package ssdb
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"net/http" "net/http"
"strconv" "strconv"
@ -18,33 +19,48 @@ var ssdbProvider = &Provider{}
// Provider holds ssdb client and configs // Provider holds ssdb client and configs
type Provider struct { type Provider struct {
client *ssdb.Client client *ssdb.Client
host string Host string `json:"host"`
port int Port int `json:"port"`
maxLifetime int64 maxLifetime int64
} }
func (p *Provider) connectInit() error { func (p *Provider) connectInit() error {
var err error var err error
if p.host == "" || p.port == 0 { if p.Host == "" || p.Port == 0 {
return errors.New("SessionInit First") return errors.New("SessionInit First")
} }
p.client, err = ssdb.Connect(p.host, p.port) p.client, err = ssdb.Connect(p.Host, p.Port)
return err return err
} }
// SessionInit init the ssdb with the config // SessionInit init the ssdb with the config
func (p *Provider) SessionInit(ctx context.Context, maxLifetime int64, savePath string) error { func (p *Provider) SessionInit(ctx context.Context, maxLifetime int64, cfg string) error {
p.maxLifetime = maxLifetime p.maxLifetime = maxLifetime
address := strings.Split(savePath, ":")
p.host = address[0]
cfg = strings.TrimSpace(cfg)
var err error var err error
if p.port, err = strconv.Atoi(address[1]); err != nil { // we think this is v2.0, using json to init the session
if strings.HasPrefix(cfg, "{") {
err = json.Unmarshal([]byte(cfg), p)
} else {
err = p.initOldStyle(cfg)
}
if err != nil {
return err return err
} }
return p.connectInit() return p.connectInit()
} }
// for v1.x
func (p *Provider) initOldStyle(savePath string) error {
address := strings.Split(savePath, ":")
p.Host = address[0]
var err error
p.Port, err = strconv.Atoi(address[1])
return err
}
// SessionRead return a ssdb client session Store // SessionRead return a ssdb client session Store
func (p *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { func (p *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
if p.client == nil { if p.client == nil {

View File

@ -0,0 +1,41 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ssdb
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestProvider_SessionInit(t *testing.T) {
// using old style
savePath := `localhost:8080`
cp := &Provider{}
cp.SessionInit(context.Background(), 12, savePath)
assert.Equal(t, "localhost", cp.Host)
assert.Equal(t, 8080, cp.Port)
assert.Equal(t, int64(12), cp.maxLifetime)
savePath = `
{ "host": "localhost", "port": 8080}
`
cp = &Provider{}
cp.SessionInit(context.Background(), 12, savePath)
assert.Equal(t, "localhost", cp.Host)
assert.Equal(t, 8080, cp.Port)
assert.Equal(t, int64(12), cp.maxLifetime)
}