1
0
mirror of https://github.com/astaxie/beego.git synced 2024-06-29 14:24:13 +00:00

modify session module

change a log
This commit is contained in:
astaxie 2014-01-05 14:48:36 +08:00
parent 95c65de97c
commit 481448fa90
11 changed files with 551 additions and 142 deletions

View File

@ -28,21 +28,21 @@ Then in you web app init the global session manager
* Use **memory** as provider: * Use **memory** as provider:
func init() { func init() {
globalSessions, _ = session.NewManager("memory", "gosessionid", 3600,"") globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`)
go globalSessions.GC() go globalSessions.GC()
} }
* Use **file** as provider, the last param is the path where you want file to be stored: * Use **file** as provider, the last param is the path where you want file to be stored:
func init() { func init() {
globalSessions, _ = session.NewManager("file", "gosessionid", 3600, "./tmp") globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","./tmp"}`)
go globalSessions.GC() go globalSessions.GC()
} }
* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password: * Use **Redis** as provider, the last param is the Redis conn address,poolsize,password:
func init() { func init() {
globalSessions, _ = session.NewManager("redis", "gosessionid", 3600, "127.0.0.1:6379,100,astaxie") globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","127.0.0.1:6379,100,astaxie"}`)
go globalSessions.GC() go globalSessions.GC()
} }
@ -50,15 +50,24 @@ Then in you web app init the global session manager
func init() { func init() {
globalSessions, _ = session.NewManager( globalSessions, _ = session.NewManager(
"mysql", "gosessionid", 3600, "username:password@protocol(address)/dbname?param=value") "mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","username:password@protocol(address)/dbname?param=value"}`)
go globalSessions.GC() go globalSessions.GC()
} }
* Use **Cookie** as provider:
func init() {
globalSessions, _ = session.NewManager(
"cookie", `{"cookieName":"gosessionid","enableSetCookie":false,gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`)
go globalSessions.GC()
}
Finally in the handlerfunc you can use it like this Finally in the handlerfunc you can use it like this
func login(w http.ResponseWriter, r *http.Request) { func login(w http.ResponseWriter, r *http.Request) {
sess := globalSessions.SessionStart(w, r) sess := globalSessions.SessionStart(w, r)
defer sess.SessionRelease() defer sess.SessionRelease(w)
username := sess.Get("username") username := sess.Get("username")
fmt.Println(username) fmt.Println(username)
if r.Method == "GET" { if r.Method == "GET" {
@ -78,19 +87,19 @@ When you develop a web app, maybe you want to write own provider because you mus
Writing a provider is easy. You only need to define two struct types Writing a provider is easy. You only need to define two struct types
(Session and Provider), which satisfy the interface definition. (Session and Provider), which satisfy the interface definition.
Maybe you will find the **memory** provider as good example. Maybe you will find the **memory** provider is a good example.
type SessionStore interface { type SessionStore interface {
Set(key, value interface{}) error //set session value Set(key, value interface{}) error //set session value
Get(key interface{}) interface{} //get session value Get(key interface{}) interface{} //get session value
Delete(key interface{}) error //delete session value Delete(key interface{}) error //delete session value
SessionID() string //back current sessionID SessionID() string //back current sessionID
SessionRelease() // release the resource & save data to provider SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
Flush() error //delete all data Flush() error //delete all data
} }
type Provider interface { type Provider interface {
SessionInit(maxlifetime int64, savePath string) error SessionInit(gclifetime int64, config string) error
SessionRead(sid string) (SessionStore, error) SessionRead(sid string) (SessionStore, error)
SessionExist(sid string) bool SessionExist(sid string) bool
SessionRegenerate(oldsid, sid string) (SessionStore, error) SessionRegenerate(oldsid, sid string) (SessionStore, error)

143
session/sess_cookie.go Normal file
View File

@ -0,0 +1,143 @@
package session
import (
"crypto/aes"
"crypto/cipher"
"encoding/json"
"net/http"
"net/url"
"sync"
)
var cookiepder = &CookieProvider{}
type CookieSessionStore struct {
sid string
values map[interface{}]interface{} //session data
lock sync.RWMutex
}
func (st *CookieSessionStore) Set(key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
return nil
}
func (st *CookieSessionStore) Get(key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
} else {
return nil
}
return nil
}
func (st *CookieSessionStore) Delete(key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
return nil
}
func (st *CookieSessionStore) Flush() error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
return nil
}
func (st *CookieSessionStore) SessionID() string {
return st.sid
}
func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) {
str, err := encodeCookie(cookiepder.block,
cookiepder.config.SecurityKey,
cookiepder.config.SecurityName,
st.values)
if err != nil {
return
}
cookie := &http.Cookie{Name: cookiepder.config.CookieName,
Value: url.QueryEscape(str),
Path: "/",
HttpOnly: true,
Secure: cookiepder.config.Secure}
http.SetCookie(w, cookie)
return
}
type cookieConfig struct {
SecurityKey string `json:"securityKey"`
BlockKey string `json:"blockKey"`
SecurityName string `json:"securityName"`
CookieName string `json:"cookieName"`
Secure bool `json:"secure"`
Maxage int `json:"maxage"`
}
type CookieProvider struct {
maxlifetime int64
config *cookieConfig
block cipher.Block
}
func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error {
pder.config = &cookieConfig{}
err := json.Unmarshal([]byte(config), pder.config)
if err != nil {
return err
}
if pder.config.BlockKey == "" {
pder.config.BlockKey = string(generateRandomKey(16))
}
if pder.config.SecurityName == "" {
pder.config.SecurityName = string(generateRandomKey(20))
}
pder.block, err = aes.NewCipher([]byte(pder.config.BlockKey))
if err != nil {
return err
}
return nil
}
func (pder *CookieProvider) SessionRead(sid string) (SessionStore, error) {
kv := make(map[interface{}]interface{})
kv, _ = decodeCookie(pder.block,
pder.config.SecurityKey,
pder.config.SecurityName,
sid, pder.maxlifetime)
rs := &CookieSessionStore{sid: sid, values: kv}
return rs, nil
}
func (pder *CookieProvider) SessionExist(sid string) bool {
return true
}
func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
return nil, nil
}
func (pder *CookieProvider) SessionDestroy(sid string) error {
return nil
}
func (pder *CookieProvider) SessionGC() {
return
}
func (pder *CookieProvider) SessionAll() int {
return 0
}
func (pder *CookieProvider) SessionUpdate(sid string) error {
return nil
}
func init() {
Register("cookie", cookiepder)
}

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
@ -60,7 +61,7 @@ func (fs *FileSessionStore) SessionID() string {
return fs.sid return fs.sid
} }
func (fs *FileSessionStore) SessionRelease() { func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
defer fs.f.Close() defer fs.f.Close()
b, err := encodeGob(fs.values) b, err := encodeGob(fs.values)
if err != nil { if err != nil {

View File

@ -1,38 +0,0 @@
package session
import (
"bytes"
"encoding/gob"
)
func init() {
gob.Register([]interface{}{})
gob.Register(map[int]interface{}{})
gob.Register(map[string]interface{}{})
gob.Register(map[interface{}]interface{}{})
gob.Register(map[string]string{})
gob.Register(map[int]string{})
gob.Register(map[int]int{})
gob.Register(map[int]int64{})
}
func encodeGob(obj map[interface{}]interface{}) ([]byte, error) {
buf := bytes.NewBuffer(nil)
enc := gob.NewEncoder(buf)
err := enc.Encode(obj)
if err != nil {
return []byte(""), err
}
return buf.Bytes(), nil
}
func decodeGob(encoded []byte) (map[interface{}]interface{}, error) {
buf := bytes.NewBuffer(encoded)
dec := gob.NewDecoder(buf)
var out map[interface{}]interface{}
err := dec.Decode(&out)
if err != nil {
return nil, err
}
return out, nil
}

View File

@ -2,6 +2,7 @@ package session
import ( import (
"container/list" "container/list"
"net/http"
"sync" "sync"
"time" "time"
) )
@ -9,9 +10,9 @@ import (
var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)} var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)}
type MemSessionStore struct { type MemSessionStore struct {
sid string //session id唯一标示 sid string //session id
timeAccessed time.Time //最后访问时间 timeAccessed time.Time //last access time
value map[interface{}]interface{} //session里面存储的值 value map[interface{}]interface{} //session store
lock sync.RWMutex lock sync.RWMutex
} }
@ -51,8 +52,7 @@ func (st *MemSessionStore) SessionID() string {
return st.sid return st.sid
} }
func (st *MemSessionStore) SessionRelease() { func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) {
} }
type MemProvider struct { type MemProvider struct {

35
session/sess_mem_test.go Normal file
View File

@ -0,0 +1,35 @@
package session
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestMem(t *testing.T) {
globalSessions, _ := NewManager("memory", `{"cookieName":"gosessionid","gclifetime":10}`)
go globalSessions.GC()
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess := globalSessions.SessionStart(w, r)
defer sess.SessionRelease(w)
err := sess.Set("username", "astaxie")
if err != nil {
t.Fatal("set error,", err)
}
if username := sess.Get("username"); username != "astaxie" {
t.Fatal("get username error")
}
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
t.Fatal("setcookie error")
} else {
parts := strings.Split(strings.TrimSpace(cookiestr), ";")
for k, v := range parts {
nameval := strings.Split(v, "=")
if k == 0 && nameval[0] != "gosessionid" {
t.Fatal("error")
}
}
}
}

View File

@ -9,6 +9,7 @@ package session
import ( import (
"database/sql" "database/sql"
"net/http"
"sync" "sync"
"time" "time"
@ -60,7 +61,7 @@ func (st *MysqlSessionStore) SessionID() string {
return st.sid return st.sid
} }
func (st *MysqlSessionStore) SessionRelease() { func (st *MysqlSessionStore) SessionRelease(w http.ResponseWriter) {
defer st.c.Close() defer st.c.Close()
if len(st.values) > 0 { if len(st.values) > 0 {
b, err := encodeGob(st.values) b, err := encodeGob(st.values)

View File

@ -1,6 +1,7 @@
package session package session
import ( import (
"net/http"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -58,7 +59,7 @@ func (rs *RedisSessionStore) SessionID() string {
return rs.sid return rs.sid
} }
func (rs *RedisSessionStore) SessionRelease() { func (rs *RedisSessionStore) SessionRelease(w http.ResponseWriter) {
defer rs.c.Close() defer rs.c.Close()
if len(rs.values) > 0 { if len(rs.values) > 0 {
b, err := encodeGob(rs.values) b, err := encodeGob(rs.values)

View File

@ -1,6 +1,8 @@
package session package session
import ( import (
"crypto/aes"
"encoding/json"
"testing" "testing"
) )
@ -26,3 +28,82 @@ func Test_gob(t *testing.T) {
t.Error("decode int error") t.Error("decode int error")
} }
} }
func TestGenerate(t *testing.T) {
str := generateRandomKey(20)
if len(str) != 20 {
t.Fatal("generate length is not equal to 20")
}
}
func TestCookieEncodeDecode(t *testing.T) {
hashKey := "testhashKey"
blockkey := generateRandomKey(16)
block, err := aes.NewCipher(blockkey)
if err != nil {
t.Fatal("NewCipher:", err)
}
securityName := string(generateRandomKey(20))
val := make(map[interface{}]interface{})
val["name"] = "astaxie"
val["gender"] = "male"
str, err := encodeCookie(block, hashKey, securityName, val)
if err != nil {
t.Fatal("encodeCookie:", err)
}
dst := make(map[interface{}]interface{})
dst, err = decodeCookie(block, hashKey, securityName, str, 3600)
if err != nil {
t.Fatal("decodeCookie", err)
}
if dst["name"] != "astaxie" {
t.Fatal("dst get map error")
}
if dst["gender"] != "male" {
t.Fatal("dst get map error")
}
}
func TestParseConfig(t *testing.T) {
s := `{"cookieName":"gosessionid","gclifetime":3600}`
cf := new(managerConfig)
cf.EnableSetCookie = true
err := json.Unmarshal([]byte(s), cf)
if err != nil {
t.Fatal("parse json error,", err)
}
if cf.CookieName != "gosessionid" {
t.Fatal("parseconfig get cookiename error")
}
if cf.Gclifetime != 3600 {
t.Fatal("parseconfig get gclifetime error")
}
cc := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
cf2 := new(managerConfig)
cf2.EnableSetCookie = true
err = json.Unmarshal([]byte(cc), cf2)
if err != nil {
t.Fatal("parse json error,", err)
}
if cf2.CookieName != "gosessionid" {
t.Fatal("parseconfig get cookiename error")
}
if cf2.Gclifetime != 3600 {
t.Fatal("parseconfig get gclifetime error")
}
if cf2.EnableSetCookie != false {
t.Fatal("parseconfig get enableSetCookie error")
}
cconfig := new(cookieConfig)
err = json.Unmarshal([]byte(cf2.ProviderConfig), cconfig)
if err != nil {
t.Fatal("parse ProviderConfig err,", err)
}
if cconfig.CookieName != "gosessionid" {
t.Fatal("ProviderConfig get cookieName error")
}
if cconfig.SecurityKey != "beegocookiehashkey" {
t.Fatal("ProviderConfig get securityKey error")
}
}

188
session/sess_utils.go Normal file
View File

@ -0,0 +1,188 @@
package session
import (
"bytes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"crypto/subtle"
"encoding/base64"
"encoding/gob"
"errors"
"fmt"
"io"
"strconv"
"time"
)
func init() {
gob.Register([]interface{}{})
gob.Register(map[int]interface{}{})
gob.Register(map[string]interface{}{})
gob.Register(map[interface{}]interface{}{})
gob.Register(map[string]string{})
gob.Register(map[int]string{})
gob.Register(map[int]int{})
gob.Register(map[int]int64{})
}
func encodeGob(obj map[interface{}]interface{}) ([]byte, error) {
buf := bytes.NewBuffer(nil)
enc := gob.NewEncoder(buf)
err := enc.Encode(obj)
if err != nil {
return []byte(""), err
}
return buf.Bytes(), nil
}
func decodeGob(encoded []byte) (map[interface{}]interface{}, error) {
buf := bytes.NewBuffer(encoded)
dec := gob.NewDecoder(buf)
var out map[interface{}]interface{}
err := dec.Decode(&out)
if err != nil {
return nil, err
}
return out, nil
}
// generateRandomKey creates a random key with the given strength.
func generateRandomKey(strength int) []byte {
k := make([]byte, strength)
if _, err := io.ReadFull(rand.Reader, k); err != nil {
return nil
}
return k
}
// Encryption -----------------------------------------------------------------
// encrypt encrypts a value using the given block in counter mode.
//
// A random initialization vector (http://goo.gl/zF67k) with the length of the
// block size is prepended to the resulting ciphertext.
func encrypt(block cipher.Block, value []byte) ([]byte, error) {
iv := generateRandomKey(block.BlockSize())
if iv == nil {
return nil, errors.New("encrypt: failed to generate random iv")
}
// Encrypt it.
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(value, value)
// Return iv + ciphertext.
return append(iv, value...), nil
}
// decrypt decrypts a value using the given block in counter mode.
//
// The value to be decrypted must be prepended by a initialization vector
// (http://goo.gl/zF67k) with the length of the block size.
func decrypt(block cipher.Block, value []byte) ([]byte, error) {
size := block.BlockSize()
if len(value) > size {
// Extract iv.
iv := value[:size]
// Extract ciphertext.
value = value[size:]
// Decrypt it.
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(value, value)
return value, nil
}
return nil, errors.New("decrypt: the value could not be decrypted")
}
func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) {
var err error
var b []byte
// 1. encodeGob.
if b, err = encodeGob(value); err != nil {
return "", err
}
// 2. Encrypt (optional).
if b, err = encrypt(block, b); err != nil {
return "", err
}
b = encode(b)
// 3. Create MAC for "name|date|value". Extra pipe to be used later.
b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b))
h := hmac.New(sha1.New, []byte(hashKey))
h.Write(b)
sig := h.Sum(nil)
// Append mac, remove name.
b = append(b, sig...)[len(name)+1:]
// 4. Encode to base64.
b = encode(b)
// Done.
return string(b), nil
}
func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) {
// 1. Decode from base64.
b, err := decode([]byte(value))
if err != nil {
return nil, err
}
// 2. Verify MAC. Value is "date|value|mac".
parts := bytes.SplitN(b, []byte("|"), 3)
if len(parts) != 3 {
return nil, errors.New("Decode: invalid value %v")
}
b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...)
h := hmac.New(sha1.New, []byte(hashKey))
h.Write(b)
sig := h.Sum(nil)
if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 {
return nil, errors.New("Decode: the value is not valid")
}
// 3. Verify date ranges.
var t1 int64
if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil {
return nil, errors.New("Decode: invalid timestamp")
}
t2 := time.Now().UTC().Unix()
if t1 > t2 {
return nil, errors.New("Decode: timestamp is too new")
}
if t1 < t2-gcmaxlifetime {
return nil, errors.New("Decode: expired timestamp")
}
// 4. Decrypt (optional).
b, err = decode(parts[1])
if err != nil {
return nil, err
}
if b, err = decrypt(block, b); err != nil {
return nil, err
}
// 5. decodeGob.
if dst, err := decodeGob(b); err != nil {
return nil, err
} else {
return dst, nil
}
// Done.
return nil, nil
}
// Encoding -------------------------------------------------------------------
// encode encodes a value using base64.
func encode(value []byte) []byte {
encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value)))
base64.URLEncoding.Encode(encoded, value)
return encoded
}
// decode decodes a cookie using base64.
func decode(value []byte) ([]byte, error) {
decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value)))
b, err := base64.URLEncoding.Decode(decoded, value)
if err != nil {
return nil, err
}
return decoded[:b], nil
}

View File

@ -6,6 +6,7 @@ import (
"crypto/rand" "crypto/rand"
"crypto/sha1" "crypto/sha1"
"encoding/hex" "encoding/hex"
"encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -18,12 +19,12 @@ type SessionStore interface {
Get(key interface{}) interface{} //get session value Get(key interface{}) interface{} //get session value
Delete(key interface{}) error //delete session value Delete(key interface{}) error //delete session value
SessionID() string //back current sessionID SessionID() string //back current sessionID
SessionRelease() // release the resource & save data to provider SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
Flush() error //delete all data Flush() error //delete all data
} }
type Provider interface { type Provider interface {
SessionInit(maxlifetime int64, savePath string) error SessionInit(gclifetime int64, config string) error
SessionRead(sid string) (SessionStore, error) SessionRead(sid string) (SessionStore, error)
SessionExist(sid string) bool SessionExist(sid string) bool
SessionRegenerate(oldsid, sid string) (SessionStore, error) SessionRegenerate(oldsid, sid string) (SessionStore, error)
@ -47,15 +48,21 @@ func Register(name string, provide Provider) {
provides[name] = provide provides[name] = provide
} }
type managerConfig struct {
CookieName string `json:"cookieName"`
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
Gclifetime int64 `json:"gclifetime"`
Maxage int `json:"maxage"`
Secure bool `json:"secure"`
SessionIDHashFunc string `json:"sessionIDHashFunc"`
SessionIDHashKey string `json:"sessionIDHashKey"`
CookieLifeTime int64 `json:"cookieLifeTime"`
ProviderConfig string `json:"providerConfig"`
}
type Manager struct { type Manager struct {
cookieName string //private cookiename
provider Provider provider Provider
maxlifetime int64 config *managerConfig
hashfunc string //support md5 & sha1
hashkey string
maxage int //cookielifetime
secure bool
options []interface{}
} }
//options //options
@ -63,74 +70,49 @@ type Manager struct {
//2. hashfunc default sha1 //2. hashfunc default sha1
//3. hashkey default beegosessionkey //3. hashkey default beegosessionkey
//4. maxage default is none //4. maxage default is none
func NewManager(provideName, cookieName string, maxlifetime int64, savePath string, options ...interface{}) (*Manager, error) { func NewManager(provideName, config string) (*Manager, error) {
provider, ok := provides[provideName] provider, ok := provides[provideName]
if !ok { if !ok {
return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName) return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName)
} }
provider.SessionInit(maxlifetime, savePath) cf := new(managerConfig)
secure := false cf.EnableSetCookie = true
if len(options) > 0 { err := json.Unmarshal([]byte(config), cf)
secure = options[0].(bool) if err != nil {
} return nil, err
hashfunc := "sha1"
if len(options) > 1 {
hashfunc = options[1].(string)
}
hashkey := "beegosessionkey"
if len(options) > 2 {
hashkey = options[2].(string)
}
maxage := -1
if len(options) > 3 {
switch options[3].(type) {
case int:
if options[3].(int) > 0 {
maxage = options[3].(int)
} else if options[3].(int) < 0 {
maxage = 0
}
case int64:
if options[3].(int64) > 0 {
maxage = int(options[3].(int64))
} else if options[3].(int64) < 0 {
maxage = 0
}
case int32:
if options[3].(int32) > 0 {
maxage = int(options[3].(int32))
} else if options[3].(int32) < 0 {
maxage = 0
} }
provider.SessionInit(cf.Gclifetime, cf.ProviderConfig)
if cf.SessionIDHashFunc == "" {
cf.SessionIDHashFunc = "sha1"
} }
if cf.SessionIDHashKey == "" {
cf.SessionIDHashKey = string(generateRandomKey(16))
} }
return &Manager{ return &Manager{
provider: provider, provider,
cookieName: cookieName, cf,
maxlifetime: maxlifetime,
hashfunc: hashfunc,
hashkey: hashkey,
maxage: maxage,
secure: secure,
options: options,
}, nil }, nil
} }
//get Session //get Session
func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) { func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) {
cookie, err := r.Cookie(manager.cookieName) cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" { if err != nil || cookie.Value == "" {
sid := manager.sessionId(r) sid := manager.sessionId(r)
session, _ = manager.provider.SessionRead(sid) session, _ = manager.provider.SessionRead(sid)
cookie = &http.Cookie{Name: manager.cookieName, cookie = &http.Cookie{Name: manager.config.CookieName,
Value: url.QueryEscape(sid), Value: url.QueryEscape(sid),
Path: "/", Path: "/",
HttpOnly: true, HttpOnly: true,
Secure: manager.secure} Secure: manager.config.Secure}
if manager.maxage >= 0 { if manager.config.Maxage >= 0 {
cookie.MaxAge = manager.maxage cookie.MaxAge = manager.config.Maxage
} }
if manager.config.EnableSetCookie {
http.SetCookie(w, cookie) http.SetCookie(w, cookie)
}
r.AddCookie(cookie) r.AddCookie(cookie)
} else { } else {
sid, _ := url.QueryUnescape(cookie.Value) sid, _ := url.QueryUnescape(cookie.Value)
@ -139,15 +121,17 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
} else { } else {
sid = manager.sessionId(r) sid = manager.sessionId(r)
session, _ = manager.provider.SessionRead(sid) session, _ = manager.provider.SessionRead(sid)
cookie = &http.Cookie{Name: manager.cookieName, cookie = &http.Cookie{Name: manager.config.CookieName,
Value: url.QueryEscape(sid), Value: url.QueryEscape(sid),
Path: "/", Path: "/",
HttpOnly: true, HttpOnly: true,
Secure: manager.secure} Secure: manager.config.Secure}
if manager.maxage >= 0 { if manager.config.Maxage >= 0 {
cookie.MaxAge = manager.maxage cookie.MaxAge = manager.config.Maxage
} }
if manager.config.EnableSetCookie {
http.SetCookie(w, cookie) http.SetCookie(w, cookie)
}
r.AddCookie(cookie) r.AddCookie(cookie)
} }
} }
@ -156,13 +140,17 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
//Destroy sessionid //Destroy sessionid
func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(manager.cookieName) cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" { if err != nil || cookie.Value == "" {
return return
} else { } else {
manager.provider.SessionDestroy(cookie.Value) manager.provider.SessionDestroy(cookie.Value)
expiration := time.Now() expiration := time.Now()
cookie := http.Cookie{Name: manager.cookieName, Path: "/", HttpOnly: true, Expires: expiration, MaxAge: -1} cookie := http.Cookie{Name: manager.config.CookieName,
Path: "/",
HttpOnly: true,
Expires: expiration,
MaxAge: -1}
http.SetCookie(w, &cookie) http.SetCookie(w, &cookie)
} }
} }
@ -174,20 +162,20 @@ func (manager *Manager) GetProvider(sid string) (sessions SessionStore, err erro
func (manager *Manager) GC() { func (manager *Manager) GC() {
manager.provider.SessionGC() manager.provider.SessionGC()
time.AfterFunc(time.Duration(manager.maxlifetime)*time.Second, func() { manager.GC() }) time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() })
} }
func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) { func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) {
sid := manager.sessionId(r) sid := manager.sessionId(r)
cookie, err := r.Cookie(manager.cookieName) cookie, err := r.Cookie(manager.config.CookieName)
if err != nil && cookie.Value == "" { if err != nil && cookie.Value == "" {
//delete old cookie //delete old cookie
session, _ = manager.provider.SessionRead(sid) session, _ = manager.provider.SessionRead(sid)
cookie = &http.Cookie{Name: manager.cookieName, cookie = &http.Cookie{Name: manager.config.CookieName,
Value: url.QueryEscape(sid), Value: url.QueryEscape(sid),
Path: "/", Path: "/",
HttpOnly: true, HttpOnly: true,
Secure: manager.secure, Secure: manager.config.Secure,
} }
} else { } else {
oldsid, _ := url.QueryUnescape(cookie.Value) oldsid, _ := url.QueryUnescape(cookie.Value)
@ -196,8 +184,8 @@ func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Reque
cookie.HttpOnly = true cookie.HttpOnly = true
cookie.Path = "/" cookie.Path = "/"
} }
if manager.maxage >= 0 { if manager.config.Maxage >= 0 {
cookie.MaxAge = manager.maxage cookie.MaxAge = manager.config.Maxage
} }
http.SetCookie(w, cookie) http.SetCookie(w, cookie)
r.AddCookie(cookie) r.AddCookie(cookie)
@ -209,12 +197,12 @@ func (manager *Manager) GetActiveSession() int {
} }
func (manager *Manager) SetHashFunc(hasfunc, hashkey string) { func (manager *Manager) SetHashFunc(hasfunc, hashkey string) {
manager.hashfunc = hasfunc manager.config.SessionIDHashFunc = hasfunc
manager.hashkey = hashkey manager.config.SessionIDHashKey = hashkey
} }
func (manager *Manager) SetSecure(secure bool) { func (manager *Manager) SetSecure(secure bool) {
manager.secure = secure manager.config.Secure = secure
} }
//remote_addr cruunixnano randdata //remote_addr cruunixnano randdata
@ -224,16 +212,16 @@ func (manager *Manager) sessionId(r *http.Request) (sid string) {
return "" return ""
} }
sig := fmt.Sprintf("%s%d%s", r.RemoteAddr, time.Now().UnixNano(), bs) sig := fmt.Sprintf("%s%d%s", r.RemoteAddr, time.Now().UnixNano(), bs)
if manager.hashfunc == "md5" { if manager.config.SessionIDHashFunc == "md5" {
h := md5.New() h := md5.New()
h.Write([]byte(sig)) h.Write([]byte(sig))
sid = hex.EncodeToString(h.Sum(nil)) sid = hex.EncodeToString(h.Sum(nil))
} else if manager.hashfunc == "sha1" { } else if manager.config.SessionIDHashFunc == "sha1" {
h := hmac.New(sha1.New, []byte(manager.hashkey)) h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey))
fmt.Fprintf(h, "%s", sig) fmt.Fprintf(h, "%s", sig)
sid = hex.EncodeToString(h.Sum(nil)) sid = hex.EncodeToString(h.Sum(nil))
} else { } else {
h := hmac.New(sha1.New, []byte(manager.hashkey)) h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey))
fmt.Fprintf(h, "%s", sig) fmt.Fprintf(h, "%s", sig)
sid = hex.EncodeToString(h.Sum(nil)) sid = hex.EncodeToString(h.Sum(nil))
} }