1
0
mirror of https://github.com/astaxie/beego.git synced 2024-11-23 10:20:54 +00:00

Merge pull request #4277 from flycash/session

support using json string to init session
This commit is contained in:
Ming Deng 2020-10-22 09:53:55 +08:00 committed by GitHub
commit 02234dc503
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 464 additions and 136 deletions

View File

@ -34,6 +34,7 @@ package couchbase
import (
"context"
"encoding/json"
"net/http"
"strings"
"sync"
@ -57,9 +58,9 @@ type SessionStore struct {
// Provider couchabse provided
type Provider struct {
maxlifetime int64
savePath string
pool string
bucket string
SavePath string `json:"save_path"`
Pool string `json:"pool"`
Bucket string `json:"bucket"`
b *couchbase.Bucket
}
@ -115,17 +116,17 @@ func (cs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
}
func (cp *Provider) getBucket() *couchbase.Bucket {
c, err := couchbase.Connect(cp.savePath)
c, err := couchbase.Connect(cp.SavePath)
if err != nil {
return nil
}
pool, err := c.GetPool(cp.pool)
pool, err := c.GetPool(cp.Pool)
if err != nil {
return nil
}
bucket, err := pool.GetBucket(cp.bucket)
bucket, err := pool.GetBucket(cp.Bucket)
if err != nil {
return nil
}
@ -135,18 +136,31 @@ func (cp *Provider) getBucket() *couchbase.Bucket {
// SessionInit init couchbase session
// savepath like couchbase server REST/JSON URL
// e.g. http://host:port/, Pool, Bucket
func (cp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
// For v1.x e.g. http://host:port/, Pool, Bucket
// For v2.x, you should pass json string.
// e.g. { "save_path": "http://host:port/", "pool": "mypool", "bucket": "mybucket"}
func (cp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfg string) error {
cp.maxlifetime = maxlifetime
cfg = strings.TrimSpace(cfg)
// we think this is v2.0, using json to init the session
if strings.HasPrefix(cfg, "{") {
return json.Unmarshal([]byte(cfg), cp)
} else {
return cp.initOldStyle(cfg)
}
}
// initOldStyle keep compatible with v1.x
func (cp *Provider) initOldStyle(savePath string) error {
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
cp.savePath = configs[0]
cp.SavePath = configs[0]
}
if len(configs) > 1 {
cp.pool = configs[1]
cp.Pool = configs[1]
}
if len(configs) > 2 {
cp.bucket = configs[2]
cp.Bucket = configs[2]
}
return nil
@ -225,7 +239,7 @@ func (cp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (
return cs, nil
}
// SessionDestroy Remove bucket in this couchbase
// SessionDestroy Remove Bucket in this couchbase
func (cp *Provider) SessionDestroy(ctx context.Context, sid string) error {
cp.b = cp.getBucket()
defer cp.b.Close()

View File

@ -0,0 +1,43 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package couchbase
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestProvider_SessionInit(t *testing.T) {
// using old style
savePath := `http://host:port/,Pool,Bucket`
cp := &Provider{}
cp.SessionInit(context.Background(), 12, savePath)
assert.Equal(t, "http://host:port/", cp.SavePath)
assert.Equal(t, "Pool", cp.Pool)
assert.Equal(t, "Bucket", cp.Bucket)
assert.Equal(t, int64(12), cp.maxlifetime)
savePath = `
{ "save_path": "my save path", "pool": "mypool", "bucket": "mybucket"}
`
cp = &Provider{}
cp.SessionInit(context.Background(), 12, savePath)
assert.Equal(t, "my save path", cp.SavePath)
assert.Equal(t, "mypool", cp.Pool)
assert.Equal(t, "mybucket", cp.Bucket)
assert.Equal(t, int64(12), cp.maxlifetime)
}

View File

@ -3,6 +3,7 @@ package ledis
import (
"context"
"encoding/json"
"net/http"
"strconv"
"strings"
@ -79,35 +80,51 @@ func (ls *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
// Provider ledis session provider
type Provider struct {
maxlifetime int64
savePath string
db int
SavePath string `json:"save_path"`
Db int `json:"db"`
}
// SessionInit init ledis session
// savepath like ledis server saveDataPath,pool size
// e.g. 127.0.0.1:6379,100,astaxie
func (lp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
// v1.x e.g. 127.0.0.1:6379,100
// v2.x you should pass a json string
// e.g. { "save_path": "my save path", "db": 100}
func (lp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error {
var err error
lp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) == 1 {
lp.savePath = configs[0]
} else if len(configs) == 2 {
lp.savePath = configs[0]
lp.db, err = strconv.Atoi(configs[1])
if err != nil {
return err
}
cfgStr = strings.TrimSpace(cfgStr)
// we think cfgStr is v2.0, using json to init the session
if strings.HasPrefix(cfgStr, "{") {
err = json.Unmarshal([]byte(cfgStr), lp)
} else {
err = lp.initOldStyle(cfgStr)
}
if err != nil {
return err
}
cfg := new(config.Config)
cfg.DataDir = lp.savePath
cfg.DataDir = lp.SavePath
var ledisInstance *ledis.Ledis
ledisInstance, err = ledis.Open(cfg)
if err != nil {
return err
}
c, err = ledisInstance.Select(lp.db)
c, err = ledisInstance.Select(lp.Db)
return err
}
func (lp *Provider) initOldStyle(cfgStr string) error {
var err error
configs := strings.Split(cfgStr, ",")
if len(configs) == 1 {
lp.SavePath = configs[0]
} else if len(configs) == 2 {
lp.SavePath = configs[0]
lp.Db, err = strconv.Atoi(configs[1])
}
return err
}

View File

@ -0,0 +1,41 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ledis
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestProvider_SessionInit(t *testing.T) {
// using old style
savePath := `http://host:port/,100`
cp := &Provider{}
cp.SessionInit(context.Background(), 12, savePath)
assert.Equal(t, "http://host:port/", cp.SavePath)
assert.Equal(t, 100, cp.Db)
assert.Equal(t, int64(12), cp.maxlifetime)
savePath = `
{ "save_path": "my save path", "db": 100}
`
cp = &Provider{}
cp.SessionInit(context.Background(), 12, savePath)
assert.Equal(t, "my save path", cp.SavePath)
assert.Equal(t, 100, cp.Db)
assert.Equal(t, int64(12), cp.maxlifetime)
}

View File

@ -34,6 +34,7 @@ package redis
import (
"context"
"encoding/json"
"net/http"
"strconv"
"strings"
@ -110,48 +111,89 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
// Provider redis session provider
type Provider struct {
maxlifetime int64
savePath string
poolsize int
password string
dbNum int
idleTimeout time.Duration
idleCheckFrequency time.Duration
maxRetries int
poollist *redis.Client
maxlifetime int64
SavePath string `json:"save_path"`
Poolsize int `json:"poolsize"`
Password string `json:"password"`
DbNum int `json:"db_num"`
idleTimeout time.Duration
IdleTimeoutStr string `json:"idle_timeout"`
idleCheckFrequency time.Duration
IdleCheckFrequencyStr string `json:"idle_check_frequency"`
MaxRetries int `json:"max_retries"`
poollist *redis.Client
}
// SessionInit init redis session
// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second
// e.g. 127.0.0.1:6379,100,astaxie,0,30
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
// v1.x e.g. 127.0.0.1:6379,100,astaxie,0,30
// v2.0 you should pass json string
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error {
rp.maxlifetime = maxlifetime
cfgStr = strings.TrimSpace(cfgStr)
// we think cfgStr is v2.0, using json to init the session
if strings.HasPrefix(cfgStr, "{") {
err := json.Unmarshal([]byte(cfgStr), rp)
if err != nil {
return err
}
rp.idleTimeout, err = time.ParseDuration(rp.IdleTimeoutStr)
if err != nil {
return err
}
rp.idleCheckFrequency, err = time.ParseDuration(rp.IdleCheckFrequencyStr)
if err != nil {
return err
}
} else {
rp.initOldStyle(cfgStr)
}
rp.poollist = redis.NewClient(&redis.Options{
Addr: rp.SavePath,
Password: rp.Password,
PoolSize: rp.Poolsize,
DB: rp.DbNum,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.MaxRetries,
})
return rp.poollist.Ping().Err()
}
func (rp *Provider) initOldStyle(savePath string) {
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
rp.savePath = configs[0]
rp.SavePath = configs[0]
}
if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize < 0 {
rp.poolsize = MaxPoolSize
rp.Poolsize = MaxPoolSize
} else {
rp.poolsize = poolsize
rp.Poolsize = poolsize
}
} else {
rp.poolsize = MaxPoolSize
rp.Poolsize = MaxPoolSize
}
if len(configs) > 2 {
rp.password = configs[2]
rp.Password = configs[2]
}
if len(configs) > 3 {
dbnum, err := strconv.Atoi(configs[3])
if err != nil || dbnum < 0 {
rp.dbNum = 0
rp.DbNum = 0
} else {
rp.dbNum = dbnum
rp.DbNum = dbnum
}
} else {
rp.dbNum = 0
rp.DbNum = 0
}
if len(configs) > 4 {
timeout, err := strconv.Atoi(configs[4])
@ -168,21 +210,9 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath
if len(configs) > 6 {
retries, err := strconv.Atoi(configs[6])
if err == nil && retries > 0 {
rp.maxRetries = retries
rp.MaxRetries = retries
}
}
rp.poollist = redis.NewClient(&redis.Options{
Addr: rp.savePath,
Password: rp.password,
PoolSize: rp.poolsize,
DB: rp.dbNum,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.maxRetries,
})
return rp.poollist.Ping().Err()
}
// SessionRead read redis session by sid

View File

@ -1,11 +1,15 @@
package redis
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/astaxie/beego/server/web/session"
)
@ -94,3 +98,15 @@ func TestRedis(t *testing.T) {
sess.SessionRelease(nil, w)
}
func TestProvider_SessionInit(t *testing.T) {
savePath := `
{ "save_path": "my save path", "idle_timeout": "3s"}
`
cp := &Provider{}
cp.SessionInit(context.Background(), 12, savePath)
assert.Equal(t, "my save path", cp.SavePath)
assert.Equal(t, 3*time.Second, cp.idleTimeout)
assert.Equal(t, int64(12), cp.maxlifetime)
}

View File

@ -34,14 +34,16 @@ package redis_cluster
import (
"context"
"encoding/json"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/astaxie/beego/server/web/session"
rediss "github.com/go-redis/redis/v7"
"github.com/astaxie/beego/server/web/session"
)
var redispder = &Provider{}
@ -109,48 +111,86 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
// Provider redis_cluster session provider
type Provider struct {
maxlifetime int64
savePath string
poolsize int
password string
dbNum int
idleTimeout time.Duration
idleCheckFrequency time.Duration
maxRetries int
poollist *rediss.ClusterClient
maxlifetime int64
SavePath string `json:"save_path"`
Poolsize int `json:"poolsize"`
Password string `json:"password"`
DbNum int `json:"db_num"`
idleTimeout time.Duration
IdleTimeoutStr string `json:"idle_timeout"`
idleCheckFrequency time.Duration
IdleCheckFrequencyStr string `json:"idle_check_frequency"`
MaxRetries int `json:"max_retries"`
poollist *rediss.ClusterClient
}
// SessionInit init redis_cluster session
// savepath like redis server addr,pool size,password,dbnum
// cfgStr like redis server addr,pool size,password,dbnum
// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error {
rp.maxlifetime = maxlifetime
cfgStr = strings.TrimSpace(cfgStr)
// we think cfgStr is v2.0, using json to init the session
if strings.HasPrefix(cfgStr, "{") {
err := json.Unmarshal([]byte(cfgStr), rp)
if err != nil {
return err
}
rp.idleTimeout, err = time.ParseDuration(rp.IdleTimeoutStr)
if err != nil {
return err
}
rp.idleCheckFrequency, err = time.ParseDuration(rp.IdleCheckFrequencyStr)
if err != nil {
return err
}
} else {
rp.initOldStyle(cfgStr)
}
rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{
Addrs: strings.Split(rp.SavePath, ";"),
Password: rp.Password,
PoolSize: rp.Poolsize,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.MaxRetries,
})
return rp.poollist.Ping().Err()
}
// for v1.x
func (rp *Provider) initOldStyle(savePath string) {
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
rp.savePath = configs[0]
rp.SavePath = configs[0]
}
if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize < 0 {
rp.poolsize = MaxPoolSize
rp.Poolsize = MaxPoolSize
} else {
rp.poolsize = poolsize
rp.Poolsize = poolsize
}
} else {
rp.poolsize = MaxPoolSize
rp.Poolsize = MaxPoolSize
}
if len(configs) > 2 {
rp.password = configs[2]
rp.Password = configs[2]
}
if len(configs) > 3 {
dbnum, err := strconv.Atoi(configs[3])
if err != nil || dbnum < 0 {
rp.dbNum = 0
rp.DbNum = 0
} else {
rp.dbNum = dbnum
rp.DbNum = dbnum
}
} else {
rp.dbNum = 0
rp.DbNum = 0
}
if len(configs) > 4 {
timeout, err := strconv.Atoi(configs[4])
@ -167,19 +207,9 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath
if len(configs) > 6 {
retries, err := strconv.Atoi(configs[6])
if err == nil && retries > 0 {
rp.maxRetries = retries
rp.MaxRetries = retries
}
}
rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{
Addrs: strings.Split(rp.savePath, ";"),
Password: rp.password,
PoolSize: rp.poolsize,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.maxRetries,
})
return rp.poollist.Ping().Err()
}
// SessionRead read redis_cluster session by sid

View File

@ -0,0 +1,35 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package redis_cluster
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestProvider_SessionInit(t *testing.T) {
savePath := `
{ "save_path": "my save path", "idle_timeout": "3s"}
`
cp := &Provider{}
cp.SessionInit(context.Background(), 12, savePath)
assert.Equal(t, "my save path", cp.SavePath)
assert.Equal(t, 3*time.Second, cp.idleTimeout)
assert.Equal(t, int64(12), cp.maxlifetime)
}

View File

@ -34,6 +34,7 @@ package redis_sentinel
import (
"context"
"encoding/json"
"net/http"
"strconv"
"strings"
@ -110,58 +111,99 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
// Provider redis_sentinel session provider
type Provider struct {
maxlifetime int64
savePath string
poolsize int
password string
dbNum int
idleTimeout time.Duration
idleCheckFrequency time.Duration
maxRetries int
poollist *redis.Client
masterName string
maxlifetime int64
SavePath string `json:"save_path"`
Poolsize int `json:"poolsize"`
Password string `json:"password"`
DbNum int `json:"db_num"`
idleTimeout time.Duration
IdleTimeoutStr string `json:"idle_timeout"`
idleCheckFrequency time.Duration
IdleCheckFrequencyStr string `json:"idle_check_frequency"`
MaxRetries int `json:"max_retries"`
poollist *redis.Client
MasterName string `json:"master_name"`
}
// SessionInit init redis_sentinel session
// savepath like redis sentinel addr,pool size,password,dbnum,masterName
// cfgStr like redis sentinel addr,pool size,password,dbnum,masterName
// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error {
rp.maxlifetime = maxlifetime
cfgStr = strings.TrimSpace(cfgStr)
// we think cfgStr is v2.0, using json to init the session
if strings.HasPrefix(cfgStr, "{") {
err := json.Unmarshal([]byte(cfgStr), rp)
if err != nil {
return err
}
rp.idleTimeout, err = time.ParseDuration(rp.IdleTimeoutStr)
if err != nil {
return err
}
rp.idleCheckFrequency, err = time.ParseDuration(rp.IdleCheckFrequencyStr)
if err != nil {
return err
}
} else {
rp.initOldStyle(cfgStr)
}
rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{
SentinelAddrs: strings.Split(rp.SavePath, ";"),
Password: rp.Password,
PoolSize: rp.Poolsize,
DB: rp.DbNum,
MasterName: rp.MasterName,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.MaxRetries,
})
return rp.poollist.Ping().Err()
}
// for v1.x
func (rp *Provider) initOldStyle(savePath string) {
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
rp.savePath = configs[0]
rp.SavePath = configs[0]
}
if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize < 0 {
rp.poolsize = DefaultPoolSize
rp.Poolsize = DefaultPoolSize
} else {
rp.poolsize = poolsize
rp.Poolsize = poolsize
}
} else {
rp.poolsize = DefaultPoolSize
rp.Poolsize = DefaultPoolSize
}
if len(configs) > 2 {
rp.password = configs[2]
rp.Password = configs[2]
}
if len(configs) > 3 {
dbnum, err := strconv.Atoi(configs[3])
if err != nil || dbnum < 0 {
rp.dbNum = 0
rp.DbNum = 0
} else {
rp.dbNum = dbnum
rp.DbNum = dbnum
}
} else {
rp.dbNum = 0
rp.DbNum = 0
}
if len(configs) > 4 {
if configs[4] != "" {
rp.masterName = configs[4]
rp.MasterName = configs[4]
} else {
rp.masterName = "mymaster"
rp.MasterName = "mymaster"
}
} else {
rp.masterName = "mymaster"
rp.MasterName = "mymaster"
}
if len(configs) > 5 {
timeout, err := strconv.Atoi(configs[4])
@ -178,22 +220,9 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath
if len(configs) > 7 {
retries, err := strconv.Atoi(configs[6])
if err == nil && retries > 0 {
rp.maxRetries = retries
rp.MaxRetries = retries
}
}
rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{
SentinelAddrs: strings.Split(rp.savePath, ";"),
Password: rp.password,
PoolSize: rp.poolsize,
DB: rp.dbNum,
MasterName: rp.masterName,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.maxRetries,
})
return rp.poollist.Ping().Err()
}
// SessionRead read redis_sentinel session by sid

View File

@ -1,9 +1,13 @@
package redis_sentinel
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/astaxie/beego/server/web/session"
)
@ -23,7 +27,7 @@ func TestRedisSentinel(t *testing.T) {
t.Log(e)
return
}
//todo test if e==nil
// todo test if e==nil
go globalSessions.GC()
r, _ := http.NewRequest("GET", "/", nil)
@ -88,3 +92,15 @@ func TestRedisSentinel(t *testing.T) {
sess.SessionRelease(nil, w)
}
func TestProvider_SessionInit(t *testing.T) {
savePath := `
{ "save_path": "my save path", "idle_timeout": "3s"}
`
cp := &Provider{}
cp.SessionInit(context.Background(), 12, savePath)
assert.Equal(t, "my save path", cp.SavePath)
assert.Equal(t, 3*time.Second, cp.idleTimeout)
assert.Equal(t, int64(12), cp.maxlifetime)
}

View File

@ -2,6 +2,7 @@ package ssdb
import (
"context"
"encoding/json"
"errors"
"net/http"
"strconv"
@ -18,33 +19,48 @@ var ssdbProvider = &Provider{}
// Provider holds ssdb client and configs
type Provider struct {
client *ssdb.Client
host string
port int
Host string `json:"host"`
Port int `json:"port"`
maxLifetime int64
}
func (p *Provider) connectInit() error {
var err error
if p.host == "" || p.port == 0 {
if p.Host == "" || p.Port == 0 {
return errors.New("SessionInit First")
}
p.client, err = ssdb.Connect(p.host, p.port)
p.client, err = ssdb.Connect(p.Host, p.Port)
return err
}
// SessionInit init the ssdb with the config
func (p *Provider) SessionInit(ctx context.Context, maxLifetime int64, savePath string) error {
func (p *Provider) SessionInit(ctx context.Context, maxLifetime int64, cfg string) error {
p.maxLifetime = maxLifetime
address := strings.Split(savePath, ":")
p.host = address[0]
cfg = strings.TrimSpace(cfg)
var err error
if p.port, err = strconv.Atoi(address[1]); err != nil {
// we think this is v2.0, using json to init the session
if strings.HasPrefix(cfg, "{") {
err = json.Unmarshal([]byte(cfg), p)
} else {
err = p.initOldStyle(cfg)
}
if err != nil {
return err
}
return p.connectInit()
}
// for v1.x
func (p *Provider) initOldStyle(savePath string) error {
address := strings.Split(savePath, ":")
p.Host = address[0]
var err error
p.Port, err = strconv.Atoi(address[1])
return err
}
// SessionRead return a ssdb client session Store
func (p *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
if p.client == nil {

View File

@ -0,0 +1,41 @@
// Copyright 2020
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ssdb
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestProvider_SessionInit(t *testing.T) {
// using old style
savePath := `localhost:8080`
cp := &Provider{}
cp.SessionInit(context.Background(), 12, savePath)
assert.Equal(t, "localhost", cp.Host)
assert.Equal(t, 8080, cp.Port)
assert.Equal(t, int64(12), cp.maxLifetime)
savePath = `
{ "host": "localhost", "port": 8080}
`
cp = &Provider{}
cp.SessionInit(context.Background(), 12, savePath)
assert.Equal(t, "localhost", cp.Host)
assert.Equal(t, 8080, cp.Port)
assert.Equal(t, int64(12), cp.maxLifetime)
}