From f55bbbdff41b43402fe8f4767ed0ebfcc35da46b Mon Sep 17 00:00:00 2001 From: sidbusy Date: Sat, 5 Sep 2015 10:31:31 +0800 Subject: [PATCH] allows custom the TableName of Session --- session/mysql/sess_mysql.go | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/session/mysql/sess_mysql.go b/session/mysql/sess_mysql.go index 76a13932..26237e95 100644 --- a/session/mysql/sess_mysql.go +++ b/session/mysql/sess_mysql.go @@ -51,7 +51,10 @@ import ( _ "github.com/go-sql-driver/mysql" ) -var mysqlpder = &MysqlProvider{} +var ( + TableName = "session" + mysqlpder = &MysqlProvider{} +) // mysql session store type MysqlSessionStore struct { @@ -110,7 +113,7 @@ func (st *MysqlSessionStore) SessionRelease(w http.ResponseWriter) { if err != nil { return } - st.c.Exec("UPDATE session set `session_data`=?, `session_expiry`=? where session_key=?", + st.c.Exec("UPDATE "+TableName+" set `session_data`=?, `session_expiry`=? where session_key=?", b, time.Now().Unix(), st.sid) } @@ -141,11 +144,11 @@ func (mp *MysqlProvider) SessionInit(maxlifetime int64, savePath string) error { // get mysql session by sid func (mp *MysqlProvider) SessionRead(sid string) (session.SessionStore, error) { c := mp.connectInit() - row := c.QueryRow("select session_data from session where session_key=?", sid) + row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) var sessiondata []byte err := row.Scan(&sessiondata) if err == sql.ErrNoRows { - c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", + c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", sid, "", time.Now().Unix()) } var kv map[interface{}]interface{} @@ -165,7 +168,7 @@ func (mp *MysqlProvider) SessionRead(sid string) (session.SessionStore, error) { func (mp *MysqlProvider) SessionExist(sid string) bool { c := mp.connectInit() defer c.Close() - row := c.QueryRow("select session_data from session where session_key=?", sid) + row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) var sessiondata []byte err := row.Scan(&sessiondata) if err == sql.ErrNoRows { @@ -178,13 +181,13 @@ func (mp *MysqlProvider) SessionExist(sid string) bool { // generate new sid for mysql session func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (session.SessionStore, error) { c := mp.connectInit() - row := c.QueryRow("select session_data from session where session_key=?", oldsid) + row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid) var sessiondata []byte err := row.Scan(&sessiondata) if err == sql.ErrNoRows { - c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix()) + c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix()) } - c.Exec("update session set `session_key`=? where session_key=?", sid, oldsid) + c.Exec("update "+TableName+" set `session_key`=? where session_key=?", sid, oldsid) var kv map[interface{}]interface{} if len(sessiondata) == 0 { kv = make(map[interface{}]interface{}) @@ -201,7 +204,7 @@ func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (session.SessionS // delete mysql session by sid func (mp *MysqlProvider) SessionDestroy(sid string) error { c := mp.connectInit() - c.Exec("DELETE FROM session where session_key=?", sid) + c.Exec("DELETE FROM "+TableName+" where session_key=?", sid) c.Close() return nil } @@ -209,7 +212,7 @@ func (mp *MysqlProvider) SessionDestroy(sid string) error { // delete expired values in mysql session func (mp *MysqlProvider) SessionGC() { c := mp.connectInit() - c.Exec("DELETE from session where session_expiry < ?", time.Now().Unix()-mp.maxlifetime) + c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime) c.Close() return } @@ -219,7 +222,7 @@ func (mp *MysqlProvider) SessionAll() int { c := mp.connectInit() defer c.Close() var total int - err := c.QueryRow("SELECT count(*) as num from session").Scan(&total) + err := c.QueryRow("SELECT count(*) as num from " + TableName).Scan(&total) if err != nil { return 0 }