1
0
mirror of https://github.com/astaxie/beego.git synced 2025-06-14 18:40:40 +00:00

move core/session to web/session

This commit is contained in:
Ming Deng
2020-10-05 19:04:22 +08:00
parent 9e6b8fcf34
commit d8e8f41230
41 changed files with 33 additions and 32 deletions

View File

@ -27,7 +27,7 @@ import (
"github.com/astaxie/beego/pkg"
"github.com/astaxie/beego/pkg/core/config"
"github.com/astaxie/beego/pkg/core/logs"
"github.com/astaxie/beego/pkg/core/session"
"github.com/astaxie/beego/pkg/server/web/session"
"github.com/astaxie/beego/pkg/core/utils"
"github.com/astaxie/beego/pkg/server/web/context"

View File

@ -29,7 +29,7 @@ import (
"strings"
"sync"
"github.com/astaxie/beego/pkg/core/session"
"github.com/astaxie/beego/pkg/server/web/session"
)
// Regexes for checking the accept headers

View File

@ -28,7 +28,7 @@ import (
"strconv"
"strings"
"github.com/astaxie/beego/pkg/core/session"
"github.com/astaxie/beego/pkg/server/web/session"
"github.com/astaxie/beego/pkg/server/web/context"
"github.com/astaxie/beego/pkg/server/web/context/param"

View File

@ -8,8 +8,8 @@ import (
"path/filepath"
"github.com/astaxie/beego/pkg/core/logs"
"github.com/astaxie/beego/pkg/core/session"
"github.com/astaxie/beego/pkg/server/web/context"
"github.com/astaxie/beego/pkg/server/web/session"
)
// register MIME type with content type

View File

@ -0,0 +1,114 @@
session
==============
session is a Go session manager. It can use many session providers. Just like the `database/sql` and `database/sql/driver`.
## How to install?
go get github.com/astaxie/beego/session
## What providers are supported?
As of now this session manager support memory, file, Redis and MySQL.
## How to use it?
First you must import it
import (
"github.com/astaxie/beego/session"
)
Then in you web app init the global session manager
var globalSessions *session.Manager
* Use **memory** as provider:
func init() {
globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`)
go globalSessions.GC()
}
* Use **file** as provider, the last param is the path where you want file to be stored:
func init() {
globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"./tmp"}`)
go globalSessions.GC()
}
* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password:
func init() {
globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:6379,100,astaxie"}`)
go globalSessions.GC()
}
* Use **MySQL** as provider, the last param is the DSN, learn more from [mysql](https://github.com/go-sql-driver/mysql#dsn-data-source-name):
func init() {
globalSessions, _ = session.NewManager(
"mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"username:password@protocol(address)/dbname?param=value"}`)
go globalSessions.GC()
}
* Use **Cookie** as provider:
func init() {
globalSessions, _ = session.NewManager(
"cookie", `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`)
go globalSessions.GC()
}
Finally in the handlerfunc you can use it like this
func login(w http.ResponseWriter, r *http.Request) {
sess := globalSessions.SessionStart(w, r)
defer sess.SessionRelease(w)
username := sess.Get("username")
fmt.Println(username)
if r.Method == "GET" {
t, _ := template.ParseFiles("login.gtpl")
t.Execute(w, nil)
} else {
fmt.Println("username:", r.Form["username"])
sess.Set("username", r.Form["username"])
fmt.Println("password:", r.Form["password"])
}
}
## How to write own provider?
When you develop a web app, maybe you want to write own provider because you must meet the requirements.
Writing a provider is easy. You only need to define two struct types
(Session and Provider), which satisfy the interface definition.
Maybe you will find the **memory** provider is a good example.
type SessionStore interface {
Set(key, value interface{}) error //set session value
Get(key interface{}) interface{} //get session value
Delete(key interface{}) error //delete session value
SessionID() string //back current sessionID
SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
Flush() error //delete all data
}
type Provider interface {
SessionInit(gclifetime int64, config string) error
SessionRead(sid string) (SessionStore, error)
SessionExist(sid string) (bool, error)
SessionRegenerate(oldsid, sid string) (SessionStore, error)
SessionDestroy(sid string) error
SessionAll() int //get all active session
SessionGC()
}
## LICENSE
BSD License http://creativecommons.org/licenses/BSD/

View File

@ -0,0 +1,248 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package couchbase for session provider
//
// depend on github.com/couchbaselabs/go-couchbasee
//
// go install github.com/couchbaselabs/go-couchbase
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/couchbase"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("couchbase", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"http://host:port/, Pool, Bucket"}``)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package couchbase
import (
"context"
"net/http"
"strings"
"sync"
couchbase "github.com/couchbase/go-couchbase"
"github.com/astaxie/beego/pkg/server/web/session"
)
var couchbpder = &Provider{}
// SessionStore store each session
type SessionStore struct {
b *couchbase.Bucket
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Provider couchabse provided
type Provider struct {
maxlifetime int64
savePath string
pool string
bucket string
b *couchbase.Bucket
}
// Set value to couchabse session
func (cs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
cs.lock.Lock()
defer cs.lock.Unlock()
cs.values[key] = value
return nil
}
// Get value from couchabse session
func (cs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
cs.lock.RLock()
defer cs.lock.RUnlock()
if v, ok := cs.values[key]; ok {
return v
}
return nil
}
// Delete value in couchbase session by given key
func (cs *SessionStore) Delete(ctx context.Context, key interface{}) error {
cs.lock.Lock()
defer cs.lock.Unlock()
delete(cs.values, key)
return nil
}
// Flush Clean all values in couchbase session
func (cs *SessionStore) Flush(context.Context) error {
cs.lock.Lock()
defer cs.lock.Unlock()
cs.values = make(map[interface{}]interface{})
return nil
}
// SessionID Get couchbase session store id
func (cs *SessionStore) SessionID(context.Context) string {
return cs.sid
}
// SessionRelease Write couchbase session with Gob string
func (cs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
defer cs.b.Close()
bo, err := session.EncodeGob(cs.values)
if err != nil {
return
}
cs.b.Set(cs.sid, int(cs.maxlifetime), bo)
}
func (cp *Provider) getBucket() *couchbase.Bucket {
c, err := couchbase.Connect(cp.savePath)
if err != nil {
return nil
}
pool, err := c.GetPool(cp.pool)
if err != nil {
return nil
}
bucket, err := pool.GetBucket(cp.bucket)
if err != nil {
return nil
}
return bucket
}
// SessionInit init couchbase session
// savepath like couchbase server REST/JSON URL
// e.g. http://host:port/, Pool, Bucket
func (cp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
cp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
cp.savePath = configs[0]
}
if len(configs) > 1 {
cp.pool = configs[1]
}
if len(configs) > 2 {
cp.bucket = configs[2]
}
return nil
}
// SessionRead read couchbase session by sid
func (cp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
cp.b = cp.getBucket()
var (
kv map[interface{}]interface{}
err error
doc []byte
)
err = cp.b.Get(sid, &doc)
if err != nil {
return nil, err
} else if doc == nil {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(doc)
if err != nil {
return nil, err
}
}
cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
return cs, nil
}
// SessionExist Check couchbase session exist.
// it checkes sid exist or not.
func (cp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
cp.b = cp.getBucket()
defer cp.b.Close()
var doc []byte
if err := cp.b.Get(sid, &doc); err != nil || doc == nil {
return false, err
}
return true, nil
}
// SessionRegenerate remove oldsid and use sid to generate new session
func (cp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
cp.b = cp.getBucket()
var doc []byte
if err := cp.b.Get(oldsid, &doc); err != nil || doc == nil {
cp.b.Set(sid, int(cp.maxlifetime), "")
} else {
err := cp.b.Delete(oldsid)
if err != nil {
return nil, err
}
_, _ = cp.b.Add(sid, int(cp.maxlifetime), doc)
}
err := cp.b.Get(sid, &doc)
if err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if doc == nil {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(doc)
if err != nil {
return nil, err
}
}
cs := &SessionStore{b: cp.b, sid: sid, values: kv, maxlifetime: cp.maxlifetime}
return cs, nil
}
// SessionDestroy Remove bucket in this couchbase
func (cp *Provider) SessionDestroy(ctx context.Context, sid string) error {
cp.b = cp.getBucket()
defer cp.b.Close()
cp.b.Delete(sid)
return nil
}
// SessionGC Recycle
func (cp *Provider) SessionGC(context.Context) {
}
// SessionAll return all active session
func (cp *Provider) SessionAll(context.Context) int {
return 0
}
func init() {
session.Register("couchbase", couchbpder)
}

View File

@ -0,0 +1,174 @@
// Package ledis provide session Provider
package ledis
import (
"context"
"net/http"
"strconv"
"strings"
"sync"
"github.com/ledisdb/ledisdb/config"
"github.com/ledisdb/ledisdb/ledis"
"github.com/astaxie/beego/pkg/server/web/session"
)
var (
ledispder = &Provider{}
c *ledis.DB
)
// SessionStore ledis session store
type SessionStore struct {
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Set value in ledis session
func (ls *SessionStore) Set(ctx context.Context, key, value interface{}) error {
ls.lock.Lock()
defer ls.lock.Unlock()
ls.values[key] = value
return nil
}
// Get value in ledis session
func (ls *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
ls.lock.RLock()
defer ls.lock.RUnlock()
if v, ok := ls.values[key]; ok {
return v
}
return nil
}
// Delete value in ledis session
func (ls *SessionStore) Delete(ctx context.Context, key interface{}) error {
ls.lock.Lock()
defer ls.lock.Unlock()
delete(ls.values, key)
return nil
}
// Flush clear all values in ledis session
func (ls *SessionStore) Flush(context.Context) error {
ls.lock.Lock()
defer ls.lock.Unlock()
ls.values = make(map[interface{}]interface{})
return nil
}
// SessionID get ledis session id
func (ls *SessionStore) SessionID(context.Context) string {
return ls.sid
}
// SessionRelease save session values to ledis
func (ls *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(ls.values)
if err != nil {
return
}
c.Set([]byte(ls.sid), b)
c.Expire([]byte(ls.sid), ls.maxlifetime)
}
// Provider ledis session provider
type Provider struct {
maxlifetime int64
savePath string
db int
}
// SessionInit init ledis session
// savepath like ledis server saveDataPath,pool size
// e.g. 127.0.0.1:6379,100,astaxie
func (lp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
var err error
lp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) == 1 {
lp.savePath = configs[0]
} else if len(configs) == 2 {
lp.savePath = configs[0]
lp.db, err = strconv.Atoi(configs[1])
if err != nil {
return err
}
}
cfg := new(config.Config)
cfg.DataDir = lp.savePath
var ledisInstance *ledis.Ledis
ledisInstance, err = ledis.Open(cfg)
if err != nil {
return err
}
c, err = ledisInstance.Select(lp.db)
return err
}
// SessionRead read ledis session by sid
func (lp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
var (
kv map[interface{}]interface{}
err error
)
kvs, _ := c.Get([]byte(sid))
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
if kv, err = session.DecodeGob(kvs); err != nil {
return nil, err
}
}
ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime}
return ls, nil
}
// SessionExist check ledis session exist by sid
func (lp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
count, _ := c.Exists([]byte(sid))
return count != 0, nil
}
// SessionRegenerate generate new sid for ledis session
func (lp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
count, _ := c.Exists([]byte(sid))
if count == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
c.Set([]byte(sid), []byte(""))
c.Expire([]byte(sid), lp.maxlifetime)
} else {
data, _ := c.Get([]byte(oldsid))
c.Set([]byte(sid), data)
c.Expire([]byte(sid), lp.maxlifetime)
}
return lp.SessionRead(context.Background(), sid)
}
// SessionDestroy delete ledis session by id
func (lp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c.Del([]byte(sid))
return nil
}
// SessionGC Impelment method, no used.
func (lp *Provider) SessionGC(context.Context) {
}
// SessionAll return all active session
func (lp *Provider) SessionAll(context.Context) int {
return 0
}
func init() {
session.Register("ledis", ledispder)
}

View File

@ -0,0 +1,231 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package memcache for session provider
//
// depend on github.com/bradfitz/gomemcache/memcache
//
// go install github.com/bradfitz/gomemcache/memcache
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/memcache"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("memcache", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:11211"}``)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package memcache
import (
"context"
"net/http"
"strings"
"sync"
"github.com/astaxie/beego/pkg/server/web/session"
"github.com/bradfitz/gomemcache/memcache"
)
var mempder = &MemProvider{}
var client *memcache.Client
// SessionStore memcache session store
type SessionStore struct {
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Set value in memcache session
func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
return nil
}
// Get value in memcache session
func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
return v
}
return nil
}
// Delete value in memcache session
func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
return nil
}
// Flush clear all values in memcache session
func (rs *SessionStore) Flush(context.Context) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
return nil
}
// SessionID get memcache session id
func (rs *SessionStore) SessionID(context.Context) string {
return rs.sid
}
// SessionRelease save session values to memcache
func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
}
item := memcache.Item{Key: rs.sid, Value: b, Expiration: int32(rs.maxlifetime)}
client.Set(&item)
}
// MemProvider memcache session provider
type MemProvider struct {
maxlifetime int64
conninfo []string
poolsize int
password string
}
// SessionInit init memcache session
// savepath like
// e.g. 127.0.0.1:9090
func (rp *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
rp.conninfo = strings.Split(savePath, ";")
client = memcache.New(rp.conninfo...)
return nil
}
// SessionRead read memcache session by sid
func (rp *MemProvider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
if client == nil {
if err := rp.connectInit(); err != nil {
return nil, err
}
}
item, err := client.Get(sid)
if err != nil {
if err == memcache.ErrCacheMiss {
rs := &SessionStore{sid: sid, values: make(map[interface{}]interface{}), maxlifetime: rp.maxlifetime}
return rs, nil
}
return nil, err
}
var kv map[interface{}]interface{}
if len(item.Value) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(item.Value)
if err != nil {
return nil, err
}
}
rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
// SessionExist check memcache session exist by sid
func (rp *MemProvider) SessionExist(ctx context.Context, sid string) (bool, error) {
if client == nil {
if err := rp.connectInit(); err != nil {
return false, err
}
}
if item, err := client.Get(sid); err != nil || len(item.Value) == 0 {
return false, err
}
return true, nil
}
// SessionRegenerate generate new sid for memcache session
func (rp *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
if client == nil {
if err := rp.connectInit(); err != nil {
return nil, err
}
}
var contain []byte
if item, err := client.Get(sid); err != nil || len(item.Value) == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
item.Key = sid
item.Value = []byte("")
item.Expiration = int32(rp.maxlifetime)
client.Set(item)
} else {
client.Delete(oldsid)
item.Key = sid
item.Expiration = int32(rp.maxlifetime)
client.Set(item)
contain = item.Value
}
var kv map[interface{}]interface{}
if len(contain) == 0 {
kv = make(map[interface{}]interface{})
} else {
var err error
kv, err = session.DecodeGob(contain)
if err != nil {
return nil, err
}
}
rs := &SessionStore{sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
// SessionDestroy delete memcache session by id
func (rp *MemProvider) SessionDestroy(ctx context.Context, sid string) error {
if client == nil {
if err := rp.connectInit(); err != nil {
return err
}
}
return client.Delete(sid)
}
func (rp *MemProvider) connectInit() error {
client = memcache.New(rp.conninfo...)
return nil
}
// SessionGC Impelment method, no used.
func (rp *MemProvider) SessionGC(context.Context) {
}
// SessionAll return all activeSession
func (rp *MemProvider) SessionAll(context.Context) int {
return 0
}
func init() {
session.Register("memcache", mempder)
}

View File

@ -0,0 +1,235 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package mysql for session provider
//
// depends on github.com/go-sql-driver/mysql:
//
// go install github.com/go-sql-driver/mysql
//
// mysql session support need create table as sql:
// CREATE TABLE `session` (
// `session_key` char(64) NOT NULL,
// `session_data` blob,
// `session_expiry` int(11) unsigned NOT NULL,
// PRIMARY KEY (`session_key`)
// ) ENGINE=MyISAM DEFAULT CHARSET=utf8;
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/mysql"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("mysql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...&paramN=valueN]"}``)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package mysql
import (
"context"
"database/sql"
"net/http"
"sync"
"time"
"github.com/astaxie/beego/pkg/server/web/session"
// import mysql driver
_ "github.com/go-sql-driver/mysql"
)
var (
// TableName store the session in MySQL
TableName = "session"
mysqlpder = &Provider{}
)
// SessionStore mysql session store
type SessionStore struct {
c *sql.DB
sid string
lock sync.RWMutex
values map[interface{}]interface{}
}
// Set value in mysql session.
// it is temp value in map.
func (st *SessionStore) Set(ctx context.Context, key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
return nil
}
// Get value from mysql session
func (st *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
}
return nil
}
// Delete value in mysql session
func (st *SessionStore) Delete(ctx context.Context, key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
return nil
}
// Flush clear all values in mysql session
func (st *SessionStore) Flush(context.Context) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
return nil
}
// SessionID get session id of this mysql session store
func (st *SessionStore) SessionID(context.Context) string {
return st.sid
}
// SessionRelease save mysql session values to database.
// must call this method to save values to database.
func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
defer st.c.Close()
b, err := session.EncodeGob(st.values)
if err != nil {
return
}
st.c.Exec("UPDATE "+TableName+" set `session_data`=?, `session_expiry`=? where session_key=?",
b, time.Now().Unix(), st.sid)
}
// Provider mysql session provider
type Provider struct {
maxlifetime int64
savePath string
}
// connect to mysql
func (mp *Provider) connectInit() *sql.DB {
db, e := sql.Open("mysql", mp.savePath)
if e != nil {
return nil
}
return db
}
// SessionInit init mysql session.
// savepath is the connection string of mysql.
func (mp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
mp.maxlifetime = maxlifetime
mp.savePath = savePath
return nil
}
// SessionRead get mysql session by sid
func (mp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)",
sid, "", time.Now().Unix())
}
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(sessiondata)
if err != nil {
return nil, err
}
}
rs := &SessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
// SessionExist check mysql session exist
func (mp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := mp.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, err
}
return true, nil
}
// SessionRegenerate generate new sid for mysql session
func (mp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix())
}
c.Exec("update "+TableName+" set `session_key`=? where session_key=?", sid, oldsid)
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(sessiondata)
if err != nil {
return nil, err
}
}
rs := &SessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
// SessionDestroy delete mysql session by sid
func (mp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := mp.connectInit()
c.Exec("DELETE FROM "+TableName+" where session_key=?", sid)
c.Close()
return nil
}
// SessionGC delete expired values in mysql session
func (mp *Provider) SessionGC(context.Context) {
c := mp.connectInit()
c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime)
c.Close()
}
// SessionAll count values in mysql session
func (mp *Provider) SessionAll(context.Context) int {
c := mp.connectInit()
defer c.Close()
var total int
err := c.QueryRow("SELECT count(*) as num from " + TableName).Scan(&total)
if err != nil {
return 0
}
return total
}
func init() {
session.Register("mysql", mysqlpder)
}

View File

@ -0,0 +1,250 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package postgres for session provider
//
// depends on github.com/lib/pq:
//
// go install github.com/lib/pq
//
//
// needs this table in your database:
//
// CREATE TABLE session (
// session_key char(64) NOT NULL,
// session_data bytea,
// session_expiry timestamp NOT NULL,
// CONSTRAINT session_key PRIMARY KEY(session_key)
// );
//
// will be activated with these settings in app.conf:
//
// SessionOn = true
// SessionProvider = postgresql
// SessionSavePath = "user=a password=b dbname=c sslmode=disable"
// SessionName = session
//
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/postgresql"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("postgresql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"user=pqgotest dbname=pqgotest sslmode=verify-full"}``)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package postgres
import (
"context"
"database/sql"
"net/http"
"sync"
"time"
"github.com/astaxie/beego/pkg/server/web/session"
// import postgresql Driver
_ "github.com/lib/pq"
)
var postgresqlpder = &Provider{}
// SessionStore postgresql session store
type SessionStore struct {
c *sql.DB
sid string
lock sync.RWMutex
values map[interface{}]interface{}
}
// Set value in postgresql session.
// it is temp value in map.
func (st *SessionStore) Set(ctx context.Context, key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
return nil
}
// Get value from postgresql session
func (st *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
}
return nil
}
// Delete value in postgresql session
func (st *SessionStore) Delete(ctx context.Context, key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
return nil
}
// Flush clear all values in postgresql session
func (st *SessionStore) Flush(context.Context) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
return nil
}
// SessionID get session id of this postgresql session store
func (st *SessionStore) SessionID(context.Context) string {
return st.sid
}
// SessionRelease save postgresql session values to database.
// must call this method to save values to database.
func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
defer st.c.Close()
b, err := session.EncodeGob(st.values)
if err != nil {
return
}
st.c.Exec("UPDATE session set session_data=$1, session_expiry=$2 where session_key=$3",
b, time.Now().Format(time.RFC3339), st.sid)
}
// Provider postgresql session provider
type Provider struct {
maxlifetime int64
savePath string
}
// connect to postgresql
func (mp *Provider) connectInit() *sql.DB {
db, e := sql.Open("postgres", mp.savePath)
if e != nil {
return nil
}
return db
}
// SessionInit init postgresql session.
// savepath is the connection string of postgresql.
func (mp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
mp.maxlifetime = maxlifetime
mp.savePath = savePath
return nil
}
// SessionRead get postgresql session by sid
func (mp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=$1", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
_, err = c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)",
sid, "", time.Now().Format(time.RFC3339))
if err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(sessiondata)
if err != nil {
return nil, err
}
}
rs := &SessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
// SessionExist check postgresql session exist
func (mp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := mp.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from session where session_key=$1", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, err
}
return true, nil
}
// SessionRegenerate generate new sid for postgresql session
func (mp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=$1", oldsid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)",
oldsid, "", time.Now().Format(time.RFC3339))
}
c.Exec("update session set session_key=$1 where session_key=$2", sid, oldsid)
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(sessiondata)
if err != nil {
return nil, err
}
}
rs := &SessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
// SessionDestroy delete postgresql session by sid
func (mp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := mp.connectInit()
c.Exec("DELETE FROM session where session_key=$1", sid)
c.Close()
return nil
}
// SessionGC delete expired values in postgresql session
func (mp *Provider) SessionGC(context.Context) {
c := mp.connectInit()
c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime)
c.Close()
}
// SessionAll count values in postgresql session
func (mp *Provider) SessionAll(context.Context) int {
c := mp.connectInit()
defer c.Close()
var total int
err := c.QueryRow("SELECT count(*) as num from session").Scan(&total)
if err != nil {
return 0
}
return total
}
func init() {
session.Register("postgresql", postgresqlpder)
}

View File

@ -0,0 +1,252 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package redis for session provider
//
// depend on github.com/gomodule/redigo/redis
//
// go install github.com/gomodule/redigo/redis
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/redis"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package redis
import (
"context"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/astaxie/beego/pkg/server/web/session"
"github.com/go-redis/redis/v7"
)
var redispder = &Provider{}
// MaxPoolSize redis max pool size
var MaxPoolSize = 100
// SessionStore redis session store
type SessionStore struct {
p *redis.Client
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Set value in redis session
func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
return nil
}
// Get value in redis session
func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
return v
}
return nil
}
// Delete value in redis session
func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
return nil
}
// Flush clear all values in redis session
func (rs *SessionStore) Flush(context.Context) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
return nil
}
// SessionID get redis session id
func (rs *SessionStore) SessionID(context.Context) string {
return rs.sid
}
// SessionRelease save session values to redis
func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
}
c := rs.p
c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second)
}
// Provider redis session provider
type Provider struct {
maxlifetime int64
savePath string
poolsize int
password string
dbNum int
idleTimeout time.Duration
idleCheckFrequency time.Duration
maxRetries int
poollist *redis.Client
}
// SessionInit init redis session
// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second
// e.g. 127.0.0.1:6379,100,astaxie,0,30
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
rp.savePath = configs[0]
}
if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize < 0 {
rp.poolsize = MaxPoolSize
} else {
rp.poolsize = poolsize
}
} else {
rp.poolsize = MaxPoolSize
}
if len(configs) > 2 {
rp.password = configs[2]
}
if len(configs) > 3 {
dbnum, err := strconv.Atoi(configs[3])
if err != nil || dbnum < 0 {
rp.dbNum = 0
} else {
rp.dbNum = dbnum
}
} else {
rp.dbNum = 0
}
if len(configs) > 4 {
timeout, err := strconv.Atoi(configs[4])
if err == nil && timeout > 0 {
rp.idleTimeout = time.Duration(timeout) * time.Second
}
}
if len(configs) > 5 {
checkFrequency, err := strconv.Atoi(configs[5])
if err == nil && checkFrequency > 0 {
rp.idleCheckFrequency = time.Duration(checkFrequency) * time.Second
}
}
if len(configs) > 6 {
retries, err := strconv.Atoi(configs[6])
if err == nil && retries > 0 {
rp.maxRetries = retries
}
}
rp.poollist = redis.NewClient(&redis.Options{
Addr: rp.savePath,
Password: rp.password,
PoolSize: rp.poolsize,
DB: rp.dbNum,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.maxRetries,
})
return rp.poollist.Ping().Err()
}
// SessionRead read redis session by sid
func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
var kv map[interface{}]interface{}
kvs, err := rp.poollist.Get(sid).Result()
if err != nil && err != redis.Nil {
return nil, err
}
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
if kv, err = session.DecodeGob([]byte(kvs)); err != nil {
return nil, err
}
}
rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
// SessionExist check redis session exist by sid
func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := rp.poollist
if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 {
return false, err
}
return true, nil
}
// SessionRegenerate generate new sid for redis session
func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := rp.poollist
if existed, _ := c.Exists(oldsid).Result(); existed == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
c.Do(c.Context(), "SET", sid, "", "EX", rp.maxlifetime)
} else {
c.Rename(oldsid, sid)
c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second)
}
return rp.SessionRead(context.Background(), sid)
}
// SessionDestroy delete redis session by id
func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := rp.poollist
c.Del(sid)
return nil
}
// SessionGC Impelment method, no used.
func (rp *Provider) SessionGC(context.Context) {
}
// SessionAll return all activeSession
func (rp *Provider) SessionAll(context.Context) int {
return 0
}
func init() {
session.Register("redis", redispder)
}

View File

@ -0,0 +1,96 @@
package redis
import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/astaxie/beego/pkg/server/web/session"
)
func TestRedis(t *testing.T) {
sessionConfig := &session.ManagerConfig{
CookieName: "gosessionid",
EnableSetCookie: true,
Gclifetime: 3600,
Maxlifetime: 3600,
Secure: false,
CookieLifeTime: 3600,
}
redisAddr := os.Getenv("REDIS_ADDR")
if redisAddr == "" {
redisAddr = "127.0.0.1:6379"
}
sessionConfig.ProviderConfig = fmt.Sprintf("%s,100,,0,30", redisAddr)
globalSession, err := session.NewManager("redis", sessionConfig)
if err != nil {
t.Fatal("could not create manager:", err)
}
go globalSession.GC()
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess, err := globalSession.SessionStart(w, r)
if err != nil {
t.Fatal("session start failed:", err)
}
defer sess.SessionRelease(nil, w)
// SET AND GET
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set username failed:", err)
}
username := sess.Get(nil, "username")
if username != "astaxie" {
t.Fatal("get username failed")
}
// DELETE
err = sess.Delete(nil, "username")
if err != nil {
t.Fatal("delete username failed:", err)
}
username = sess.Get(nil, "username")
if username != nil {
t.Fatal("delete username failed")
}
// FLUSH
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set failed:", err)
}
err = sess.Set(nil, "password", "1qaz2wsx")
if err != nil {
t.Fatal("set failed:", err)
}
username = sess.Get(nil, "username")
if username != "astaxie" {
t.Fatal("get username failed")
}
password := sess.Get(nil, "password")
if password != "1qaz2wsx" {
t.Fatal("get password failed")
}
err = sess.Flush(nil)
if err != nil {
t.Fatal("flush failed:", err)
}
username = sess.Get(nil, "username")
if username != nil {
t.Fatal("flush failed")
}
password = sess.Get(nil, "password")
if password != nil {
t.Fatal("flush failed")
}
sess.SessionRelease(nil, w)
}

View File

@ -0,0 +1,247 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package redis for session provider
//
// depend on github.com/go-redis/redis
//
// go install github.com/go-redis/redis
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/redis_cluster"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("redis_cluster", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070;127.0.0.1:7071"}``)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package redis_cluster
import (
"context"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/astaxie/beego/pkg/server/web/session"
rediss "github.com/go-redis/redis/v7"
)
var redispder = &Provider{}
// MaxPoolSize redis_cluster max pool size
var MaxPoolSize = 1000
// SessionStore redis_cluster session store
type SessionStore struct {
p *rediss.ClusterClient
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Set value in redis_cluster session
func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
return nil
}
// Get value in redis_cluster session
func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
return v
}
return nil
}
// Delete value in redis_cluster session
func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
return nil
}
// Flush clear all values in redis_cluster session
func (rs *SessionStore) Flush(context.Context) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
return nil
}
// SessionID get redis_cluster session id
func (rs *SessionStore) SessionID(context.Context) string {
return rs.sid
}
// SessionRelease save session values to redis_cluster
func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
}
c := rs.p
c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second)
}
// Provider redis_cluster session provider
type Provider struct {
maxlifetime int64
savePath string
poolsize int
password string
dbNum int
idleTimeout time.Duration
idleCheckFrequency time.Duration
maxRetries int
poollist *rediss.ClusterClient
}
// SessionInit init redis_cluster session
// savepath like redis server addr,pool size,password,dbnum
// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
rp.savePath = configs[0]
}
if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize < 0 {
rp.poolsize = MaxPoolSize
} else {
rp.poolsize = poolsize
}
} else {
rp.poolsize = MaxPoolSize
}
if len(configs) > 2 {
rp.password = configs[2]
}
if len(configs) > 3 {
dbnum, err := strconv.Atoi(configs[3])
if err != nil || dbnum < 0 {
rp.dbNum = 0
} else {
rp.dbNum = dbnum
}
} else {
rp.dbNum = 0
}
if len(configs) > 4 {
timeout, err := strconv.Atoi(configs[4])
if err == nil && timeout > 0 {
rp.idleTimeout = time.Duration(timeout) * time.Second
}
}
if len(configs) > 5 {
checkFrequency, err := strconv.Atoi(configs[5])
if err == nil && checkFrequency > 0 {
rp.idleCheckFrequency = time.Duration(checkFrequency) * time.Second
}
}
if len(configs) > 6 {
retries, err := strconv.Atoi(configs[6])
if err == nil && retries > 0 {
rp.maxRetries = retries
}
}
rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{
Addrs: strings.Split(rp.savePath, ";"),
Password: rp.password,
PoolSize: rp.poolsize,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.maxRetries,
})
return rp.poollist.Ping().Err()
}
// SessionRead read redis_cluster session by sid
func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
var kv map[interface{}]interface{}
kvs, err := rp.poollist.Get(sid).Result()
if err != nil && err != rediss.Nil {
return nil, err
}
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
if kv, err = session.DecodeGob([]byte(kvs)); err != nil {
return nil, err
}
}
rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
// SessionExist check redis_cluster session exist by sid
func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := rp.poollist
if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 {
return false, err
}
return true, nil
}
// SessionRegenerate generate new sid for redis_cluster session
func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := rp.poollist
if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second)
} else {
c.Rename(oldsid, sid)
c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second)
}
return rp.SessionRead(context.Background(), sid)
}
// SessionDestroy delete redis session by id
func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := rp.poollist
c.Del(sid)
return nil
}
// SessionGC Impelment method, no used.
func (rp *Provider) SessionGC(context.Context) {
}
// SessionAll return all activeSession
func (rp *Provider) SessionAll(context.Context) int {
return 0
}
func init() {
session.Register("redis_cluster", redispder)
}

View File

@ -0,0 +1,260 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package redis for session provider
//
// depend on github.com/go-redis/redis
//
// go install github.com/go-redis/redis
//
// Usage:
// import(
// _ "github.com/astaxie/beego/session/redis_sentinel"
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("redis_sentinel", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:26379;127.0.0.2:26379"}``)
// go globalSessions.GC()
// }
//
// more detail about params: please check the notes on the function SessionInit in this package
package redis_sentinel
import (
"context"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/astaxie/beego/pkg/server/web/session"
"github.com/go-redis/redis/v7"
)
var redispder = &Provider{}
// DefaultPoolSize redis_sentinel default pool size
var DefaultPoolSize = 100
// SessionStore redis_sentinel session store
type SessionStore struct {
p *redis.Client
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxlifetime int64
}
// Set value in redis_sentinel session
func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values[key] = value
return nil
}
// Get value in redis_sentinel session
func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
rs.lock.RLock()
defer rs.lock.RUnlock()
if v, ok := rs.values[key]; ok {
return v
}
return nil
}
// Delete value in redis_sentinel session
func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error {
rs.lock.Lock()
defer rs.lock.Unlock()
delete(rs.values, key)
return nil
}
// Flush clear all values in redis_sentinel session
func (rs *SessionStore) Flush(context.Context) error {
rs.lock.Lock()
defer rs.lock.Unlock()
rs.values = make(map[interface{}]interface{})
return nil
}
// SessionID get redis_sentinel session id
func (rs *SessionStore) SessionID(context.Context) string {
return rs.sid
}
// SessionRelease save session values to redis_sentinel
func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(rs.values)
if err != nil {
return
}
c := rs.p
c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second)
}
// Provider redis_sentinel session provider
type Provider struct {
maxlifetime int64
savePath string
poolsize int
password string
dbNum int
idleTimeout time.Duration
idleCheckFrequency time.Duration
maxRetries int
poollist *redis.Client
masterName string
}
// SessionInit init redis_sentinel session
// savepath like redis sentinel addr,pool size,password,dbnum,masterName
// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster
func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
rp.maxlifetime = maxlifetime
configs := strings.Split(savePath, ",")
if len(configs) > 0 {
rp.savePath = configs[0]
}
if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1])
if err != nil || poolsize < 0 {
rp.poolsize = DefaultPoolSize
} else {
rp.poolsize = poolsize
}
} else {
rp.poolsize = DefaultPoolSize
}
if len(configs) > 2 {
rp.password = configs[2]
}
if len(configs) > 3 {
dbnum, err := strconv.Atoi(configs[3])
if err != nil || dbnum < 0 {
rp.dbNum = 0
} else {
rp.dbNum = dbnum
}
} else {
rp.dbNum = 0
}
if len(configs) > 4 {
if configs[4] != "" {
rp.masterName = configs[4]
} else {
rp.masterName = "mymaster"
}
} else {
rp.masterName = "mymaster"
}
if len(configs) > 5 {
timeout, err := strconv.Atoi(configs[4])
if err == nil && timeout > 0 {
rp.idleTimeout = time.Duration(timeout) * time.Second
}
}
if len(configs) > 6 {
checkFrequency, err := strconv.Atoi(configs[5])
if err == nil && checkFrequency > 0 {
rp.idleCheckFrequency = time.Duration(checkFrequency) * time.Second
}
}
if len(configs) > 7 {
retries, err := strconv.Atoi(configs[6])
if err == nil && retries > 0 {
rp.maxRetries = retries
}
}
rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{
SentinelAddrs: strings.Split(rp.savePath, ";"),
Password: rp.password,
PoolSize: rp.poolsize,
DB: rp.dbNum,
MasterName: rp.masterName,
IdleTimeout: rp.idleTimeout,
IdleCheckFrequency: rp.idleCheckFrequency,
MaxRetries: rp.maxRetries,
})
return rp.poollist.Ping().Err()
}
// SessionRead read redis_sentinel session by sid
func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
var kv map[interface{}]interface{}
kvs, err := rp.poollist.Get(sid).Result()
if err != nil && err != redis.Nil {
return nil, err
}
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
if kv, err = session.DecodeGob([]byte(kvs)); err != nil {
return nil, err
}
}
rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
return rs, nil
}
// SessionExist check redis_sentinel session exist by sid
func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
c := rp.poollist
if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 {
return false, err
}
return true, nil
}
// SessionRegenerate generate new sid for redis_sentinel session
func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
c := rp.poollist
if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 {
// oldsid doesn't exists, set the new sid directly
// ignore error here, since if it return error
// the existed value will be 0
c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second)
} else {
c.Rename(oldsid, sid)
c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second)
}
return rp.SessionRead(context.Background(), sid)
}
// SessionDestroy delete redis session by id
func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error {
c := rp.poollist
c.Del(sid)
return nil
}
// SessionGC Impelment method, no used.
func (rp *Provider) SessionGC(context.Context) {
}
// SessionAll return all activeSession
func (rp *Provider) SessionAll(context.Context) int {
return 0
}
func init() {
session.Register("redis_sentinel", redispder)
}

View File

@ -0,0 +1,90 @@
package redis_sentinel
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/astaxie/beego/pkg/server/web/session"
)
func TestRedisSentinel(t *testing.T) {
sessionConfig := &session.ManagerConfig{
CookieName: "gosessionid",
EnableSetCookie: true,
Gclifetime: 3600,
Maxlifetime: 3600,
Secure: false,
CookieLifeTime: 3600,
ProviderConfig: "127.0.0.1:6379,100,,0,master",
}
globalSessions, e := session.NewManager("redis_sentinel", sessionConfig)
if e != nil {
t.Log(e)
return
}
//todo test if e==nil
go globalSessions.GC()
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess, err := globalSessions.SessionStart(w, r)
if err != nil {
t.Fatal("session start failed:", err)
}
defer sess.SessionRelease(nil, w)
// SET AND GET
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set username failed:", err)
}
username := sess.Get(nil, "username")
if username != "astaxie" {
t.Fatal("get username failed")
}
// DELETE
err = sess.Delete(nil, "username")
if err != nil {
t.Fatal("delete username failed:", err)
}
username = sess.Get(nil, "username")
if username != nil {
t.Fatal("delete username failed")
}
// FLUSH
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set failed:", err)
}
err = sess.Set(nil, "password", "1qaz2wsx")
if err != nil {
t.Fatal("set failed:", err)
}
username = sess.Get(nil, "username")
if username != "astaxie" {
t.Fatal("get username failed")
}
password := sess.Get(nil, "password")
if password != "1qaz2wsx" {
t.Fatal("get password failed")
}
err = sess.Flush(nil)
if err != nil {
t.Fatal("flush failed:", err)
}
username = sess.Get(nil, "username")
if username != nil {
t.Fatal("flush failed")
}
password = sess.Get(nil, "password")
if password != nil {
t.Fatal("flush failed")
}
sess.SessionRelease(nil, w)
}

View File

@ -0,0 +1,181 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"context"
"crypto/aes"
"crypto/cipher"
"encoding/json"
"net/http"
"net/url"
"sync"
)
var cookiepder = &CookieProvider{}
// CookieSessionStore Cookie SessionStore
type CookieSessionStore struct {
sid string
values map[interface{}]interface{} // session data
lock sync.RWMutex
}
// Set value to cookie session.
// the value are encoded as gob with hash block string.
func (st *CookieSessionStore) Set(ctx context.Context, key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
return nil
}
// Get value from cookie session
func (st *CookieSessionStore) Get(ctx context.Context, key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
}
return nil
}
// Delete value in cookie session
func (st *CookieSessionStore) Delete(ctx context.Context, key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
return nil
}
// Flush Clean all values in cookie session
func (st *CookieSessionStore) Flush(context.Context) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
return nil
}
// SessionID Return id of this cookie session
func (st *CookieSessionStore) SessionID(context.Context) string {
return st.sid
}
// SessionRelease Write cookie session to http response cookie
func (st *CookieSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
st.lock.Lock()
encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values)
st.lock.Unlock()
if err == nil {
cookie := &http.Cookie{Name: cookiepder.config.CookieName,
Value: url.QueryEscape(encodedCookie),
Path: "/",
HttpOnly: true,
Secure: cookiepder.config.Secure,
MaxAge: cookiepder.config.Maxage}
http.SetCookie(w, cookie)
}
}
type cookieConfig struct {
SecurityKey string `json:"securityKey"`
BlockKey string `json:"blockKey"`
SecurityName string `json:"securityName"`
CookieName string `json:"cookieName"`
Secure bool `json:"secure"`
Maxage int `json:"maxage"`
}
// CookieProvider Cookie session provider
type CookieProvider struct {
maxlifetime int64
config *cookieConfig
block cipher.Block
}
// SessionInit Init cookie session provider with max lifetime and config json.
// maxlifetime is ignored.
// json config:
// securityKey - hash string
// blockKey - gob encode hash string. it's saved as aes crypto.
// securityName - recognized name in encoded cookie string
// cookieName - cookie name
// maxage - cookie max life time.
func (pder *CookieProvider) SessionInit(ctx context.Context, maxlifetime int64, config string) error {
pder.config = &cookieConfig{}
err := json.Unmarshal([]byte(config), pder.config)
if err != nil {
return err
}
if pder.config.BlockKey == "" {
pder.config.BlockKey = string(generateRandomKey(16))
}
if pder.config.SecurityName == "" {
pder.config.SecurityName = string(generateRandomKey(20))
}
pder.block, err = aes.NewCipher([]byte(pder.config.BlockKey))
if err != nil {
return err
}
pder.maxlifetime = maxlifetime
return nil
}
// SessionRead Get SessionStore in cooke.
// decode cooke string to map and put into SessionStore with sid.
func (pder *CookieProvider) SessionRead(ctx context.Context, sid string) (Store, error) {
maps, _ := decodeCookie(pder.block,
pder.config.SecurityKey,
pder.config.SecurityName,
sid, pder.maxlifetime)
if maps == nil {
maps = make(map[interface{}]interface{})
}
rs := &CookieSessionStore{sid: sid, values: maps}
return rs, nil
}
// SessionExist Cookie session is always existed
func (pder *CookieProvider) SessionExist(ctx context.Context, sid string) (bool, error) {
return true, nil
}
// SessionRegenerate Implement method, no used.
func (pder *CookieProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) {
return nil, nil
}
// SessionDestroy Implement method, no used.
func (pder *CookieProvider) SessionDestroy(ctx context.Context, sid string) error {
return nil
}
// SessionGC Implement method, no used.
func (pder *CookieProvider) SessionGC(context.Context) {
}
// SessionAll Implement method, return 0.
func (pder *CookieProvider) SessionAll(context.Context) int {
return 0
}
// SessionUpdate Implement method, no used.
func (pder *CookieProvider) SessionUpdate(ctx context.Context, sid string) error {
return nil
}
func init() {
Register("cookie", cookiepder)
}

View File

@ -0,0 +1,105 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestCookie(t *testing.T) {
config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
conf := new(ManagerConfig)
if err := json.Unmarshal([]byte(config), conf); err != nil {
t.Fatal("json decode error", err)
}
globalSessions, err := NewManager("cookie", conf)
if err != nil {
t.Fatal("init cookie session err", err)
}
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess, err := globalSessions.SessionStart(w, r)
if err != nil {
t.Fatal("set error,", err)
}
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set error,", err)
}
if username := sess.Get(nil, "username"); username != "astaxie" {
t.Fatal("get username error")
}
sess.SessionRelease(nil, w)
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
t.Fatal("setcookie error")
} else {
parts := strings.Split(strings.TrimSpace(cookiestr), ";")
for k, v := range parts {
nameval := strings.Split(v, "=")
if k == 0 && nameval[0] != "gosessionid" {
t.Fatal("error")
}
}
}
}
func TestDestorySessionCookie(t *testing.T) {
config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
conf := new(ManagerConfig)
if err := json.Unmarshal([]byte(config), conf); err != nil {
t.Fatal("json decode error", err)
}
globalSessions, err := NewManager("cookie", conf)
if err != nil {
t.Fatal("init cookie session err", err)
}
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
session, err := globalSessions.SessionStart(w, r)
if err != nil {
t.Fatal("session start err,", err)
}
// request again ,will get same sesssion id .
r1, _ := http.NewRequest("GET", "/", nil)
r1.Header.Set("Cookie", w.Header().Get("Set-Cookie"))
w = httptest.NewRecorder()
newSession, err := globalSessions.SessionStart(w, r1)
if err != nil {
t.Fatal("session start err,", err)
}
if newSession.SessionID(nil) != session.SessionID(nil) {
t.Fatal("get cookie session id is not the same again.")
}
// After destroy session , will get a new session id .
globalSessions.SessionDestroy(w, r1)
r2, _ := http.NewRequest("GET", "/", nil)
r2.Header.Set("Cookie", w.Header().Get("Set-Cookie"))
w = httptest.NewRecorder()
newSession, err = globalSessions.SessionStart(w, r2)
if err != nil {
t.Fatal("session start error")
}
if newSession.SessionID(nil) == session.SessionID(nil) {
t.Fatal("after destroy session and reqeust again ,get cookie session id is same.")
}
}

View File

@ -0,0 +1,316 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"context"
"errors"
"fmt"
"io/ioutil"
"net/http"
"os"
"path"
"path/filepath"
"strings"
"sync"
"time"
)
var (
filepder = &FileProvider{}
gcmaxlifetime int64
)
// FileSessionStore File session store
type FileSessionStore struct {
sid string
lock sync.RWMutex
values map[interface{}]interface{}
}
// Set value to file session
func (fs *FileSessionStore) Set(ctx context.Context, key, value interface{}) error {
fs.lock.Lock()
defer fs.lock.Unlock()
fs.values[key] = value
return nil
}
// Get value from file session
func (fs *FileSessionStore) Get(ctx context.Context, key interface{}) interface{} {
fs.lock.RLock()
defer fs.lock.RUnlock()
if v, ok := fs.values[key]; ok {
return v
}
return nil
}
// Delete value in file session by given key
func (fs *FileSessionStore) Delete(ctx context.Context, key interface{}) error {
fs.lock.Lock()
defer fs.lock.Unlock()
delete(fs.values, key)
return nil
}
// Flush Clean all values in file session
func (fs *FileSessionStore) Flush(context.Context) error {
fs.lock.Lock()
defer fs.lock.Unlock()
fs.values = make(map[interface{}]interface{})
return nil
}
// SessionID Get file session store id
func (fs *FileSessionStore) SessionID(context.Context) string {
return fs.sid
}
// SessionRelease Write file session to local file with Gob string
func (fs *FileSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
b, err := EncodeGob(fs.values)
if err != nil {
SLogger.Println(err)
return
}
_, err = os.Stat(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid))
var f *os.File
if err == nil {
f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777)
if err != nil {
SLogger.Println(err)
return
}
} else if os.IsNotExist(err) {
f, err = os.Create(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid))
if err != nil {
SLogger.Println(err)
return
}
} else {
return
}
f.Truncate(0)
f.Seek(0, 0)
f.Write(b)
f.Close()
}
// FileProvider File session provider
type FileProvider struct {
lock sync.RWMutex
maxlifetime int64
savePath string
}
// SessionInit Init file session provider.
// savePath sets the session files path.
func (fp *FileProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
fp.maxlifetime = maxlifetime
fp.savePath = savePath
return nil
}
// SessionRead Read file session by sid.
// if file is not exist, create it.
// the file path is generated from sid string.
func (fp *FileProvider) SessionRead(ctx context.Context, sid string) (Store, error) {
invalidChars := "./"
if strings.ContainsAny(sid, invalidChars) {
return nil, errors.New("the sid shouldn't have following characters: " + invalidChars)
}
if len(sid) < 2 {
return nil, errors.New("length of the sid is less than 2")
}
filepder.lock.Lock()
defer filepder.lock.Unlock()
err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0755)
if err != nil {
SLogger.Println(err.Error())
}
_, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
var f *os.File
if err == nil {
f, err = os.OpenFile(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), os.O_RDWR, 0777)
} else if os.IsNotExist(err) {
f, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
} else {
return nil, err
}
defer f.Close()
os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now())
var kv map[interface{}]interface{}
b, err := ioutil.ReadAll(f)
if err != nil {
return nil, err
}
if len(b) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = DecodeGob(b)
if err != nil {
return nil, err
}
}
ss := &FileSessionStore{sid: sid, values: kv}
return ss, nil
}
// SessionExist Check file session exist.
// it checks the file named from sid exist or not.
func (fp *FileProvider) SessionExist(ctx context.Context, sid string) (bool, error) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
if len(sid) < 2 {
SLogger.Println("min length of session id is 2 but got length: ", sid)
return false, errors.New("min length of session id is 2")
}
_, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
return err == nil, nil
}
// SessionDestroy Remove all files in this save path
func (fp *FileProvider) SessionDestroy(ctx context.Context, sid string) error {
filepder.lock.Lock()
defer filepder.lock.Unlock()
os.Remove(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
return nil
}
// SessionGC Recycle files in save path
func (fp *FileProvider) SessionGC(context.Context) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
gcmaxlifetime = fp.maxlifetime
filepath.Walk(fp.savePath, gcpath)
}
// SessionAll Get active file session number.
// it walks save path to count files.
func (fp *FileProvider) SessionAll(context.Context) int {
a := &activeSession{}
err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error {
return a.visit(path, f, err)
})
if err != nil {
SLogger.Printf("filepath.Walk() returned %v\n", err)
return 0
}
return a.total
}
// SessionRegenerate Generate new sid for file session.
// it delete old file and create new file named from new sid.
func (fp *FileProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
oldPath := path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]))
oldSidFile := path.Join(oldPath, oldsid)
newPath := path.Join(fp.savePath, string(sid[0]), string(sid[1]))
newSidFile := path.Join(newPath, sid)
// new sid file is exist
_, err := os.Stat(newSidFile)
if err == nil {
return nil, fmt.Errorf("newsid %s exist", newSidFile)
}
err = os.MkdirAll(newPath, 0755)
if err != nil {
SLogger.Println(err.Error())
}
// if old sid file exist
// 1.read and parse file content
// 2.write content to new sid file
// 3.remove old sid file, change new sid file atime and ctime
// 4.return FileSessionStore
_, err = os.Stat(oldSidFile)
if err == nil {
b, err := ioutil.ReadFile(oldSidFile)
if err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if len(b) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = DecodeGob(b)
if err != nil {
return nil, err
}
}
ioutil.WriteFile(newSidFile, b, 0777)
os.Remove(oldSidFile)
os.Chtimes(newSidFile, time.Now(), time.Now())
ss := &FileSessionStore{sid: sid, values: kv}
return ss, nil
}
// if old sid file not exist, just create new sid file and return
newf, err := os.Create(newSidFile)
if err != nil {
return nil, err
}
newf.Close()
ss := &FileSessionStore{sid: sid, values: make(map[interface{}]interface{})}
return ss, nil
}
// remove file in save path if expired
func gcpath(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
if (info.ModTime().Unix() + gcmaxlifetime) < time.Now().Unix() {
os.Remove(path)
}
return nil
}
type activeSession struct {
total int
}
func (as *activeSession) visit(paths string, f os.FileInfo, err error) error {
if err != nil {
return err
}
if f.IsDir() {
return nil
}
as.total = as.total + 1
return nil
}
func init() {
Register("file", filepder)
}

View File

@ -0,0 +1,427 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"context"
"fmt"
"os"
"sync"
"testing"
"time"
)
const sid = "Session_id"
const sidNew = "Session_id_new"
const sessionPath = "./_session_runtime"
var (
mutex sync.Mutex
)
func TestFileProvider_SessionInit(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
if fp.maxlifetime != 180 {
t.Error()
}
if fp.savePath != sessionPath {
t.Error()
}
}
func TestFileProvider_SessionExist(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
exists, err := fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if exists {
t.Error()
}
_, err = fp.SessionRead(context.Background(), sid)
if err != nil {
t.Error(err)
}
exists, err = fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if !exists {
t.Error()
}
}
func TestFileProvider_SessionExist2(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
exists, err := fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if exists {
t.Error()
}
exists, err = fp.SessionExist(context.Background(), "")
if err == nil {
t.Error()
}
if exists {
t.Error()
}
exists, err = fp.SessionExist(context.Background(), "1")
if err == nil {
t.Error()
}
if exists {
t.Error()
}
}
func TestFileProvider_SessionRead(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
s, err := fp.SessionRead(context.Background(), sid)
if err != nil {
t.Error(err)
}
_ = s.Set(nil, "sessionValue", 18975)
v := s.Get(nil, "sessionValue")
if v.(int) != 18975 {
t.Error()
}
}
func TestFileProvider_SessionRead1(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
_, err := fp.SessionRead(context.Background(), "")
if err == nil {
t.Error(err)
}
_, err = fp.SessionRead(context.Background(), "1")
if err == nil {
t.Error(err)
}
}
func TestFileProvider_SessionAll(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 546
for i := 1; i <= sessionCount; i++ {
_, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil {
t.Error(err)
}
}
if fp.SessionAll(nil) != sessionCount {
t.Error()
}
}
func TestFileProvider_SessionRegenerate(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
_, err := fp.SessionRead(context.Background(), sid)
if err != nil {
t.Error(err)
}
exists, err := fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if !exists {
t.Error()
}
_, err = fp.SessionRegenerate(context.Background(), sid, sidNew)
if err != nil {
t.Error(err)
}
exists, err = fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if exists {
t.Error()
}
exists, err = fp.SessionExist(context.Background(), sidNew)
if err != nil {
t.Error(err)
}
if !exists {
t.Error()
}
}
func TestFileProvider_SessionDestroy(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
_, err := fp.SessionRead(context.Background(), sid)
if err != nil {
t.Error(err)
}
exists, err := fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if !exists {
t.Error()
}
err = fp.SessionDestroy(context.Background(), sid)
if err != nil {
t.Error(err)
}
exists, err = fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if exists {
t.Error()
}
}
func TestFileProvider_SessionGC(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 1, sessionPath)
sessionCount := 412
for i := 1; i <= sessionCount; i++ {
_, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil {
t.Error(err)
}
}
time.Sleep(2 * time.Second)
fp.SessionGC(nil)
if fp.SessionAll(nil) != 0 {
t.Error()
}
}
func TestFileSessionStore_Set(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 100
s, _ := fp.SessionRead(context.Background(), sid)
for i := 1; i <= sessionCount; i++ {
err := s.Set(nil, i, i)
if err != nil {
t.Error(err)
}
}
}
func TestFileSessionStore_Get(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 100
s, _ := fp.SessionRead(context.Background(), sid)
for i := 1; i <= sessionCount; i++ {
_ = s.Set(nil, i, i)
v := s.Get(nil, i)
if v.(int) != i {
t.Error()
}
}
}
func TestFileSessionStore_Delete(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
s, _ := fp.SessionRead(context.Background(), sid)
s.Set(nil, "1", 1)
if s.Get(nil, "1") == nil {
t.Error()
}
s.Delete(nil, "1")
if s.Get(nil, "1") != nil {
t.Error()
}
}
func TestFileSessionStore_Flush(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 100
s, _ := fp.SessionRead(context.Background(), sid)
for i := 1; i <= sessionCount; i++ {
_ = s.Set(nil, i, i)
}
_ = s.Flush(nil)
for i := 1; i <= sessionCount; i++ {
if s.Get(nil, i) != nil {
t.Error()
}
}
}
func TestFileSessionStore_SessionID(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
sessionCount := 85
for i := 1; i <= sessionCount; i++ {
s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil {
t.Error(err)
}
if s.SessionID(nil) != fmt.Sprintf("%s_%d", sid, i) {
t.Error(err)
}
}
}
func TestFileSessionStore_SessionRelease(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
filepder.savePath = sessionPath
sessionCount := 85
for i := 1; i <= sessionCount; i++ {
s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil {
t.Error(err)
}
s.Set(nil, i, i)
s.SessionRelease(nil, nil)
}
for i := 1; i <= sessionCount; i++ {
s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i))
if err != nil {
t.Error(err)
}
if s.Get(nil, i).(int) != i {
t.Error()
}
}
}

View File

@ -0,0 +1,197 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"container/list"
"context"
"net/http"
"sync"
"time"
)
var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)}
// MemSessionStore memory session store.
// it saved sessions in a map in memory.
type MemSessionStore struct {
sid string //session id
timeAccessed time.Time //last access time
value map[interface{}]interface{} //session store
lock sync.RWMutex
}
// Set value to memory session
func (st *MemSessionStore) Set(ctx context.Context, key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.value[key] = value
return nil
}
// Get value from memory session by key
func (st *MemSessionStore) Get(ctx context.Context, key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.value[key]; ok {
return v
}
return nil
}
// Delete in memory session by key
func (st *MemSessionStore) Delete(ctx context.Context, key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.value, key)
return nil
}
// Flush clear all values in memory session
func (st *MemSessionStore) Flush(context.Context) error {
st.lock.Lock()
defer st.lock.Unlock()
st.value = make(map[interface{}]interface{})
return nil
}
// SessionID get this id of memory session store
func (st *MemSessionStore) SessionID(context.Context) string {
return st.sid
}
// SessionRelease Implement method, no used.
func (st *MemSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
}
// MemProvider Implement the provider interface
type MemProvider struct {
lock sync.RWMutex // locker
sessions map[string]*list.Element // map in memory
list *list.List // for gc
maxlifetime int64
savePath string
}
// SessionInit init memory session
func (pder *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error {
pder.maxlifetime = maxlifetime
pder.savePath = savePath
return nil
}
// SessionRead get memory session store by sid
func (pder *MemProvider) SessionRead(ctx context.Context, sid string) (Store, error) {
pder.lock.RLock()
if element, ok := pder.sessions[sid]; ok {
go pder.SessionUpdate(nil, sid)
pder.lock.RUnlock()
return element.Value.(*MemSessionStore), nil
}
pder.lock.RUnlock()
pder.lock.Lock()
newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})}
element := pder.list.PushFront(newsess)
pder.sessions[sid] = element
pder.lock.Unlock()
return newsess, nil
}
// SessionExist check session store exist in memory session by sid
func (pder *MemProvider) SessionExist(ctx context.Context, sid string) (bool, error) {
pder.lock.RLock()
defer pder.lock.RUnlock()
if _, ok := pder.sessions[sid]; ok {
return true, nil
}
return false, nil
}
// SessionRegenerate generate new sid for session store in memory session
func (pder *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) {
pder.lock.RLock()
if element, ok := pder.sessions[oldsid]; ok {
go pder.SessionUpdate(nil, oldsid)
pder.lock.RUnlock()
pder.lock.Lock()
element.Value.(*MemSessionStore).sid = sid
pder.sessions[sid] = element
delete(pder.sessions, oldsid)
pder.lock.Unlock()
return element.Value.(*MemSessionStore), nil
}
pder.lock.RUnlock()
pder.lock.Lock()
newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})}
element := pder.list.PushFront(newsess)
pder.sessions[sid] = element
pder.lock.Unlock()
return newsess, nil
}
// SessionDestroy delete session store in memory session by id
func (pder *MemProvider) SessionDestroy(ctx context.Context, sid string) error {
pder.lock.Lock()
defer pder.lock.Unlock()
if element, ok := pder.sessions[sid]; ok {
delete(pder.sessions, sid)
pder.list.Remove(element)
return nil
}
return nil
}
// SessionGC clean expired session stores in memory session
func (pder *MemProvider) SessionGC(context.Context) {
pder.lock.RLock()
for {
element := pder.list.Back()
if element == nil {
break
}
if (element.Value.(*MemSessionStore).timeAccessed.Unix() + pder.maxlifetime) < time.Now().Unix() {
pder.lock.RUnlock()
pder.lock.Lock()
pder.list.Remove(element)
delete(pder.sessions, element.Value.(*MemSessionStore).sid)
pder.lock.Unlock()
pder.lock.RLock()
} else {
break
}
}
pder.lock.RUnlock()
}
// SessionAll get count number of memory session
func (pder *MemProvider) SessionAll(context.Context) int {
return pder.list.Len()
}
// SessionUpdate expand time of session store by id in memory session
func (pder *MemProvider) SessionUpdate(ctx context.Context, sid string) error {
pder.lock.Lock()
defer pder.lock.Unlock()
if element, ok := pder.sessions[sid]; ok {
element.Value.(*MemSessionStore).timeAccessed = time.Now()
pder.list.MoveToFront(element)
return nil
}
return nil
}
func init() {
Register("memory", mempder)
}

View File

@ -0,0 +1,58 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestMem(t *testing.T) {
config := `{"cookieName":"gosessionid","gclifetime":10, "enableSetCookie":true}`
conf := new(ManagerConfig)
if err := json.Unmarshal([]byte(config), conf); err != nil {
t.Fatal("json decode error", err)
}
globalSessions, _ := NewManager("memory", conf)
go globalSessions.GC()
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess, err := globalSessions.SessionStart(w, r)
if err != nil {
t.Fatal("set error,", err)
}
defer sess.SessionRelease(nil, w)
err = sess.Set(nil, "username", "astaxie")
if err != nil {
t.Fatal("set error,", err)
}
if username := sess.Get(nil, "username"); username != "astaxie" {
t.Fatal("get username error")
}
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
t.Fatal("setcookie error")
} else {
parts := strings.Split(strings.TrimSpace(cookiestr), ";")
for k, v := range parts {
nameval := strings.Split(v, "=")
if k == 0 && nameval[0] != "gosessionid" {
t.Fatal("error")
}
}
}
}

View File

@ -0,0 +1,131 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"crypto/aes"
"encoding/json"
"testing"
)
func Test_gob(t *testing.T) {
a := make(map[interface{}]interface{})
a["username"] = "astaxie"
a[12] = 234
a["user"] = User{"asta", "xie"}
b, err := EncodeGob(a)
if err != nil {
t.Error(err)
}
c, err := DecodeGob(b)
if err != nil {
t.Error(err)
}
if len(c) == 0 {
t.Error("decodeGob empty")
}
if c["username"] != "astaxie" {
t.Error("decode string error")
}
if c[12] != 234 {
t.Error("decode int error")
}
if c["user"].(User).Username != "asta" {
t.Error("decode struct error")
}
}
type User struct {
Username string
NickName string
}
func TestGenerate(t *testing.T) {
str := generateRandomKey(20)
if len(str) != 20 {
t.Fatal("generate length is not equal to 20")
}
}
func TestCookieEncodeDecode(t *testing.T) {
hashKey := "testhashKey"
blockkey := generateRandomKey(16)
block, err := aes.NewCipher(blockkey)
if err != nil {
t.Fatal("NewCipher:", err)
}
securityName := string(generateRandomKey(20))
val := make(map[interface{}]interface{})
val["name"] = "astaxie"
val["gender"] = "male"
str, err := encodeCookie(block, hashKey, securityName, val)
if err != nil {
t.Fatal("encodeCookie:", err)
}
dst, err := decodeCookie(block, hashKey, securityName, str, 3600)
if err != nil {
t.Fatal("decodeCookie", err)
}
if dst["name"] != "astaxie" {
t.Fatal("dst get map error")
}
if dst["gender"] != "male" {
t.Fatal("dst get map error")
}
}
func TestParseConfig(t *testing.T) {
s := `{"cookieName":"gosessionid","gclifetime":3600}`
cf := new(ManagerConfig)
cf.EnableSetCookie = true
err := json.Unmarshal([]byte(s), cf)
if err != nil {
t.Fatal("parse json error,", err)
}
if cf.CookieName != "gosessionid" {
t.Fatal("parseconfig get cookiename error")
}
if cf.Gclifetime != 3600 {
t.Fatal("parseconfig get gclifetime error")
}
cc := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
cf2 := new(ManagerConfig)
cf2.EnableSetCookie = true
err = json.Unmarshal([]byte(cc), cf2)
if err != nil {
t.Fatal("parse json error,", err)
}
if cf2.CookieName != "gosessionid" {
t.Fatal("parseconfig get cookiename error")
}
if cf2.Gclifetime != 3600 {
t.Fatal("parseconfig get gclifetime error")
}
if cf2.EnableSetCookie {
t.Fatal("parseconfig get enableSetCookie error")
}
cconfig := new(cookieConfig)
err = json.Unmarshal([]byte(cf2.ProviderConfig), cconfig)
if err != nil {
t.Fatal("parse ProviderConfig err,", err)
}
if cconfig.CookieName != "gosessionid" {
t.Fatal("ProviderConfig get cookieName error")
}
if cconfig.SecurityKey != "beegocookiehashkey" {
t.Fatal("ProviderConfig get securityKey error")
}
}

View File

@ -0,0 +1,207 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package session
import (
"bytes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/gob"
"errors"
"fmt"
"io"
"strconv"
"time"
"github.com/astaxie/beego/pkg/core/utils"
)
func init() {
gob.Register([]interface{}{})
gob.Register(map[int]interface{}{})
gob.Register(map[string]interface{}{})
gob.Register(map[interface{}]interface{}{})
gob.Register(map[string]string{})
gob.Register(map[int]string{})
gob.Register(map[int]int{})
gob.Register(map[int]int64{})
}
// EncodeGob encode the obj to gob
func EncodeGob(obj map[interface{}]interface{}) ([]byte, error) {
for _, v := range obj {
gob.Register(v)
}
buf := bytes.NewBuffer(nil)
enc := gob.NewEncoder(buf)
err := enc.Encode(obj)
if err != nil {
return []byte(""), err
}
return buf.Bytes(), nil
}
// DecodeGob decode data to map
func DecodeGob(encoded []byte) (map[interface{}]interface{}, error) {
buf := bytes.NewBuffer(encoded)
dec := gob.NewDecoder(buf)
var out map[interface{}]interface{}
err := dec.Decode(&out)
if err != nil {
return nil, err
}
return out, nil
}
// generateRandomKey creates a random key with the given strength.
func generateRandomKey(strength int) []byte {
k := make([]byte, strength)
if n, err := io.ReadFull(rand.Reader, k); n != strength || err != nil {
return utils.RandomCreateBytes(strength)
}
return k
}
// Encryption -----------------------------------------------------------------
// encrypt encrypts a value using the given block in counter mode.
//
// A random initialization vector (http://goo.gl/zF67k) with the length of the
// block size is prepended to the resulting ciphertext.
func encrypt(block cipher.Block, value []byte) ([]byte, error) {
iv := generateRandomKey(block.BlockSize())
if iv == nil {
return nil, errors.New("encrypt: failed to generate random iv")
}
// Encrypt it.
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(value, value)
// Return iv + ciphertext.
return append(iv, value...), nil
}
// decrypt decrypts a value using the given block in counter mode.
//
// The value to be decrypted must be prepended by a initialization vector
// (http://goo.gl/zF67k) with the length of the block size.
func decrypt(block cipher.Block, value []byte) ([]byte, error) {
size := block.BlockSize()
if len(value) > size {
// Extract iv.
iv := value[:size]
// Extract ciphertext.
value = value[size:]
// Decrypt it.
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(value, value)
return value, nil
}
return nil, errors.New("decrypt: the value could not be decrypted")
}
func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) {
var err error
var b []byte
// 1. EncodeGob.
if b, err = EncodeGob(value); err != nil {
return "", err
}
// 2. Encrypt (optional).
if b, err = encrypt(block, b); err != nil {
return "", err
}
b = encode(b)
// 3. Create MAC for "name|date|value". Extra pipe to be used later.
b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b))
h := hmac.New(sha256.New, []byte(hashKey))
h.Write(b)
sig := h.Sum(nil)
// Append mac, remove name.
b = append(b, sig...)[len(name)+1:]
// 4. Encode to base64.
b = encode(b)
// Done.
return string(b), nil
}
func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) {
// 1. Decode from base64.
b, err := decode([]byte(value))
if err != nil {
return nil, err
}
// 2. Verify MAC. Value is "date|value|mac".
parts := bytes.SplitN(b, []byte("|"), 3)
if len(parts) != 3 {
return nil, errors.New("Decode: invalid value format")
}
b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...)
h := hmac.New(sha256.New, []byte(hashKey))
h.Write(b)
sig := h.Sum(nil)
if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 {
return nil, errors.New("Decode: the value is not valid")
}
// 3. Verify date ranges.
var t1 int64
if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil {
return nil, errors.New("Decode: invalid timestamp")
}
t2 := time.Now().UTC().Unix()
if t1 > t2 {
return nil, errors.New("Decode: timestamp is too new")
}
if t1 < t2-gcmaxlifetime {
return nil, errors.New("Decode: expired timestamp")
}
// 4. Decrypt (optional).
b, err = decode(parts[1])
if err != nil {
return nil, err
}
if b, err = decrypt(block, b); err != nil {
return nil, err
}
// 5. DecodeGob.
dst, err := DecodeGob(b)
if err != nil {
return nil, err
}
return dst, nil
}
// Encoding -------------------------------------------------------------------
// encode encodes a value using base64.
func encode(value []byte) []byte {
encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value)))
base64.URLEncoding.Encode(encoded, value)
return encoded
}
// decode decodes a cookie using base64.
func decode(value []byte) ([]byte, error) {
decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value)))
b, err := base64.URLEncoding.Decode(decoded, value)
if err != nil {
return nil, err
}
return decoded[:b], nil
}

View File

@ -0,0 +1,384 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package session provider
//
// Usage:
// import(
// "github.com/astaxie/beego/session"
// )
//
// func init() {
// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "cookieLifeTime": 3600, "providerConfig": ""}`)
// go globalSessions.GC()
// }
//
// more docs: http://beego.me/docs/module/session.md
package session
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/textproto"
"net/url"
"os"
"time"
)
// Store contains all data for one session process with specific id.
type Store interface {
Set(ctx context.Context, key, value interface{}) error //set session value
Get(ctx context.Context, key interface{}) interface{} //get session value
Delete(ctx context.Context, key interface{}) error //delete session value
SessionID(ctx context.Context) string //back current sessionID
SessionRelease(ctx context.Context, w http.ResponseWriter) // release the resource & save data to provider & return the data
Flush(ctx context.Context) error //delete all data
}
// Provider contains global session methods and saved SessionStores.
// it can operate a SessionStore by its id.
type Provider interface {
SessionInit(ctx context.Context, gclifetime int64, config string) error
SessionRead(ctx context.Context, sid string) (Store, error)
SessionExist(ctx context.Context, sid string) (bool, error)
SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error)
SessionDestroy(ctx context.Context, sid string) error
SessionAll(ctx context.Context) int //get all active session
SessionGC(ctx context.Context)
}
var provides = make(map[string]Provider)
// SLogger a helpful variable to log information about session
var SLogger = NewSessionLog(os.Stderr)
// Register makes a session provide available by the provided name.
// If Register is called twice with the same name or if driver is nil,
// it panics.
func Register(name string, provide Provider) {
if provide == nil {
panic("session: Register provide is nil")
}
if _, dup := provides[name]; dup {
panic("session: Register called twice for provider " + name)
}
provides[name] = provide
}
//GetProvider
func GetProvider(name string) (Provider, error) {
provider, ok := provides[name]
if !ok {
return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", name)
}
return provider, nil
}
// ManagerConfig define the session config
type ManagerConfig struct {
CookieName string `json:"cookieName"`
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
Gclifetime int64 `json:"gclifetime"`
Maxlifetime int64 `json:"maxLifetime"`
DisableHTTPOnly bool `json:"disableHTTPOnly"`
Secure bool `json:"secure"`
CookieLifeTime int `json:"cookieLifeTime"`
ProviderConfig string `json:"providerConfig"`
Domain string `json:"domain"`
SessionIDLength int64 `json:"sessionIDLength"`
EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"`
SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"`
EnableSidInURLQuery bool `json:"EnableSidInURLQuery"`
SessionIDPrefix string `json:"sessionIDPrefix"`
}
// Manager contains Provider and its configuration.
type Manager struct {
provider Provider
config *ManagerConfig
}
// NewManager Create new Manager with provider name and json config string.
// provider name:
// 1. cookie
// 2. file
// 3. memory
// 4. redis
// 5. mysql
// json config:
// 1. is https default false
// 2. hashfunc default sha1
// 3. hashkey default beegosessionkey
// 4. maxage default is none
func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) {
provider, ok := provides[provideName]
if !ok {
return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName)
}
if cf.Maxlifetime == 0 {
cf.Maxlifetime = cf.Gclifetime
}
if cf.EnableSidInHTTPHeader {
if cf.SessionNameInHTTPHeader == "" {
panic(errors.New("SessionNameInHTTPHeader is empty"))
}
strMimeHeader := textproto.CanonicalMIMEHeaderKey(cf.SessionNameInHTTPHeader)
if cf.SessionNameInHTTPHeader != strMimeHeader {
strErrMsg := "SessionNameInHTTPHeader (" + cf.SessionNameInHTTPHeader + ") has the wrong format, it should be like this : " + strMimeHeader
panic(errors.New(strErrMsg))
}
}
err := provider.SessionInit(nil, cf.Maxlifetime, cf.ProviderConfig)
if err != nil {
return nil, err
}
if cf.SessionIDLength == 0 {
cf.SessionIDLength = 16
}
return &Manager{
provider,
cf,
}, nil
}
// GetProvider return current manager's provider
func (manager *Manager) GetProvider() Provider {
return manager.provider
}
// getSid retrieves session identifier from HTTP Request.
// First try to retrieve id by reading from cookie, session cookie name is configurable,
// if not exist, then retrieve id from querying parameters.
//
// error is not nil when there is anything wrong.
// sid is empty when need to generate a new session id
// otherwise return an valid session id.
func (manager *Manager) getSid(r *http.Request) (string, error) {
cookie, errs := r.Cookie(manager.config.CookieName)
if errs != nil || cookie.Value == "" {
var sid string
if manager.config.EnableSidInURLQuery {
errs := r.ParseForm()
if errs != nil {
return "", errs
}
sid = r.FormValue(manager.config.CookieName)
}
// if not found in Cookie / param, then read it from request headers
if manager.config.EnableSidInHTTPHeader && sid == "" {
sids, isFound := r.Header[manager.config.SessionNameInHTTPHeader]
if isFound && len(sids) != 0 {
return sids[0], nil
}
}
return sid, nil
}
// HTTP Request contains cookie for sessionid info.
return url.QueryUnescape(cookie.Value)
}
// SessionStart generate or read the session id from http request.
// if session id exists, return SessionStore with this id.
func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session Store, err error) {
sid, errs := manager.getSid(r)
if errs != nil {
return nil, errs
}
if sid != "" {
exists, err := manager.provider.SessionExist(nil, sid)
if err != nil {
return nil, err
}
if exists {
return manager.provider.SessionRead(nil, sid)
}
}
// Generate a new session
sid, errs = manager.sessionID()
if errs != nil {
return nil, errs
}
session, err = manager.provider.SessionRead(nil, sid)
if err != nil {
return nil, err
}
cookie := &http.Cookie{
Name: manager.config.CookieName,
Value: url.QueryEscape(sid),
Path: "/",
HttpOnly: !manager.config.DisableHTTPOnly,
Secure: manager.isSecure(r),
Domain: manager.config.Domain,
}
if manager.config.CookieLifeTime > 0 {
cookie.MaxAge = manager.config.CookieLifeTime
cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second)
}
if manager.config.EnableSetCookie {
http.SetCookie(w, cookie)
}
r.AddCookie(cookie)
if manager.config.EnableSidInHTTPHeader {
r.Header.Set(manager.config.SessionNameInHTTPHeader, sid)
w.Header().Set(manager.config.SessionNameInHTTPHeader, sid)
}
return
}
// SessionDestroy Destroy session by its id in http request cookie.
func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
if manager.config.EnableSidInHTTPHeader {
r.Header.Del(manager.config.SessionNameInHTTPHeader)
w.Header().Del(manager.config.SessionNameInHTTPHeader)
}
cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" {
return
}
sid, _ := url.QueryUnescape(cookie.Value)
manager.provider.SessionDestroy(nil, sid)
if manager.config.EnableSetCookie {
expiration := time.Now()
cookie = &http.Cookie{Name: manager.config.CookieName,
Path: "/",
HttpOnly: !manager.config.DisableHTTPOnly,
Expires: expiration,
MaxAge: -1,
Domain: manager.config.Domain}
http.SetCookie(w, cookie)
}
}
// GetSessionStore Get SessionStore by its id.
func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) {
sessions, err = manager.provider.SessionRead(nil, sid)
return
}
// GC Start session gc process.
// it can do gc in times after gc lifetime.
func (manager *Manager) GC() {
manager.provider.SessionGC(nil)
time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() })
}
// SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request.
func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) (session Store) {
sid, err := manager.sessionID()
if err != nil {
return
}
cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" {
//delete old cookie
session, _ = manager.provider.SessionRead(nil, sid)
cookie = &http.Cookie{Name: manager.config.CookieName,
Value: url.QueryEscape(sid),
Path: "/",
HttpOnly: !manager.config.DisableHTTPOnly,
Secure: manager.isSecure(r),
Domain: manager.config.Domain,
}
} else {
oldsid, _ := url.QueryUnescape(cookie.Value)
session, _ = manager.provider.SessionRegenerate(nil, oldsid, sid)
cookie.Value = url.QueryEscape(sid)
cookie.HttpOnly = true
cookie.Path = "/"
}
if manager.config.CookieLifeTime > 0 {
cookie.MaxAge = manager.config.CookieLifeTime
cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second)
}
if manager.config.EnableSetCookie {
http.SetCookie(w, cookie)
}
r.AddCookie(cookie)
if manager.config.EnableSidInHTTPHeader {
r.Header.Set(manager.config.SessionNameInHTTPHeader, sid)
w.Header().Set(manager.config.SessionNameInHTTPHeader, sid)
}
return
}
// GetActiveSession Get all active sessions count number.
func (manager *Manager) GetActiveSession() int {
return manager.provider.SessionAll(nil)
}
// SetSecure Set cookie with https.
func (manager *Manager) SetSecure(secure bool) {
manager.config.Secure = secure
}
func (manager *Manager) sessionID() (string, error) {
b := make([]byte, manager.config.SessionIDLength)
n, err := rand.Read(b)
if n != len(b) || err != nil {
return "", fmt.Errorf("Could not successfully read from the system CSPRNG")
}
return manager.config.SessionIDPrefix + hex.EncodeToString(b), nil
}
// Set cookie with https.
func (manager *Manager) isSecure(req *http.Request) bool {
if !manager.config.Secure {
return false
}
if req.URL.Scheme != "" {
return req.URL.Scheme == "https"
}
if req.TLS == nil {
return false
}
return true
}
// Log implement the log.Logger
type Log struct {
*log.Logger
}
// NewSessionLog set io.Writer to create a Logger for session.
func NewSessionLog(out io.Writer) *Log {
sl := new(Log)
sl.Logger = log.New(out, "[SESSION]", 1e9)
return sl
}

View File

@ -0,0 +1,201 @@
package ssdb
import (
"context"
"errors"
"net/http"
"strconv"
"strings"
"sync"
"github.com/ssdb/gossdb/ssdb"
"github.com/astaxie/beego/pkg/server/web/session"
)
var ssdbProvider = &Provider{}
// Provider holds ssdb client and configs
type Provider struct {
client *ssdb.Client
host string
port int
maxLifetime int64
}
func (p *Provider) connectInit() error {
var err error
if p.host == "" || p.port == 0 {
return errors.New("SessionInit First")
}
p.client, err = ssdb.Connect(p.host, p.port)
return err
}
// SessionInit init the ssdb with the config
func (p *Provider) SessionInit(ctx context.Context, maxLifetime int64, savePath string) error {
p.maxLifetime = maxLifetime
address := strings.Split(savePath, ":")
p.host = address[0]
var err error
if p.port, err = strconv.Atoi(address[1]); err != nil {
return err
}
return p.connectInit()
}
// SessionRead return a ssdb client session Store
func (p *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) {
if p.client == nil {
if err := p.connectInit(); err != nil {
return nil, err
}
}
var kv map[interface{}]interface{}
value, err := p.client.Get(sid)
if err != nil {
return nil, err
}
if value == nil || len(value.(string)) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob([]byte(value.(string)))
if err != nil {
return nil, err
}
}
rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client}
return rs, nil
}
// SessionExist judged whether sid is exist in session
func (p *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
if p.client == nil {
if err := p.connectInit(); err != nil {
return false, err
}
}
value, err := p.client.Get(sid)
if err != nil {
panic(err)
}
if value == nil || len(value.(string)) == 0 {
return false, nil
}
return true, nil
}
// SessionRegenerate regenerate session with new sid and delete oldsid
func (p *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
// conn.Do("setx", key, v, ttl)
if p.client == nil {
if err := p.connectInit(); err != nil {
return nil, err
}
}
value, err := p.client.Get(oldsid)
if err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if value == nil || len(value.(string)) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob([]byte(value.(string)))
if err != nil {
return nil, err
}
_, err = p.client.Del(oldsid)
if err != nil {
return nil, err
}
}
_, e := p.client.Do("setx", sid, value, p.maxLifetime)
if e != nil {
return nil, e
}
rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client}
return rs, nil
}
// SessionDestroy destroy the sid
func (p *Provider) SessionDestroy(ctx context.Context, sid string) error {
if p.client == nil {
if err := p.connectInit(); err != nil {
return err
}
}
_, err := p.client.Del(sid)
return err
}
// SessionGC not implemented
func (p *Provider) SessionGC(context.Context) {
}
// SessionAll not implemented
func (p *Provider) SessionAll(context.Context) int {
return 0
}
// SessionStore holds the session information which stored in ssdb
type SessionStore struct {
sid string
lock sync.RWMutex
values map[interface{}]interface{}
maxLifetime int64
client *ssdb.Client
}
// Set the key and value
func (s *SessionStore) Set(ctx context.Context, key, value interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
s.values[key] = value
return nil
}
// Get return the value by the key
func (s *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
s.lock.Lock()
defer s.lock.Unlock()
if value, ok := s.values[key]; ok {
return value
}
return nil
}
// Delete the key in session store
func (s *SessionStore) Delete(ctx context.Context, key interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
delete(s.values, key)
return nil
}
// Flush delete all keys and values
func (s *SessionStore) Flush(context.Context) error {
s.lock.Lock()
defer s.lock.Unlock()
s.values = make(map[interface{}]interface{})
return nil
}
// SessionID return the sessionID
func (s *SessionStore) SessionID(context.Context) string {
return s.sid
}
// SessionRelease Store the keyvalues into ssdb
func (s *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
b, err := session.EncodeGob(s.values)
if err != nil {
return
}
s.client.Do("setx", s.sid, string(b), s.maxLifetime)
}
func init() {
session.Register("ssdb", ssdbProvider)
}