diff --git a/session/ssdb/sess_ssdb.go b/session/ssdb/sess_ssdb.go index 26dd3e8a..cbb94840 100644 --- a/session/ssdb/sess_ssdb.go +++ b/session/ssdb/sess_ssdb.go @@ -15,7 +15,7 @@ var ssdbProvider = &SsdbProvider{} type SsdbProvider struct { client *ssdb.Client host string - port int32 + port int maxLifetime int64 } @@ -29,15 +29,15 @@ func (p *SsdbProvider) connectInit() error { } func (p *SsdbProvider) SessionInit(maxLifetime int64, savePath string) error { + var e error = nil p.maxLifetime = maxLifetime address := strings.Split(savePath, ":") p.host = address[0] - port, e := strconv.Atoi(address[1]) + p.port, e = strconv.Atoi(address[1]) if e != nil { return e } - p.port = address[1] - err := p.connectinit() + err := p.connectInit() if err != nil { return err } @@ -55,10 +55,10 @@ func (p *SsdbProvider) SessionRead(sid string) (session.Store, error) { if err != nil { return nil, err } - if value == nil || len(value) == 0 { + if value == nil || len(value.(string)) == 0 { kv = make(map[interface{}]interface{}) } else { - kv, err = session.DecodeGob([]byte(value)) + kv, err = session.DecodeGob([]byte(value.(string))) if err != nil { return nil, err } @@ -70,10 +70,14 @@ func (p *SsdbProvider) SessionRead(sid string) (session.Store, error) { func (p *SsdbProvider) SessionExist(sid string) bool { if p.client == nil { if err := p.connectInit(); err != nil { - return nil, err + panic(err) } } - if value == nil || len(value) == 0 { + value, err := p.client.Get(sid) + if err != nil { + panic(err) + } + if value == nil || len(value.(string)) == 0 { return false } return true @@ -87,24 +91,26 @@ func (p *SsdbProvider) SessionRegenerate(oldsid, sid string) (session.Store, err } } value, err := p.client.Get(oldsid) - if err != nil || len(value) == 0 { - value = "" - } else { - err = p.client.Del(sid) - } - _, e := p.client.Do("setx", sid, value, p.maxLifetime) - if e != nil { - return nil, e + if err != nil { + return nil, err } var kv map[interface{}]interface{} - if value == nil || len(value) == 0 { + if value == nil || len(value.(string)) == 0 { kv = make(map[interface{}]interface{}) } else { var err error - kv, err = session.DecodeGob(value) + kv, err = session.DecodeGob([]byte(value.(string))) if err != nil { return nil, err } + _, err = p.client.Del(oldsid) + if err != nil { + return nil, err + } + } + _, e := p.client.Do("setx", sid, value.(string), p.maxLifetime) + if e != nil { + return nil, e } rs := &SessionStore{sid: sid, values: kv, maxLifetime: p.maxLifetime, client: p.client} return rs, nil @@ -113,10 +119,10 @@ func (p *SsdbProvider) SessionRegenerate(oldsid, sid string) (session.Store, err func (p *SsdbProvider) SessionDestroy(sid string) error { if p.client == nil { if err := p.connectInit(); err != nil { - return nil, err + return err } } - flag, err := p.client.Del(sid) + _, err := p.client.Del(sid) if err != nil { return err } @@ -148,7 +154,7 @@ func (s *SessionStore) Set(key, value interface{}) error { func (s *SessionStore) Get(key interface{}) interface{} { s.lock.Lock() defer s.lock.Unlock() - if value, ok := rs.values[key]; ok { + if value, ok := s.values[key]; ok { return value } return nil @@ -175,7 +181,7 @@ func (s *SessionStore) SessionRelease(w http.ResponseWriter) { if err != nil { return } - s.client.Do("setx", s.sid, s.values, s.maxLifetime) + s.client.Do("setx", s.sid, string(b), s.maxLifetime) } func init() { diff --git a/session/ssdb/sess_ssdb_test.go b/session/ssdb/sess_ssdb_test.go index e69de29b..dfeba4a3 100644 --- a/session/ssdb/sess_ssdb_test.go +++ b/session/ssdb/sess_ssdb_test.go @@ -0,0 +1,60 @@ +package ssdb + +import ( + "fmt" + "net/http" + "testing" +) + +func Test(t *testing.T) { + p := &SsdbProvider{} + p.SessionInit(300, "127.0.0.1:8888") + if p.host != "127.0.0.1" || p.port != 8888 { + t.Error("host:port err") + } + if p.client == nil { + t.Error("client err") + } + ss, err := p.SessionRead("1") + if err != nil { + t.Error(err) + } + err = ss.Set("key", "value") + if err != nil { + t.Error(err) + } + if ss.Get("key") != "value" { + t.Error("Get err") + } + err = ss.Delete("key") + //err = ss.Flush() + if err != nil { + t.Error(err) + } + if ss.Get("key") == "value" { + t.Error("Delete/Flush err") + } + if ss.SessionID() != "1" { + t.Error("id err") + } + + ss.Set("key1", "value1") + var w http.ResponseWriter + ss.SessionRelease(w) + new, e := p.SessionRead("1") + if new == nil || e != nil { + t.Error(e) + } + if !p.SessionExist("1") { + t.Error("SessionExist err") + } + newS, er := p.SessionRegenerate("1", "3") + if er != nil || newS == nil { + t.Error("SessionRegenerate err") + } + if p.SessionExist("1") { + t.Error("SessionExist err") + } + fmt.Println(newS) + +}