diff --git a/README.md b/README.md index 42b13e65..f272cdb1 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ func (this *MainController) Get() { } func main() { - beego.RegisterController("/", &MainController{}) + beego.Router("/", &MainController{}) //beego.HttpPort = 8080 // default beego.Run() } @@ -54,16 +54,16 @@ Some associated tools for beego reside in:[bee](https://github.com/astaxie/bee) ============ In beego, a route is a struct paired with a URL-matching pattern. The struct has many method with the same name of http method to serve the http response. Each route is associated with a block. ```go -beego.RegisterController("/", &controllers.MainController{}) -beego.RegisterController("/admin", &admin.UserController{}) -beego.RegisterController("/admin/index", &admin.ArticleController{}) -beego.RegisterController("/admin/addpkg", &admin.AddController{}) +beego.Router("/", &controllers.MainController{}) +beego.Router("/admin", &admin.UserController{}) +beego.Router("/admin/index", &admin.ArticleController{}) +beego.Router("/admin/addpkg", &admin.AddController{}) ``` You can specify custom regular expressions for routes: ```go -beego.RegisterController("/admin/editpkg/:id([0-9]+)", &admin.EditController{}) -beego.RegisterController("/admin/delpkg/:id([0-9]+)", &admin.DelController{}) -beego.RegisterController("/:pkg(.*)", &controllers.MainController{}) +beego.Router("/admin/editpkg/:id([0-9]+)", &admin.EditController{}) +beego.Router("/admin/delpkg/:id([0-9]+)", &admin.DelController{}) +beego.Router("/:pkg(.*)", &controllers.MainController{}) ``` You can also create routes for static files: @@ -87,7 +87,7 @@ beego.Filter(FilterUser) ``` You can also apply filters only when certain REST URL Parameters exist: ```go -beego.RegisterController("/:id([0-9]+)", &admin.EditController{}) +beego.Router("/:id([0-9]+)", &admin.EditController{}) beego.FilterParam("id", func(rw http.ResponseWriter, r *http.Request) { ... }) diff --git a/beego.go b/beego.go index 36158a03..3286cf40 100644 --- a/beego.go +++ b/beego.go @@ -2,8 +2,7 @@ package beego import ( "fmt" - "github.com/astaxie/session" - _ "github.com/astaxie/session/providers/memory" + "github.com/astaxie/beego/session" "html/template" "net" "net/http" @@ -31,9 +30,10 @@ var ( AppConfig *Config //related to session SessionOn bool // wheather auto start session,default is false - SessionProvider string // default session provider memory + SessionProvider string // default session provider memory mysql redis SessionName string // sessionName cookie's name SessionGCMaxLifetime int64 // session's gc maxlifetime + SessionSavePath string // session savepath if use mysql/redis/file this set to the connectinfo UseFcgi bool GlobalSessions *session.Manager //GlobalSessions @@ -60,6 +60,7 @@ func init() { SessionProvider = "memory" SessionName = "beegosessionID" SessionGCMaxLifetime = 3600 + SessionSavePath = "" UseFcgi = false } else { HttpAddr = AppConfig.String("httpaddr") @@ -109,6 +110,11 @@ func init() { } else { SessionName = ar } + if ar := AppConfig.String("sessionsavepath"); ar == "" { + SessionSavePath = "" + } else { + SessionSavePath = ar + } if ar, err := AppConfig.Int("sessiongcmaxlifetime"); err != nil && ar != 0 { int64val, _ := strconv.ParseInt(strconv.Itoa(ar), 10, 64) SessionGCMaxLifetime = int64val @@ -222,7 +228,7 @@ func Run() { BeeApp.Router(`/debug/pprof/:pp([\w]+)`, &ProfController{}) } if SessionOn { - GlobalSessions, _ = session.NewManager(SessionProvider, SessionName, SessionGCMaxLifetime) + GlobalSessions, _ = session.NewManager(SessionProvider, SessionName, SessionGCMaxLifetime, SessionSavePath) go GlobalSessions.GC() } err := BuildTemplate(ViewsPath) diff --git a/controller.go b/controller.go index f2ef3c31..30580160 100644 --- a/controller.go +++ b/controller.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/json" "encoding/xml" - "github.com/astaxie/session" + "github.com/astaxie/beego/session" "html/template" "io/ioutil" "net/http" @@ -51,7 +51,6 @@ func (c *Controller) Prepare() { } func (c *Controller) Finish() { - } func (c *Controller) Get() { @@ -82,21 +81,6 @@ func (c *Controller) Options() { http.Error(c.Ctx.ResponseWriter, "Method Not Allowed", 405) } -func (c *Controller) SetSession(name string, value interface{}) { - ss := c.StartSession() - ss.Set(name, value) -} - -func (c *Controller) GetSession(name string) interface{} { - ss := c.StartSession() - return ss.Get(name) -} - -func (c *Controller) DelSession(name string) { - ss := c.StartSession() - ss.Delete(name) -} - func (c *Controller) Render() error { rb, err := c.RenderBytes() @@ -190,7 +174,25 @@ func (c *Controller) Input() url.Values { return c.Ctx.Request.Form } -func (c *Controller) StartSession() (sess session.Session) { +func (c *Controller) StartSession() (sess session.SessionStore) { sess = GlobalSessions.SessionStart(c.Ctx.ResponseWriter, c.Ctx.Request) return } + +func (c *Controller) SetSession(name string, value interface{}) { + ss := c.StartSession() + defer ss.SessionRelease() + ss.Set(name, value) +} + +func (c *Controller) GetSession(name string) interface{} { + ss := c.StartSession() + defer ss.SessionRelease() + return ss.Get(name) +} + +func (c *Controller) DelSession(name string) { + ss := c.StartSession() + defer ss.SessionRelease() + ss.Delete(name) +} diff --git a/session/sess_file.go b/session/sess_file.go new file mode 100644 index 00000000..d5ab6b4a --- /dev/null +++ b/session/sess_file.go @@ -0,0 +1,138 @@ +package session + +import ( + "io/ioutil" + "os" + "path" + "path/filepath" + "sync" + "time" +) + +var ( + filepder = &FileProvider{} + gcmaxlifetime int64 +) + +type FileSessionStore struct { + f *os.File + sid string + lock sync.RWMutex + values map[interface{}]interface{} +} + +func (fs *FileSessionStore) Set(key, value interface{}) error { + fs.lock.Lock() + defer fs.lock.Unlock() + fs.values[key] = value + fs.updatecontent() + return nil +} + +func (fs *FileSessionStore) Get(key interface{}) interface{} { + fs.lock.RLock() + defer fs.lock.RUnlock() + fs.updatecontent() + if v, ok := fs.values[key]; ok { + return v + } else { + return nil + } + return nil +} + +func (fs *FileSessionStore) Delete(key interface{}) error { + fs.lock.Lock() + defer fs.lock.Unlock() + delete(fs.values, key) + fs.updatecontent() + return nil +} + +func (fs *FileSessionStore) SessionID() string { + return fs.sid +} + +func (fs *FileSessionStore) SessionRelease() { + fs.f.Close() +} + +func (fs *FileSessionStore) updatecontent() { + b, err := encodeGob(fs.values) + if err != nil { + return + } + fs.f.Write(b) +} + +type FileProvider struct { + maxlifetime int64 + savePath string +} + +func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { + fp.maxlifetime = maxlifetime + fp.savePath = savePath + return nil +} + +func (fp *FileProvider) SessionRead(sid string) (SessionStore, error) { + err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777) + if err != nil { + println(err.Error()) + } + _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + var f *os.File + if err == nil { + f, err = os.OpenFile(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), os.O_RDWR, 0777) + } else if os.IsNotExist(err) { + f, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + } else { + return nil, err + } + os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now()) + var kv map[interface{}]interface{} + b, err := ioutil.ReadAll(f) + if err != nil { + return nil, err + } + if len(b) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = decodeGob(b) + if err != nil { + return nil, err + } + } + f.Close() + f, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + ss := &FileSessionStore{f: f, sid: sid, values: kv} + return ss, nil +} + +func (fp *FileProvider) SessionDestroy(sid string) error { + os.Remove(path.Join(fp.savePath)) + return nil +} + +func (fp *FileProvider) SessionGC() { + gcmaxlifetime = fp.maxlifetime + filepath.Walk(fp.savePath, gcpath) +} + +func gcpath(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + if (info.ModTime().Unix() + gcmaxlifetime) < time.Now().Unix() { + os.Remove(path) + } + return nil +} + +func init() { + Register("file", filepder) +} diff --git a/session/sess_gob.go b/session/sess_gob.go new file mode 100644 index 00000000..92313947 --- /dev/null +++ b/session/sess_gob.go @@ -0,0 +1,38 @@ +package session + +import ( + "bytes" + "encoding/gob" +) + +func init() { + gob.Register([]interface{}{}) + gob.Register(map[int]interface{}{}) + gob.Register(map[string]interface{}{}) + gob.Register(map[interface{}]interface{}{}) + gob.Register(map[string]string{}) + gob.Register(map[int]string{}) + gob.Register(map[int]int{}) + gob.Register(map[int]int64{}) +} + +func encodeGob(obj map[interface{}]interface{}) ([]byte, error) { + buf := bytes.NewBuffer(nil) + enc := gob.NewEncoder(buf) + err := enc.Encode(obj) + if err != nil { + return []byte(""), err + } + return buf.Bytes(), nil +} + +func decodeGob(encoded []byte) (map[interface{}]interface{}, error) { + buf := bytes.NewBuffer(encoded) + dec := gob.NewDecoder(buf) + var out map[interface{}]interface{} + err := dec.Decode(&out) + if err != nil { + return nil, err + } + return out, nil +} diff --git a/session/sess_mem.go b/session/sess_mem.go new file mode 100644 index 00000000..c213293d --- /dev/null +++ b/session/sess_mem.go @@ -0,0 +1,128 @@ +package session + +import ( + "container/list" + "sync" + "time" +) + +var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)} + +type MemSessionStore struct { + sid string //session id唯一标示 + timeAccessed time.Time //最后访问时间 + value map[interface{}]interface{} //session里面存储的值 + lock sync.RWMutex +} + +func (st *MemSessionStore) Set(key, value interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + st.value[key] = value + return nil +} + +func (st *MemSessionStore) Get(key interface{}) interface{} { + st.lock.RLock() + defer st.lock.RUnlock() + if v, ok := st.value[key]; ok { + return v + } else { + return nil + } + return nil +} + +func (st *MemSessionStore) Delete(key interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + delete(st.value, key) + return nil +} + +func (st *MemSessionStore) SessionID() string { + return st.sid +} + +func (st *MemSessionStore) SessionRelease() { + +} + +type MemProvider struct { + lock sync.RWMutex //用来锁 + sessions map[string]*list.Element //用来存储在内存 + list *list.List //用来做gc + maxlifetime int64 + savePath string +} + +func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error { + pder.maxlifetime = maxlifetime + pder.savePath = savePath + return nil +} + +func (pder *MemProvider) SessionRead(sid string) (SessionStore, error) { + pder.lock.RLock() + if element, ok := pder.sessions[sid]; ok { + go pder.SessionUpdate(sid) + pder.lock.RUnlock() + return element.Value.(*MemSessionStore), nil + } else { + pder.lock.RUnlock() + pder.lock.Lock() + newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})} + element := pder.list.PushBack(newsess) + pder.sessions[sid] = element + pder.lock.Unlock() + return newsess, nil + } + return nil, nil +} + +func (pder *MemProvider) SessionDestroy(sid string) error { + pder.lock.Lock() + defer pder.lock.Unlock() + if element, ok := pder.sessions[sid]; ok { + delete(pder.sessions, sid) + pder.list.Remove(element) + return nil + } + return nil +} + +func (pder *MemProvider) SessionGC() { + pder.lock.RLock() + for { + element := pder.list.Back() + if element == nil { + break + } + if (element.Value.(*MemSessionStore).timeAccessed.Unix() + pder.maxlifetime) < time.Now().Unix() { + pder.lock.RUnlock() + pder.lock.Lock() + pder.list.Remove(element) + delete(pder.sessions, element.Value.(*MemSessionStore).sid) + pder.lock.Unlock() + pder.lock.RLock() + } else { + break + } + } + pder.lock.RUnlock() +} + +func (pder *MemProvider) SessionUpdate(sid string) error { + pder.lock.RLock() + defer pder.lock.RUnlock() + if element, ok := pder.sessions[sid]; ok { + element.Value.(*MemSessionStore).timeAccessed = time.Now() + pder.list.MoveToFront(element) + return nil + } + return nil +} + +func init() { + Register("memory", mempder) +} diff --git a/session/sess_mysql.go b/session/sess_mysql.go new file mode 100644 index 00000000..a55d5612 --- /dev/null +++ b/session/sess_mysql.go @@ -0,0 +1,125 @@ +package session + +//CREATE TABLE `session` ( +// `session_key` char(64) NOT NULL, +// `session_data` blob, +// `session_expiry` int(11) unsigned NOT NULL, +// PRIMARY KEY (`session_key`) +//) ENGINE=MyISAM DEFAULT CHARSET=utf8; + +import ( + "database/sql" + _ "github.com/go-sql-driver/mysql" + "sync" + "time" +) + +var mysqlpder = &MysqlProvider{} + +type MysqlSessionStore struct { + c *sql.DB + sid string + lock sync.RWMutex + values map[interface{}]interface{} +} + +func (st *MysqlSessionStore) Set(key, value interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + st.values[key] = value + st.updatemysql() + return nil +} + +func (st *MysqlSessionStore) Get(key interface{}) interface{} { + st.lock.RLock() + defer st.lock.RUnlock() + if v, ok := st.values[key]; ok { + return v + } else { + return nil + } + return nil +} + +func (st *MysqlSessionStore) Delete(key interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + delete(st.values, key) + st.updatemysql() + return nil +} + +func (st *MysqlSessionStore) SessionID() string { + return st.sid +} + +func (st *MysqlSessionStore) updatemysql() { + b, err := encodeGob(st.values) + if err != nil { + return + } + st.c.Exec("UPDATE session set `session_data`= ? where session_key=?", b, st.sid) +} + +func (st *MysqlSessionStore) SessionRelease() { + st.c.Close() +} + +type MysqlProvider struct { + maxlifetime int64 + savePath string +} + +func (mp *MysqlProvider) connectInit() *sql.DB { + db, e := sql.Open("mysql", mp.savePath) + if e != nil { + return nil + } + return db +} + +func (mp *MysqlProvider) SessionInit(maxlifetime int64, savePath string) error { + mp.maxlifetime = maxlifetime + mp.savePath = savePath + return nil +} + +func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) { + c := mp.connectInit() + row := c.QueryRow("select session_data from session where session_key=?", sid) + var sessiondata []byte + err := row.Scan(&sessiondata) + if err == sql.ErrNoRows { + c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", sid, "", time.Now().Unix()) + } + var kv map[interface{}]interface{} + if len(sessiondata) == 0 { + kv = make(map[interface{}]interface{}) + } else { + kv, err = decodeGob(sessiondata) + if err != nil { + return nil, err + } + } + rs := &MysqlSessionStore{c: c, sid: sid, values: kv} + return rs, nil +} + +func (mp *MysqlProvider) SessionDestroy(sid string) error { + c := mp.connectInit() + c.Exec("DELETE FROM session where session_key=?", sid) + c.Close() + return nil +} + +func (mp *MysqlProvider) SessionGC() { + c := mp.connectInit() + c.Exec("DELETE from session where session_expiry < ?", time.Now().Unix()-mp.maxlifetime) + c.Close() + return +} + +func init() { + Register("mysql", mysqlpder) +} diff --git a/session/sess_redis.go b/session/sess_redis.go new file mode 100644 index 00000000..9d3dcd96 --- /dev/null +++ b/session/sess_redis.go @@ -0,0 +1,80 @@ +package session + +import ( + "github.com/garyburd/redigo/redis" +) + +var redispder = &RedisProvider{} + +type RedisSessionStore struct { + c redis.Conn + sid string +} + +func (rs *RedisSessionStore) Set(key, value interface{}) error { + _, err := rs.c.Do("HSET", rs.sid, key, value) + return err +} + +func (rs *RedisSessionStore) Get(key interface{}) interface{} { + v, err := rs.c.Do("GET", rs.sid, key) + if err != nil { + return nil + } + return v +} + +func (rs *RedisSessionStore) Delete(key interface{}) error { + _, err := rs.c.Do("HDEL", rs.sid, key) + return err +} + +func (rs *RedisSessionStore) SessionID() string { + return rs.sid +} + +func (rs *RedisSessionStore) SessionRelease() { + rs.c.Close() +} + +type RedisProvider struct { + maxlifetime int64 + savePath string +} + +func (rp *RedisProvider) connectInit() redis.Conn { + c, err := redis.Dial("tcp", rp.savePath) + if err != nil { + return nil + } + return c +} + +func (rp *RedisProvider) SessionInit(maxlifetime int64, savePath string) error { + rp.maxlifetime = maxlifetime + rp.savePath = savePath + return nil +} + +func (rp *RedisProvider) SessionRead(sid string) (SessionStore, error) { + c := rp.connectInit() + if str, err := redis.String(c.Do("GET", sid)); err != nil || str == "" { + c.Do("SET", sid, sid, rp.maxlifetime) + } + rs := &RedisSessionStore{c: c, sid: sid} + return rs, nil +} + +func (rp *RedisProvider) SessionDestroy(sid string) error { + c := rp.connectInit() + c.Do("DEL", sid) + return nil +} + +func (rp *RedisProvider) SessionGC() { + return +} + +func init() { + Register("redis", redispder) +} diff --git a/session/session.go b/session/session.go new file mode 100644 index 00000000..71756139 --- /dev/null +++ b/session/session.go @@ -0,0 +1,97 @@ +package session + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "io" + "net/http" + "net/url" + "time" +) + +type SessionStore interface { + Set(key, value interface{}) error //set session value + Get(key interface{}) interface{} //get session value + Delete(key interface{}) error //delete session value + SessionID() string //back current sessionID + SessionRelease() // release the resource +} + +type Provider interface { + SessionInit(maxlifetime int64, savePath string) error + SessionRead(sid string) (SessionStore, error) + SessionDestroy(sid string) error + SessionGC() +} + +var provides = make(map[string]Provider) + +// Register makes a session provide available by the provided name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, provide Provider) { + if provide == nil { + panic("session: Register provide is nil") + } + if _, dup := provides[name]; dup { + panic("session: Register called twice for provider " + name) + } + provides[name] = provide +} + +type Manager struct { + cookieName string //private cookiename + provider Provider + maxlifetime int64 +} + +func NewManager(provideName, cookieName string, maxlifetime int64, savePath string) (*Manager, error) { + provider, ok := provides[provideName] + if !ok { + return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName) + } + provider.SessionInit(maxlifetime, savePath) + return &Manager{provider: provider, cookieName: cookieName, maxlifetime: maxlifetime}, nil +} + +//get Session +func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) { + cookie, err := r.Cookie(manager.cookieName) + if err != nil || cookie.Value == "" { + sid := manager.sessionId() + session, _ = manager.provider.SessionRead(sid) + cookie := http.Cookie{Name: manager.cookieName, Value: url.QueryEscape(sid), Path: "/", HttpOnly: true, MaxAge: int(manager.maxlifetime)} + http.SetCookie(w, &cookie) + } else { + sid, _ := url.QueryUnescape(cookie.Value) + session, _ = manager.provider.SessionRead(sid) + } + return +} + +//Destroy sessionid +func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie(manager.cookieName) + if err != nil || cookie.Value == "" { + return + } else { + manager.provider.SessionDestroy(cookie.Value) + expiration := time.Now() + cookie := http.Cookie{Name: manager.cookieName, Path: "/", HttpOnly: true, Expires: expiration, MaxAge: -1} + http.SetCookie(w, &cookie) + } +} + +func (manager *Manager) GC() { + manager.provider.SessionGC() + time.AfterFunc(time.Duration(manager.maxlifetime)*time.Second, func() { manager.GC() }) +} + +func (manager *Manager) sessionId() string { + b := make([]byte, 24) + if _, err := io.ReadFull(rand.Reader, b); err != nil { + return "" + } + return base64.URLEncoding.EncodeToString(b) +}