mirror of
https://github.com/astaxie/beego.git
synced 2024-12-23 16:20:49 +00:00
Merge pull request #4277 from flycash/session
support using json string to init session
This commit is contained in:
commit
02234dc503
@ -34,6 +34,7 @@ package couchbase
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -57,9 +58,9 @@ type SessionStore struct {
|
||||
// Provider couchabse provided
|
||||
type Provider struct {
|
||||
maxlifetime int64
|
||||
savePath string
|
||||
pool string
|
||||
bucket string
|
||||
SavePath string `json:"save_path"`
|
||||
Pool string `json:"pool"`
|
||||
Bucket string `json:"bucket"`
|
||||
b *couchbase.Bucket
|
||||
}
|
||||
|
||||
@ -115,17 +116,17 @@ func (cs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
|
||||
}
|
||||
|
||||
func (cp *Provider) getBucket() *couchbase.Bucket {
|
||||
c, err := couchbase.Connect(cp.savePath)
|
||||
c, err := couchbase.Connect(cp.SavePath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
pool, err := c.GetPool(cp.pool)
|
||||
pool, err := c.GetPool(cp.Pool)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bucket, err := pool.GetBucket(cp.bucket)
|
||||
bucket, err := pool.GetBucket(cp.Bucket)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@ -135,18 +136,31 @@ func (cp *Provider) getBucket() *couchbase.Bucket {
|
||||
|
||||
// SessionInit init couchbase session
|
||||
// savepath like couchbase server REST/JSON URL
|
||||
// e.g. http://host:port/, Pool, Bucket
|
||||
func (cp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
|
||||
// For v1.x e.g. http://host:port/, Pool, Bucket
|
||||
// For v2.x, you should pass json string.
|
||||
// e.g. { "save_path": "http://host:port/", "pool": "mypool", "bucket": "mybucket"}
|
||||
func (cp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfg string) error {
|
||||
cp.maxlifetime = maxlifetime
|
||||
cfg = strings.TrimSpace(cfg)
|
||||
// we think this is v2.0, using json to init the session
|
||||
if strings.HasPrefix(cfg, "{") {
|
||||
return json.Unmarshal([]byte(cfg), cp)
|
||||
} else {
|
||||
return cp.initOldStyle(cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// initOldStyle keep compatible with v1.x
|
||||
func (cp *Provider) initOldStyle(savePath string) error {
|
||||
configs := strings.Split(savePath, ",")
|
||||
if len(configs) > 0 {
|
||||
cp.savePath = configs[0]
|
||||
cp.SavePath = configs[0]
|
||||
}
|
||||
if len(configs) > 1 {
|
||||
cp.pool = configs[1]
|
||||
cp.Pool = configs[1]
|
||||
}
|
||||
if len(configs) > 2 {
|
||||
cp.bucket = configs[2]
|
||||
cp.Bucket = configs[2]
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -225,7 +239,7 @@ func (cp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
// SessionDestroy Remove bucket in this couchbase
|
||||
// SessionDestroy Remove Bucket in this couchbase
|
||||
func (cp *Provider) SessionDestroy(ctx context.Context, sid string) error {
|
||||
cp.b = cp.getBucket()
|
||||
defer cp.b.Close()
|
||||
|
43
server/web/session/couchbase/sess_couchbase_test.go
Normal file
43
server/web/session/couchbase/sess_couchbase_test.go
Normal file
@ -0,0 +1,43 @@
|
||||
// Copyright 2020
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package couchbase
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestProvider_SessionInit(t *testing.T) {
|
||||
// using old style
|
||||
savePath := `http://host:port/,Pool,Bucket`
|
||||
cp := &Provider{}
|
||||
cp.SessionInit(context.Background(), 12, savePath)
|
||||
assert.Equal(t, "http://host:port/", cp.SavePath)
|
||||
assert.Equal(t, "Pool", cp.Pool)
|
||||
assert.Equal(t, "Bucket", cp.Bucket)
|
||||
assert.Equal(t, int64(12), cp.maxlifetime)
|
||||
|
||||
savePath = `
|
||||
{ "save_path": "my save path", "pool": "mypool", "bucket": "mybucket"}
|
||||
`
|
||||
cp = &Provider{}
|
||||
cp.SessionInit(context.Background(), 12, savePath)
|
||||
assert.Equal(t, "my save path", cp.SavePath)
|
||||
assert.Equal(t, "mypool", cp.Pool)
|
||||
assert.Equal(t, "mybucket", cp.Bucket)
|
||||
assert.Equal(t, int64(12), cp.maxlifetime)
|
||||
}
|
@ -3,6 +3,7 @@ package ledis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -79,35 +80,51 @@ func (ls *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
|
||||
// Provider ledis session provider
|
||||
type Provider struct {
|
||||
maxlifetime int64
|
||||
savePath string
|
||||
db int
|
||||
SavePath string `json:"save_path"`
|
||||
Db int `json:"db"`
|
||||
}
|
||||
|
||||
// SessionInit init ledis session
|
||||
// savepath like ledis server saveDataPath,pool size
|
||||
// e.g. 127.0.0.1:6379,100,astaxie
|
||||
func (lp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
|
||||
// v1.x e.g. 127.0.0.1:6379,100
|
||||
// v2.x you should pass a json string
|
||||
// e.g. { "save_path": "my save path", "db": 100}
|
||||
func (lp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error {
|
||||
var err error
|
||||
lp.maxlifetime = maxlifetime
|
||||
configs := strings.Split(savePath, ",")
|
||||
if len(configs) == 1 {
|
||||
lp.savePath = configs[0]
|
||||
} else if len(configs) == 2 {
|
||||
lp.savePath = configs[0]
|
||||
lp.db, err = strconv.Atoi(configs[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfgStr = strings.TrimSpace(cfgStr)
|
||||
// we think cfgStr is v2.0, using json to init the session
|
||||
if strings.HasPrefix(cfgStr, "{") {
|
||||
err = json.Unmarshal([]byte(cfgStr), lp)
|
||||
} else {
|
||||
err = lp.initOldStyle(cfgStr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg := new(config.Config)
|
||||
cfg.DataDir = lp.savePath
|
||||
cfg.DataDir = lp.SavePath
|
||||
|
||||
var ledisInstance *ledis.Ledis
|
||||
ledisInstance, err = ledis.Open(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c, err = ledisInstance.Select(lp.db)
|
||||
c, err = ledisInstance.Select(lp.Db)
|
||||
return err
|
||||
}
|
||||
|
||||
func (lp *Provider) initOldStyle(cfgStr string) error {
|
||||
var err error
|
||||
configs := strings.Split(cfgStr, ",")
|
||||
if len(configs) == 1 {
|
||||
lp.SavePath = configs[0]
|
||||
} else if len(configs) == 2 {
|
||||
lp.SavePath = configs[0]
|
||||
lp.Db, err = strconv.Atoi(configs[1])
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
|
41
server/web/session/ledis/ledis_session_test.go
Normal file
41
server/web/session/ledis/ledis_session_test.go
Normal file
@ -0,0 +1,41 @@
|
||||
// Copyright 2020
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package ledis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestProvider_SessionInit(t *testing.T) {
|
||||
// using old style
|
||||
savePath := `http://host:port/,100`
|
||||
cp := &Provider{}
|
||||
cp.SessionInit(context.Background(), 12, savePath)
|
||||
assert.Equal(t, "http://host:port/", cp.SavePath)
|
||||
assert.Equal(t, 100, cp.Db)
|
||||
assert.Equal(t, int64(12), cp.maxlifetime)
|
||||
|
||||
savePath = `
|
||||
{ "save_path": "my save path", "db": 100}
|
||||
`
|
||||
cp = &Provider{}
|
||||
cp.SessionInit(context.Background(), 12, savePath)
|
||||
assert.Equal(t, "my save path", cp.SavePath)
|
||||
assert.Equal(t, 100, cp.Db)
|
||||
assert.Equal(t, int64(12), cp.maxlifetime)
|
||||
}
|
@ -34,6 +34,7 @@ package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -110,48 +111,89 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
|
||||
|
||||
// Provider redis session provider
|
||||
type Provider struct {
|
||||
maxlifetime int64
|
||||
savePath string
|
||||
poolsize int
|
||||
password string
|
||||
dbNum int
|
||||
idleTimeout time.Duration
|
||||
idleCheckFrequency time.Duration
|
||||
maxRetries int
|
||||
poollist *redis.Client
|
||||
maxlifetime int64
|
||||
SavePath string `json:"save_path"`
|
||||
Poolsize int `json:"poolsize"`
|
||||
Password string `json:"password"`
|
||||
DbNum int `json:"db_num"`
|
||||
|
||||
idleTimeout time.Duration
|
||||
IdleTimeoutStr string `json:"idle_timeout"`
|
||||
|
||||
idleCheckFrequency time.Duration
|
||||
IdleCheckFrequencyStr string `json:"idle_check_frequency"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
poollist *redis.Client
|
||||
}
|
||||
|
||||
// SessionInit init redis session
|
||||
// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second
|
||||
// e.g. 127.0.0.1:6379,100,astaxie,0,30
|
||||
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
|
||||
// v1.x e.g. 127.0.0.1:6379,100,astaxie,0,30
|
||||
// v2.0 you should pass json string
|
||||
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error {
|
||||
rp.maxlifetime = maxlifetime
|
||||
|
||||
cfgStr = strings.TrimSpace(cfgStr)
|
||||
// we think cfgStr is v2.0, using json to init the session
|
||||
if strings.HasPrefix(cfgStr, "{") {
|
||||
err := json.Unmarshal([]byte(cfgStr), rp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rp.idleTimeout, err = time.ParseDuration(rp.IdleTimeoutStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rp.idleCheckFrequency, err = time.ParseDuration(rp.IdleCheckFrequencyStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
} else {
|
||||
rp.initOldStyle(cfgStr)
|
||||
}
|
||||
|
||||
rp.poollist = redis.NewClient(&redis.Options{
|
||||
Addr: rp.SavePath,
|
||||
Password: rp.Password,
|
||||
PoolSize: rp.Poolsize,
|
||||
DB: rp.DbNum,
|
||||
IdleTimeout: rp.idleTimeout,
|
||||
IdleCheckFrequency: rp.idleCheckFrequency,
|
||||
MaxRetries: rp.MaxRetries,
|
||||
})
|
||||
|
||||
return rp.poollist.Ping().Err()
|
||||
}
|
||||
|
||||
func (rp *Provider) initOldStyle(savePath string) {
|
||||
configs := strings.Split(savePath, ",")
|
||||
if len(configs) > 0 {
|
||||
rp.savePath = configs[0]
|
||||
rp.SavePath = configs[0]
|
||||
}
|
||||
if len(configs) > 1 {
|
||||
poolsize, err := strconv.Atoi(configs[1])
|
||||
if err != nil || poolsize < 0 {
|
||||
rp.poolsize = MaxPoolSize
|
||||
rp.Poolsize = MaxPoolSize
|
||||
} else {
|
||||
rp.poolsize = poolsize
|
||||
rp.Poolsize = poolsize
|
||||
}
|
||||
} else {
|
||||
rp.poolsize = MaxPoolSize
|
||||
rp.Poolsize = MaxPoolSize
|
||||
}
|
||||
if len(configs) > 2 {
|
||||
rp.password = configs[2]
|
||||
rp.Password = configs[2]
|
||||
}
|
||||
if len(configs) > 3 {
|
||||
dbnum, err := strconv.Atoi(configs[3])
|
||||
if err != nil || dbnum < 0 {
|
||||
rp.dbNum = 0
|
||||
rp.DbNum = 0
|
||||
} else {
|
||||
rp.dbNum = dbnum
|
||||
rp.DbNum = dbnum
|
||||
}
|
||||
} else {
|
||||
rp.dbNum = 0
|
||||
rp.DbNum = 0
|
||||
}
|
||||
if len(configs) > 4 {
|
||||
timeout, err := strconv.Atoi(configs[4])
|
||||
@ -168,21 +210,9 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath
|
||||
if len(configs) > 6 {
|
||||
retries, err := strconv.Atoi(configs[6])
|
||||
if err == nil && retries > 0 {
|
||||
rp.maxRetries = retries
|
||||
rp.MaxRetries = retries
|
||||
}
|
||||
}
|
||||
|
||||
rp.poollist = redis.NewClient(&redis.Options{
|
||||
Addr: rp.savePath,
|
||||
Password: rp.password,
|
||||
PoolSize: rp.poolsize,
|
||||
DB: rp.dbNum,
|
||||
IdleTimeout: rp.idleTimeout,
|
||||
IdleCheckFrequency: rp.idleCheckFrequency,
|
||||
MaxRetries: rp.maxRetries,
|
||||
})
|
||||
|
||||
return rp.poollist.Ping().Err()
|
||||
}
|
||||
|
||||
// SessionRead read redis session by sid
|
||||
|
@ -1,11 +1,15 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/astaxie/beego/server/web/session"
|
||||
)
|
||||
@ -94,3 +98,15 @@ func TestRedis(t *testing.T) {
|
||||
|
||||
sess.SessionRelease(nil, w)
|
||||
}
|
||||
|
||||
func TestProvider_SessionInit(t *testing.T) {
|
||||
|
||||
savePath := `
|
||||
{ "save_path": "my save path", "idle_timeout": "3s"}
|
||||
`
|
||||
cp := &Provider{}
|
||||
cp.SessionInit(context.Background(), 12, savePath)
|
||||
assert.Equal(t, "my save path", cp.SavePath)
|
||||
assert.Equal(t, 3*time.Second, cp.idleTimeout)
|
||||
assert.Equal(t, int64(12), cp.maxlifetime)
|
||||
}
|
||||
|
@ -34,14 +34,16 @@ package redis_cluster
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/astaxie/beego/server/web/session"
|
||||
rediss "github.com/go-redis/redis/v7"
|
||||
|
||||
"github.com/astaxie/beego/server/web/session"
|
||||
)
|
||||
|
||||
var redispder = &Provider{}
|
||||
@ -109,48 +111,86 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
|
||||
|
||||
// Provider redis_cluster session provider
|
||||
type Provider struct {
|
||||
maxlifetime int64
|
||||
savePath string
|
||||
poolsize int
|
||||
password string
|
||||
dbNum int
|
||||
idleTimeout time.Duration
|
||||
idleCheckFrequency time.Duration
|
||||
maxRetries int
|
||||
poollist *rediss.ClusterClient
|
||||
maxlifetime int64
|
||||
SavePath string `json:"save_path"`
|
||||
Poolsize int `json:"poolsize"`
|
||||
Password string `json:"password"`
|
||||
DbNum int `json:"db_num"`
|
||||
|
||||
idleTimeout time.Duration
|
||||
IdleTimeoutStr string `json:"idle_timeout"`
|
||||
|
||||
idleCheckFrequency time.Duration
|
||||
IdleCheckFrequencyStr string `json:"idle_check_frequency"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
poollist *rediss.ClusterClient
|
||||
}
|
||||
|
||||
// SessionInit init redis_cluster session
|
||||
// savepath like redis server addr,pool size,password,dbnum
|
||||
// cfgStr like redis server addr,pool size,password,dbnum
|
||||
// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0
|
||||
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
|
||||
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error {
|
||||
rp.maxlifetime = maxlifetime
|
||||
cfgStr = strings.TrimSpace(cfgStr)
|
||||
// we think cfgStr is v2.0, using json to init the session
|
||||
if strings.HasPrefix(cfgStr, "{") {
|
||||
err := json.Unmarshal([]byte(cfgStr), rp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rp.idleTimeout, err = time.ParseDuration(rp.IdleTimeoutStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rp.idleCheckFrequency, err = time.ParseDuration(rp.IdleCheckFrequencyStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
} else {
|
||||
rp.initOldStyle(cfgStr)
|
||||
}
|
||||
|
||||
rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{
|
||||
Addrs: strings.Split(rp.SavePath, ";"),
|
||||
Password: rp.Password,
|
||||
PoolSize: rp.Poolsize,
|
||||
IdleTimeout: rp.idleTimeout,
|
||||
IdleCheckFrequency: rp.idleCheckFrequency,
|
||||
MaxRetries: rp.MaxRetries,
|
||||
})
|
||||
return rp.poollist.Ping().Err()
|
||||
}
|
||||
|
||||
// for v1.x
|
||||
func (rp *Provider) initOldStyle(savePath string) {
|
||||
configs := strings.Split(savePath, ",")
|
||||
if len(configs) > 0 {
|
||||
rp.savePath = configs[0]
|
||||
rp.SavePath = configs[0]
|
||||
}
|
||||
if len(configs) > 1 {
|
||||
poolsize, err := strconv.Atoi(configs[1])
|
||||
if err != nil || poolsize < 0 {
|
||||
rp.poolsize = MaxPoolSize
|
||||
rp.Poolsize = MaxPoolSize
|
||||
} else {
|
||||
rp.poolsize = poolsize
|
||||
rp.Poolsize = poolsize
|
||||
}
|
||||
} else {
|
||||
rp.poolsize = MaxPoolSize
|
||||
rp.Poolsize = MaxPoolSize
|
||||
}
|
||||
if len(configs) > 2 {
|
||||
rp.password = configs[2]
|
||||
rp.Password = configs[2]
|
||||
}
|
||||
if len(configs) > 3 {
|
||||
dbnum, err := strconv.Atoi(configs[3])
|
||||
if err != nil || dbnum < 0 {
|
||||
rp.dbNum = 0
|
||||
rp.DbNum = 0
|
||||
} else {
|
||||
rp.dbNum = dbnum
|
||||
rp.DbNum = dbnum
|
||||
}
|
||||
} else {
|
||||
rp.dbNum = 0
|
||||
rp.DbNum = 0
|
||||
}
|
||||
if len(configs) > 4 {
|
||||
timeout, err := strconv.Atoi(configs[4])
|
||||
@ -167,19 +207,9 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath
|
||||
if len(configs) > 6 {
|
||||
retries, err := strconv.Atoi(configs[6])
|
||||
if err == nil && retries > 0 {
|
||||
rp.maxRetries = retries
|
||||
rp.MaxRetries = retries
|
||||
}
|
||||
}
|
||||
|
||||
rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{
|
||||
Addrs: strings.Split(rp.savePath, ";"),
|
||||
Password: rp.password,
|
||||
PoolSize: rp.poolsize,
|
||||
IdleTimeout: rp.idleTimeout,
|
||||
IdleCheckFrequency: rp.idleCheckFrequency,
|
||||
MaxRetries: rp.maxRetries,
|
||||
})
|
||||
return rp.poollist.Ping().Err()
|
||||
}
|
||||
|
||||
// SessionRead read redis_cluster session by sid
|
||||
|
35
server/web/session/redis_cluster/redis_cluster_test.go
Normal file
35
server/web/session/redis_cluster/redis_cluster_test.go
Normal file
@ -0,0 +1,35 @@
|
||||
// Copyright 2020
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package redis_cluster
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestProvider_SessionInit(t *testing.T) {
|
||||
|
||||
savePath := `
|
||||
{ "save_path": "my save path", "idle_timeout": "3s"}
|
||||
`
|
||||
cp := &Provider{}
|
||||
cp.SessionInit(context.Background(), 12, savePath)
|
||||
assert.Equal(t, "my save path", cp.SavePath)
|
||||
assert.Equal(t, 3*time.Second, cp.idleTimeout)
|
||||
assert.Equal(t, int64(12), cp.maxlifetime)
|
||||
}
|
@ -34,6 +34,7 @@ package redis_sentinel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -110,58 +111,99 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
|
||||
|
||||
// Provider redis_sentinel session provider
|
||||
type Provider struct {
|
||||
maxlifetime int64
|
||||
savePath string
|
||||
poolsize int
|
||||
password string
|
||||
dbNum int
|
||||
idleTimeout time.Duration
|
||||
idleCheckFrequency time.Duration
|
||||
maxRetries int
|
||||
poollist *redis.Client
|
||||
masterName string
|
||||
maxlifetime int64
|
||||
SavePath string `json:"save_path"`
|
||||
Poolsize int `json:"poolsize"`
|
||||
Password string `json:"password"`
|
||||
DbNum int `json:"db_num"`
|
||||
|
||||
idleTimeout time.Duration
|
||||
IdleTimeoutStr string `json:"idle_timeout"`
|
||||
|
||||
idleCheckFrequency time.Duration
|
||||
IdleCheckFrequencyStr string `json:"idle_check_frequency"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
poollist *redis.Client
|
||||
MasterName string `json:"master_name"`
|
||||
}
|
||||
|
||||
// SessionInit init redis_sentinel session
|
||||
// savepath like redis sentinel addr,pool size,password,dbnum,masterName
|
||||
// cfgStr like redis sentinel addr,pool size,password,dbnum,masterName
|
||||
// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster
|
||||
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
|
||||
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error {
|
||||
rp.maxlifetime = maxlifetime
|
||||
cfgStr = strings.TrimSpace(cfgStr)
|
||||
// we think cfgStr is v2.0, using json to init the session
|
||||
if strings.HasPrefix(cfgStr, "{") {
|
||||
err := json.Unmarshal([]byte(cfgStr), rp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rp.idleTimeout, err = time.ParseDuration(rp.IdleTimeoutStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rp.idleCheckFrequency, err = time.ParseDuration(rp.IdleCheckFrequencyStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
} else {
|
||||
rp.initOldStyle(cfgStr)
|
||||
}
|
||||
|
||||
rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{
|
||||
SentinelAddrs: strings.Split(rp.SavePath, ";"),
|
||||
Password: rp.Password,
|
||||
PoolSize: rp.Poolsize,
|
||||
DB: rp.DbNum,
|
||||
MasterName: rp.MasterName,
|
||||
IdleTimeout: rp.idleTimeout,
|
||||
IdleCheckFrequency: rp.idleCheckFrequency,
|
||||
MaxRetries: rp.MaxRetries,
|
||||
})
|
||||
|
||||
return rp.poollist.Ping().Err()
|
||||
}
|
||||
|
||||
// for v1.x
|
||||
func (rp *Provider) initOldStyle(savePath string) {
|
||||
configs := strings.Split(savePath, ",")
|
||||
if len(configs) > 0 {
|
||||
rp.savePath = configs[0]
|
||||
rp.SavePath = configs[0]
|
||||
}
|
||||
if len(configs) > 1 {
|
||||
poolsize, err := strconv.Atoi(configs[1])
|
||||
if err != nil || poolsize < 0 {
|
||||
rp.poolsize = DefaultPoolSize
|
||||
rp.Poolsize = DefaultPoolSize
|
||||
} else {
|
||||
rp.poolsize = poolsize
|
||||
rp.Poolsize = poolsize
|
||||
}
|
||||
} else {
|
||||
rp.poolsize = DefaultPoolSize
|
||||
rp.Poolsize = DefaultPoolSize
|
||||
}
|
||||
if len(configs) > 2 {
|
||||
rp.password = configs[2]
|
||||
rp.Password = configs[2]
|
||||
}
|
||||
if len(configs) > 3 {
|
||||
dbnum, err := strconv.Atoi(configs[3])
|
||||
if err != nil || dbnum < 0 {
|
||||
rp.dbNum = 0
|
||||
rp.DbNum = 0
|
||||
} else {
|
||||
rp.dbNum = dbnum
|
||||
rp.DbNum = dbnum
|
||||
}
|
||||
} else {
|
||||
rp.dbNum = 0
|
||||
rp.DbNum = 0
|
||||
}
|
||||
if len(configs) > 4 {
|
||||
if configs[4] != "" {
|
||||
rp.masterName = configs[4]
|
||||
rp.MasterName = configs[4]
|
||||
} else {
|
||||
rp.masterName = "mymaster"
|
||||
rp.MasterName = "mymaster"
|
||||
}
|
||||
} else {
|
||||
rp.masterName = "mymaster"
|
||||
rp.MasterName = "mymaster"
|
||||
}
|
||||
if len(configs) > 5 {
|
||||
timeout, err := strconv.Atoi(configs[4])
|
||||
@ -178,22 +220,9 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath
|
||||
if len(configs) > 7 {
|
||||
retries, err := strconv.Atoi(configs[6])
|
||||
if err == nil && retries > 0 {
|
||||
rp.maxRetries = retries
|
||||
rp.MaxRetries = retries
|
||||
}
|
||||
}
|
||||
|
||||
rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{
|
||||
SentinelAddrs: strings.Split(rp.savePath, ";"),
|
||||
Password: rp.password,
|
||||
PoolSize: rp.poolsize,
|
||||
DB: rp.dbNum,
|
||||
MasterName: rp.masterName,
|
||||
IdleTimeout: rp.idleTimeout,
|
||||
IdleCheckFrequency: rp.idleCheckFrequency,
|
||||
MaxRetries: rp.maxRetries,
|
||||
})
|
||||
|
||||
return rp.poollist.Ping().Err()
|
||||
}
|
||||
|
||||
// SessionRead read redis_sentinel session by sid
|
||||
|
@ -1,9 +1,13 @@
|
||||
package redis_sentinel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/astaxie/beego/server/web/session"
|
||||
)
|
||||
@ -23,7 +27,7 @@ func TestRedisSentinel(t *testing.T) {
|
||||
t.Log(e)
|
||||
return
|
||||
}
|
||||
//todo test if e==nil
|
||||
// todo test if e==nil
|
||||
go globalSessions.GC()
|
||||
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
@ -88,3 +92,15 @@ func TestRedisSentinel(t *testing.T) {
|
||||
sess.SessionRelease(nil, w)
|
||||
|
||||
}
|
||||
|
||||
func TestProvider_SessionInit(t *testing.T) {
|
||||
|
||||
savePath := `
|
||||
{ "save_path": "my save path", "idle_timeout": "3s"}
|
||||
`
|
||||
cp := &Provider{}
|
||||
cp.SessionInit(context.Background(), 12, savePath)
|
||||
assert.Equal(t, "my save path", cp.SavePath)
|
||||
assert.Equal(t, 3*time.Second, cp.idleTimeout)
|
||||
assert.Equal(t, int64(12), cp.maxlifetime)
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package ssdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@ -18,33 +19,48 @@ var ssdbProvider = &Provider{}
|
||||
// Provider holds ssdb client and configs
|
||||
type Provider struct {
|
||||
client *ssdb.Client
|
||||
host string
|
||||
port int
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
maxLifetime int64
|
||||
}
|
||||
|
||||
func (p *Provider) connectInit() error {
|
||||
var err error
|
||||
if p.host == "" || p.port == 0 {
|
||||
if p.Host == "" || p.Port == 0 {
|
||||
return errors.New("SessionInit First")
|
||||
}
|
||||
p.client, err = ssdb.Connect(p.host, p.port)
|
||||
p.client, err = ssdb.Connect(p.Host, p.Port)
|
||||
return err
|
||||
}
|
||||
|
||||
// SessionInit init the ssdb with the config
|
||||
func (p *Provider) SessionInit(ctx context.Context, maxLifetime int64, savePath string) error {
|
||||
func (p *Provider) SessionInit(ctx context.Context, maxLifetime int64, cfg string) error {
|
||||
p.maxLifetime = maxLifetime
|
||||
address := strings.Split(savePath, ":")
|
||||
p.host = address[0]
|
||||
|
||||
cfg = strings.TrimSpace(cfg)
|
||||
var err error
|
||||
if p.port, err = strconv.Atoi(address[1]); err != nil {
|
||||
// we think this is v2.0, using json to init the session
|
||||
if strings.HasPrefix(cfg, "{") {
|
||||
err = json.Unmarshal([]byte(cfg), p)
|
||||
} else {
|
||||
err = p.initOldStyle(cfg)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return p.connectInit()
|
||||
}
|
||||
|
||||
// for v1.x
|
||||
func (p *Provider) initOldStyle(savePath string) error {
|
||||
address := strings.Split(savePath, ":")
|
||||
p.Host = address[0]
|
||||
|
||||
var err error
|
||||
p.Port, err = strconv.Atoi(address[1])
|
||||
return err
|
||||
}
|
||||
|
||||
// SessionRead return a ssdb client session Store
|
||||
func (p *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
|
||||
if p.client == nil {
|
||||
|
41
server/web/session/ssdb/sess_ssdb_test.go
Normal file
41
server/web/session/ssdb/sess_ssdb_test.go
Normal file
@ -0,0 +1,41 @@
|
||||
// Copyright 2020
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package ssdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestProvider_SessionInit(t *testing.T) {
|
||||
// using old style
|
||||
savePath := `localhost:8080`
|
||||
cp := &Provider{}
|
||||
cp.SessionInit(context.Background(), 12, savePath)
|
||||
assert.Equal(t, "localhost", cp.Host)
|
||||
assert.Equal(t, 8080, cp.Port)
|
||||
assert.Equal(t, int64(12), cp.maxLifetime)
|
||||
|
||||
savePath = `
|
||||
{ "host": "localhost", "port": 8080}
|
||||
`
|
||||
cp = &Provider{}
|
||||
cp.SessionInit(context.Background(), 12, savePath)
|
||||
assert.Equal(t, "localhost", cp.Host)
|
||||
assert.Equal(t, 8080, cp.Port)
|
||||
assert.Equal(t, int64(12), cp.maxLifetime)
|
||||
}
|
Loading…
Reference in New Issue
Block a user