diff --git a/session/sess_file.go b/session/sess_file.go index 132f5a00..50687c9e 100644 --- a/session/sess_file.go +++ b/session/sess_file.go @@ -15,8 +15,7 @@ package session import ( - "errors" - "io" + "fmt" "io/ioutil" "net/http" "os" @@ -135,6 +134,9 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) { } else { return nil, err } + + defer f.Close() + os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now()) var kv map[interface{}]interface{} b, err := ioutil.ReadAll(f) @@ -149,7 +151,7 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) { return nil, err } } - f.Close() + ss := &FileSessionStore{sid: sid, values: kv} return ss, nil } @@ -204,49 +206,58 @@ func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { filepder.lock.Lock() defer filepder.lock.Unlock() - err := os.MkdirAll(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])), 0777) - if err != nil { - SLogger.Println(err.Error()) - } - err = os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777) - if err != nil { - SLogger.Println(err.Error()) - } - _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) - var newf *os.File + oldPath := path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])) + oldSidFile := path.Join(oldPath, oldsid) + newPath := path.Join(fp.savePath, string(sid[0]), string(sid[1])) + newSidFile := path.Join(newPath, sid) + + // new sid file is exist + _, err := os.Stat(newSidFile) if err == nil { - return nil, errors.New("newsid exist") - } else if os.IsNotExist(err) { - newf, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + return nil, fmt.Errorf("newsid %s exist", newSidFile) } - _, err = os.Stat(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]), oldsid)) - var f *os.File - if err == nil { - f, err = os.OpenFile(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]), oldsid), os.O_RDWR, 0777) - io.Copy(newf, f) - } else if os.IsNotExist(err) { - newf, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) - } else { - return nil, err - } - f.Close() - os.Remove(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]))) - os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now()) - var kv map[interface{}]interface{} - b, err := ioutil.ReadAll(newf) + err = os.MkdirAll(newPath, 0777) if err != nil { - return nil, err + SLogger.Println(err.Error()) } - if len(b) == 0 { - kv = make(map[interface{}]interface{}) - } else { - kv, err = DecodeGob(b) + + // if old sid file exist + // 1.read and parse file content + // 2.write content to new sid file + // 3.remove old sid file, change new sid file atime and ctime + // 4.return FileSessionStore + _, err = os.Stat(oldSidFile) + if err == nil { + b, err := ioutil.ReadFile(oldSidFile) if err != nil { return nil, err } + + var kv map[interface{}]interface{} + if len(b) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = DecodeGob(b) + if err != nil { + return nil, err + } + } + + ioutil.WriteFile(newSidFile, b, 0777) + os.Remove(oldSidFile) + os.Chtimes(newSidFile, time.Now(), time.Now()) + ss := &FileSessionStore{sid: sid, values: kv} + return ss, nil } - ss := &FileSessionStore{sid: sid, values: kv} + + // if old sid file not exist, just create new sid file and return + newf, err := os.Create(newSidFile) + if err != nil { + return nil, err + } + newf.Close() + ss := &FileSessionStore{sid: sid, values: make(map[interface{}]interface{})} return ss, nil }