1
0
mirror of https://github.com/astaxie/beego.git synced 2025-01-22 13:37:12 +00:00

Merge pull request #1 from fuxiaohei/develop

merge develop
This commit is contained in:
傅小黑 2014-01-17 07:48:39 -08:00
commit 6b5108ef92
77 changed files with 3478 additions and 685 deletions

View File

@ -35,9 +35,5 @@ More info [beego.me](http://beego.me)
beego is licensed under the Apache Licence, Version 2.0
(http://www.apache.org/licenses/LICENSE-2.0.html).
## Use case
- Displaying API documentation: [gowalker](https://github.com/Unknwon/gowalker)
- seocms: [seocms](https://github.com/chinakr/seocms)
- CMS: [toropress](https://github.com/insionng/toropress)
[![Clone in Koding](http://learn.koding.com/btn/clone_d.png)][koding]
[koding]: https://koding.com/Teamwork?import=https://github.com/astaxie/beego/archive/master.zip&c=git1

8
app.go
View File

@ -118,6 +118,14 @@ func (app *App) AutoRouter(c ControllerInterface) *App {
return app
}
// AutoRouterWithPrefix adds beego-defined controller handler with prefix.
// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page,
// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function.
func (app *App) AutoRouterWithPrefix(prefix string, c ControllerInterface) *App {
app.Handlers.AddAutoPrefix(prefix, c)
return app
}
// UrlFor creates a url with another registered controller handler with params.
// The endpoint is formed as path.controller.name to defined the controller method which will run.
// The values need key-pair data to assign into controller method.

126
beego.go
View File

@ -4,6 +4,7 @@ import (
"net/http"
"path"
"path/filepath"
"strconv"
"strings"
"github.com/astaxie/beego/middleware"
@ -13,6 +14,76 @@ import (
// beego web framework version.
const VERSION = "1.0.1"
type hookfunc func() error //hook function to run
var hooks []hookfunc //hook function slice to store the hookfunc
type groupRouter struct {
pattern string
controller ControllerInterface
mappingMethods string
}
// RouterGroups which will store routers
type GroupRouters []groupRouter
// Get a new GroupRouters
func NewGroupRouters() GroupRouters {
return make([]groupRouter, 0)
}
// Add Router in the GroupRouters
// it is for plugin or module to register router
func (gr GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingMethod ...string) {
var newRG groupRouter
if len(mappingMethod) > 0 {
newRG = groupRouter{
pattern,
c,
mappingMethod[0],
}
} else {
newRG = groupRouter{
pattern,
c,
"",
}
}
gr = append(gr, newRG)
}
func (gr GroupRouters) AddAuto(c ControllerInterface) {
newRG := groupRouter{
"",
c,
"",
}
gr = append(gr, newRG)
}
// AddGroupRouter with the prefix
// it will register the router in BeeApp
// the follow code is write in modules:
// GR:=NewGroupRouters()
// GR.AddRouter("/login",&UserController,"get:Login")
// GR.AddRouter("/logout",&UserController,"get:Logout")
// GR.AddRouter("/register",&UserController,"get:Reg")
// the follow code is write in app:
// import "github.com/beego/modules/auth"
// AddRouterGroup("/admin", auth.GR)
func AddGroupRouter(prefix string, groups GroupRouters) *App {
for _, v := range groups {
if v.pattern == "" {
BeeApp.AutoRouterWithPrefix(prefix, v.controller)
} else if v.mappingMethods != "" {
BeeApp.Router(prefix+v.pattern, v.controller, v.mappingMethods)
} else {
BeeApp.Router(prefix+v.pattern, v.controller)
}
}
return BeeApp
}
// Router adds a patterned controller handler to BeeApp.
// it's an alias method of App.Router.
func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App {
@ -36,6 +107,13 @@ func AutoRouter(c ControllerInterface) *App {
return BeeApp
}
// AutoPrefix adds controller handler to BeeApp with prefix.
// it's same to App.AutoRouterWithPrefix.
func AutoPrefix(prefix string, c ControllerInterface) *App {
BeeApp.AutoRouterWithPrefix(prefix, c)
return BeeApp
}
// ErrorHandler registers http.HandlerFunc to each http err code string.
// usage:
// beego.ErrorHandler("404",NotFound)
@ -87,6 +165,12 @@ func InsertFilter(pattern string, pos int, filter FilterFunc) *App {
return BeeApp
}
// The hookfunc will run in beego.Run()
// such as sessionInit, middlerware start, buildtemplate, admin start
func AddAPPStartHook(hf hookfunc) {
hooks = append(hooks, hf)
}
// Run beego application.
// it's alias of App.Run.
func Run() {
@ -99,18 +183,32 @@ func Run() {
}
}
//init mime
initMime()
// do hooks function
for _, hk := range hooks {
err := hk()
if err != nil {
panic(err)
}
}
if SessionOn {
GlobalSessions, _ = session.NewManager(SessionProvider,
SessionName,
SessionGCMaxLifetime,
SessionSavePath,
HttpTLS,
SessionHashFunc,
SessionHashKey,
SessionCookieLifeTime)
var err error
sessionConfig := AppConfig.String("sessionConfig")
if sessionConfig == "" {
sessionConfig = `{"cookieName":"` + SessionName + `",` +
`"gclifetime":` + strconv.FormatInt(SessionGCMaxLifetime, 10) + `,` +
`"providerConfig":"` + SessionSavePath + `",` +
`"secure":` + strconv.FormatBool(HttpTLS) + `,` +
`"sessionIDHashFunc":"` + SessionHashFunc + `",` +
`"sessionIDHashKey":"` + SessionHashKey + `",` +
`"enableSetCookie":` + strconv.FormatBool(SessionAutoSetCookie) + `,` +
`"cookieLifeTime":` + strconv.Itoa(SessionCookieLifeTime) + `}`
}
GlobalSessions, err = session.NewManager(SessionProvider,
sessionConfig)
if err != nil {
panic(err)
}
go GlobalSessions.GC()
}
@ -123,7 +221,7 @@ func Run() {
middleware.VERSION = VERSION
middleware.AppName = AppName
middleware.RegisterErrorHander()
middleware.RegisterErrorHandler()
if EnableAdmin {
go BeeAdminApp.Run()
@ -131,3 +229,9 @@ func Run() {
BeeApp.Run()
}
func init() {
hooks = make([]hookfunc, 0)
//init mime
AddAPPStartHook(initMime)
}

2
cache/README.md vendored
View File

@ -43,7 +43,7 @@ interval means the gc time. The cache will check at each time interval, whether
## Memcache adapter
memory adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client.
Memcache adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client.
Configure like this:

50
cache/cache_test.go vendored
View File

@ -5,7 +5,7 @@ import (
"time"
)
func Test_cache(t *testing.T) {
func TestCache(t *testing.T) {
bm, err := NewCache("memory", `{"interval":20}`)
if err != nil {
t.Error("init err")
@ -51,3 +51,51 @@ func Test_cache(t *testing.T) {
t.Error("delete err")
}
}
func TestFileCache(t *testing.T) {
bm, err := NewCache("file", `{"CachePath":"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0}`)
if err != nil {
t.Error("init err")
}
if err = bm.Put("astaxie", 1, 10); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
t.Error("check err")
}
if v := bm.Get("astaxie"); v.(int) != 1 {
t.Error("get err")
}
if err = bm.Incr("astaxie"); err != nil {
t.Error("Incr Error", err)
}
if v := bm.Get("astaxie"); v.(int) != 2 {
t.Error("get err")
}
if err = bm.Decr("astaxie"); err != nil {
t.Error("Incr Error", err)
}
if v := bm.Get("astaxie"); v.(int) != 1 {
t.Error("get err")
}
bm.Delete("astaxie")
if bm.IsExist("astaxie") {
t.Error("delete err")
}
//test string
if err = bm.Put("astaxie", "author", 10); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
t.Error("check err")
}
if v := bm.Get("astaxie"); v.(string) != "author" {
t.Error("get err")
}
}

10
cache/file.go vendored
View File

@ -61,6 +61,7 @@ func (this *FileCache) StartAndGC(config string) error {
var cfg map[string]string
json.Unmarshal([]byte(config), &cfg)
//fmt.Println(cfg)
//fmt.Println(config)
if _, ok := cfg["CachePath"]; !ok {
cfg["CachePath"] = FileCachePath
}
@ -135,7 +136,7 @@ func (this *FileCache) Get(key string) interface{} {
return ""
}
var to FileCacheItem
Gob_decode([]byte(filedata), &to)
Gob_decode(filedata, &to)
if to.Expired < time.Now().Unix() {
return ""
}
@ -177,7 +178,7 @@ func (this *FileCache) Delete(key string) error {
func (this *FileCache) Incr(key string) error {
data := this.Get(key)
var incr int
fmt.Println(reflect.TypeOf(data).Name())
//fmt.Println(reflect.TypeOf(data).Name())
if reflect.TypeOf(data).Name() != "int" {
incr = 0
} else {
@ -210,8 +211,7 @@ func (this *FileCache) IsExist(key string) bool {
// Clean cached files.
// not implemented.
func (this *FileCache) ClearAll() error {
//this.CachePath .递归删除
//this.CachePath
return nil
}
@ -271,7 +271,7 @@ func Gob_encode(data interface{}) ([]byte, error) {
}
// Gob decodes file cache item.
func Gob_decode(data []byte, to interface{}) error {
func Gob_decode(data []byte, to *FileCacheItem) error {
buf := bytes.NewBuffer(data)
dec := gob.NewDecoder(buf)
return dec.Decode(&to)

41
cache/memcache.go vendored
View File

@ -21,7 +21,11 @@ func NewMemCache() *MemcacheCache {
// get value from memcache.
func (rc *MemcacheCache) Get(key string) interface{} {
if rc.c == nil {
rc.c = rc.connectInit()
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
v, err := rc.c.Get(key)
if err != nil {
@ -39,7 +43,11 @@ func (rc *MemcacheCache) Get(key string) interface{} {
// put value to memcache. only support string.
func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
if rc.c == nil {
rc.c = rc.connectInit()
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
v, ok := val.(string)
if !ok {
@ -55,7 +63,11 @@ func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
// delete value in memcache.
func (rc *MemcacheCache) Delete(key string) error {
if rc.c == nil {
rc.c = rc.connectInit()
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := rc.c.Delete(key)
return err
@ -76,7 +88,11 @@ func (rc *MemcacheCache) Decr(key string) error {
// check value exists in memcache.
func (rc *MemcacheCache) IsExist(key string) bool {
if rc.c == nil {
rc.c = rc.connectInit()
var err error
rc.c, err = rc.connectInit()
if err != nil {
return false
}
}
v, err := rc.c.Get(key)
if err != nil {
@ -93,7 +109,11 @@ func (rc *MemcacheCache) IsExist(key string) bool {
// clear all cached in memcache.
func (rc *MemcacheCache) ClearAll() error {
if rc.c == nil {
rc.c = rc.connectInit()
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
err := rc.c.FlushAll()
return err
@ -109,20 +129,21 @@ func (rc *MemcacheCache) StartAndGC(config string) error {
return errors.New("config has no conn key")
}
rc.conninfo = cf["conn"]
rc.c = rc.connectInit()
if rc.c == nil {
var err error
rc.c, err = rc.connectInit()
if err != nil {
return errors.New("dial tcp conn error")
}
return nil
}
// connect to memcache and keep the connection.
func (rc *MemcacheCache) connectInit() *memcache.Connection {
func (rc *MemcacheCache) connectInit() (*memcache.Connection, error) {
c, err := memcache.Connect(rc.conninfo)
if err != nil {
return nil
return nil, err
}
return c
return c, nil
}
func init() {

118
cache/redis.go vendored
View File

@ -3,6 +3,7 @@ package cache
import (
"encoding/json"
"errors"
"time"
"github.com/beego/redigo/redis"
)
@ -14,7 +15,7 @@ var (
// Redis cache adapter.
type RedisCache struct {
c redis.Conn
p *redis.Pool // redis connection pool
conninfo string
key string
}
@ -24,107 +25,62 @@ func NewRedisCache() *RedisCache {
return &RedisCache{key: DefaultKey}
}
// actually do the redis cmds
func (rc *RedisCache) do(commandName string, args ...interface{}) (reply interface{}, err error) {
c := rc.p.Get()
defer c.Close()
return c.Do(commandName, args...)
}
// Get cache from redis.
func (rc *RedisCache) Get(key string) interface{} {
if rc.c == nil {
var err error
rc.c, err = rc.connectInit()
if err != nil {
return nil
}
}
v, err := rc.c.Do("HGET", rc.key, key)
v, err := rc.do("HGET", rc.key, key)
if err != nil {
return nil
}
return v
}
// put cache to redis.
// timeout is ignored.
func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error {
if rc.c == nil {
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := rc.c.Do("HSET", rc.key, key, val)
_, err := rc.do("HSET", rc.key, key, val)
return err
}
// delete cache in redis.
func (rc *RedisCache) Delete(key string) error {
if rc.c == nil {
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := rc.c.Do("HDEL", rc.key, key)
_, err := rc.do("HDEL", rc.key, key)
return err
}
// check cache exist in redis.
func (rc *RedisCache) IsExist(key string) bool {
if rc.c == nil {
var err error
rc.c, err = rc.connectInit()
if err != nil {
return false
}
}
v, err := redis.Bool(rc.c.Do("HEXISTS", rc.key, key))
v, err := redis.Bool(rc.do("HEXISTS", rc.key, key))
if err != nil {
return false
}
return v
}
// increase counter in redis.
func (rc *RedisCache) Incr(key string) error {
if rc.c == nil {
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, 1))
if err != nil {
return err
}
return nil
_, err := redis.Bool(rc.do("HINCRBY", rc.key, key, 1))
return err
}
// decrease counter in redis.
func (rc *RedisCache) Decr(key string) error {
if rc.c == nil {
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, -1))
if err != nil {
return err
}
return nil
_, err := redis.Bool(rc.do("HINCRBY", rc.key, key, -1))
return err
}
// clean all cache in redis. delete this redis collection.
func (rc *RedisCache) ClearAll() error {
if rc.c == nil {
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := rc.c.Do("DEL", rc.key)
_, err := rc.do("DEL", rc.key)
return err
}
@ -135,32 +91,42 @@ func (rc *RedisCache) ClearAll() error {
func (rc *RedisCache) StartAndGC(config string) error {
var cf map[string]string
json.Unmarshal([]byte(config), &cf)
if _, ok := cf["key"]; !ok {
cf["key"] = DefaultKey
}
if _, ok := cf["conn"]; !ok {
return errors.New("config has no conn key")
}
rc.key = cf["key"]
rc.conninfo = cf["conn"]
var err error
rc.c, err = rc.connectInit()
if err != nil {
rc.connectInit()
c := rc.p.Get()
defer c.Close()
if err := c.Err(); err != nil {
return err
}
if rc.c == nil {
return errors.New("dial tcp conn error")
}
return nil
}
// connect to redis.
func (rc *RedisCache) connectInit() (redis.Conn, error) {
c, err := redis.Dial("tcp", rc.conninfo)
if err != nil {
return nil, err
func (rc *RedisCache) connectInit() {
// initialize a new pool
rc.p = &redis.Pool{
MaxIdle: 3,
IdleTimeout: 180 * time.Second,
Dial: func() (redis.Conn, error) {
c, err := redis.Dial("tcp", rc.conninfo)
if err != nil {
return nil, err
}
return c, nil
},
}
return c, nil
}
func init() {

View File

@ -40,6 +40,7 @@ var (
SessionHashFunc string // session hash generation func.
SessionHashKey string // session hash salt string.
SessionCookieLifeTime int // the life time of session id in cookie.
SessionAutoSetCookie bool // auto setcookie
UseFcgi bool
MaxMemory int64
EnableGzip bool // flag of enable gzip
@ -96,6 +97,7 @@ func init() {
SessionHashFunc = "sha1"
SessionHashKey = "beegoserversessionkey"
SessionCookieLifeTime = 0 //set cookie default is the brower life
SessionAutoSetCookie = true
UseFcgi = false
@ -139,6 +141,7 @@ func init() {
func ParseConfig() (err error) {
AppConfig, err = config.NewConfig("ini", AppConfigPath)
if err != nil {
AppConfig = config.NewFakeConfig()
return err
} else {
HttpAddr = AppConfig.String("HttpAddr")

View File

@ -6,8 +6,9 @@ import (
// ConfigContainer defines how to get and set value from configuration raw data.
type ConfigContainer interface {
Set(key, val string) error // support section::key type in given key when using ini type.
String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
Set(key, val string) error // support section::key type in given key when using ini type.
String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
Strings(key string) []string //get string slice
Int(key string) (int, error)
Int64(key string) (int64, error)
Bool(key string) (bool, error)

62
config/fake.go Normal file
View File

@ -0,0 +1,62 @@
package config
import (
"errors"
"strconv"
"strings"
)
type fakeConfigContainer struct {
data map[string]string
}
func (c *fakeConfigContainer) getData(key string) string {
key = strings.ToLower(key)
return c.data[key]
}
func (c *fakeConfigContainer) Set(key, val string) error {
key = strings.ToLower(key)
c.data[key] = val
return nil
}
func (c *fakeConfigContainer) String(key string) string {
return c.getData(key)
}
func (c *fakeConfigContainer) Strings(key string) []string {
return strings.Split(c.getData(key), ";")
}
func (c *fakeConfigContainer) Int(key string) (int, error) {
return strconv.Atoi(c.getData(key))
}
func (c *fakeConfigContainer) Int64(key string) (int64, error) {
return strconv.ParseInt(c.getData(key), 10, 64)
}
func (c *fakeConfigContainer) Bool(key string) (bool, error) {
return strconv.ParseBool(c.getData(key))
}
func (c *fakeConfigContainer) Float(key string) (float64, error) {
return strconv.ParseFloat(c.getData(key), 64)
}
func (c *fakeConfigContainer) DIY(key string) (interface{}, error) {
key = strings.ToLower(key)
if v, ok := c.data[key]; ok {
return v, nil
}
return nil, errors.New("key not find")
}
var _ ConfigContainer = new(fakeConfigContainer)
func NewFakeConfig() ConfigContainer {
return &fakeConfigContainer{
data: make(map[string]string),
}
}

View File

@ -146,6 +146,11 @@ func (c *IniConfigContainer) String(key string) string {
return c.getdata(key)
}
// Strings returns the []string value for a given key.
func (c *IniConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// WriteValue writes a new value for key.
// if write to one section, the key need be "section::key".
// if the section is not existed, it panics.

View File

@ -19,6 +19,7 @@ copyrequestbody = true
key1="asta"
key2 = "xie"
CaseInsensitive = true
peers = one;two;three
`
func TestIni(t *testing.T) {
@ -78,4 +79,11 @@ func TestIni(t *testing.T) {
if v, err := iniconf.Bool("demo::caseinsensitive"); err != nil || v != true {
t.Fatal("get demo.caseinsensitive error")
}
if data := iniconf.Strings("demo::peers"); len(data) != 3 {
t.Fatal("get strings error", data)
} else if data[0] != "one" {
t.Fatal("get first params error not equat to one")
}
}

View File

@ -116,6 +116,11 @@ func (c *JsonConfigContainer) String(key string) string {
return ""
}
// Strings returns the []string value for a given key.
func (c *JsonConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// WriteValue writes a new value for key.
func (c *JsonConfigContainer) Set(key, val string) error {
c.Lock()

View File

@ -5,6 +5,7 @@ import (
"io/ioutil"
"os"
"strconv"
"strings"
"sync"
"github.com/beego/x2j"
@ -72,6 +73,11 @@ func (c *XMLConfigContainer) String(key string) string {
return ""
}
// Strings returns the []string value for a given key.
func (c *XMLConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// WriteValue writes a new value for key.
func (c *XMLConfigContainer) Set(key, val string) error {
c.Lock()

View File

@ -7,6 +7,7 @@ import (
"io/ioutil"
"log"
"os"
"strings"
"sync"
"github.com/beego/goyaml2"
@ -117,6 +118,11 @@ func (c *YAMLConfigContainer) String(key string) string {
return ""
}
// Strings returns the []string value for a given key.
func (c *YAMLConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// WriteValue writes a new value for key.
func (c *YAMLConfigContainer) Set(key, val string) error {
c.Lock()

View File

@ -3,7 +3,6 @@ package beego
import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"errors"
@ -22,6 +21,7 @@ import (
"github.com/astaxie/beego/context"
"github.com/astaxie/beego/session"
"github.com/astaxie/beego/utils"
)
var (
@ -140,7 +140,7 @@ func (c *Controller) RenderString() (string, error) {
return string(b), e
}
// RenderBytes returns the bytes of renderd tempate string. Do not send out response.
// RenderBytes returns the bytes of rendered template string. Do not send out response.
func (c *Controller) RenderBytes() ([]byte, error) {
//if the controller has set layout, then first get the tplname's content set the content to the layout
if c.Layout != "" {
@ -165,7 +165,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
if c.LayoutSections != nil {
for sectionName, sectionTpl := range c.LayoutSections {
if (sectionTpl == "") {
if sectionTpl == "" {
c.Data[sectionName] = ""
continue
}
@ -391,6 +391,7 @@ func (c *Controller) DelSession(name interface{}) {
// SessionRegenerateID regenerates session id for this session.
// the session data have no changes.
func (c *Controller) SessionRegenerateID() {
c.CruSession.SessionRelease(c.Ctx.ResponseWriter)
c.CruSession = GlobalSessions.SessionRegenerateId(c.Ctx.ResponseWriter, c.Ctx.Request)
c.Ctx.Input.CruSession = c.CruSession
}
@ -454,7 +455,7 @@ func (c *Controller) XsrfToken() string {
} else {
expire = int64(XSRFExpire)
}
token = getRandomString(15)
token = string(utils.RandomCreateBytes(15))
c.SetSecureCookie(XSRFKEY, "_xsrf", token, expire)
}
c._xsrf_token = token
@ -491,14 +492,3 @@ func (c *Controller) XsrfFormHtml() string {
func (c *Controller) GetControllerAndAction() (controllerName, actionName string) {
return c.controllerName, c.actionName
}
// getRandomString returns random string.
func getRandomString(n int) string {
const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
var bytes = make([]byte, n)
rand.Read(bytes)
for i, b := range bytes {
bytes[i] = alphanum[b%byte(len(alphanum))]
}
return string(bytes)
}

View File

@ -1,12 +1,13 @@
package controllers
import (
"github.com/astaxie/beego"
"github.com/garyburd/go-websocket/websocket"
"io/ioutil"
"math/rand"
"net/http"
"time"
"github.com/astaxie/beego"
"github.com/gorilla/websocket"
)
const (

View File

@ -28,6 +28,12 @@ func (mr *FilterRouter) ValidRouter(router string) (bool, map[string]string) {
if router == mr.pattern {
return true, nil
}
//pattern /admin router /admin/ match
//pattern /admin/ router /admin don't match, because url will 301 in router
if n := len(router); n > 1 && router[n-1] == '/' && router[:n-2] == mr.pattern {
return true, nil
}
if mr.hasregex {
if !mr.regex.MatchString(router) {
return false, nil
@ -46,7 +52,7 @@ func (mr *FilterRouter) ValidRouter(router string) (bool, map[string]string) {
return false, nil
}
func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
func buildFilter(pattern string, filter FilterFunc) (*FilterRouter, error) {
mr := new(FilterRouter)
mr.params = make(map[int]string)
mr.filterFunc = filter
@ -54,7 +60,7 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
j := 0
for i, part := range parts {
if strings.HasPrefix(part, ":") {
expr := "(.+)"
expr := "(.*)"
//a user may choose to override the default expression
// similar to expressjs: /user/:id([0-9]+)
if index := strings.Index(part, "("); index != -1 {
@ -77,7 +83,7 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
j++
}
if strings.HasPrefix(part, "*") {
expr := "(.+)"
expr := "(.*)"
if part == "*.*" {
mr.params[j] = ":path"
parts[i] = "([^.]+).([^.]+)"
@ -137,12 +143,11 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
pattern = strings.Join(parts, "/")
regex, regexErr := regexp.Compile(pattern)
if regexErr != nil {
//TODO add error handling here to avoid panic
panic(regexErr)
return nil, regexErr
}
mr.regex = regex
mr.hasregex = true
}
mr.pattern = pattern
return mr
return mr, nil
}

View File

@ -23,3 +23,32 @@ func TestFilter(t *testing.T) {
t.Errorf("user define func can't run")
}
}
var FilterAdminUser = func(ctx *context.Context) {
ctx.Output.Body([]byte("i am admin"))
}
// Filter pattern /admin/:all
// all url like /admin/ /admin/xie will all get filter
func TestPatternTwo(t *testing.T) {
r, _ := http.NewRequest("GET", "/admin/", nil)
w := httptest.NewRecorder()
handler := NewControllerRegistor()
handler.AddFilter("/admin/:all", "AfterStatic", FilterAdminUser)
handler.ServeHTTP(w, r)
if w.Body.String() != "i am admin" {
t.Errorf("filter /admin/ can't run")
}
}
func TestPatternThree(t *testing.T) {
r, _ := http.NewRequest("GET", "/admin/astaxie", nil)
w := httptest.NewRecorder()
handler := NewControllerRegistor()
handler.AddFilter("/admin/:all", "AfterStatic", FilterAdminUser)
handler.ServeHTTP(w, r)
if w.Body.String() != "i am admin" {
t.Errorf("filter /admin/astaxie can't run")
}
}

View File

@ -7,6 +7,8 @@ import (
"net"
)
// ConnWriter implements LoggerInterface.
// it writes messages in keep-live tcp connection.
type ConnWriter struct {
lg *log.Logger
innerWriter io.WriteCloser
@ -17,12 +19,15 @@ type ConnWriter struct {
Level int `json:"level"`
}
// create new ConnWrite returning as LoggerInterface.
func NewConn() LoggerInterface {
conn := new(ConnWriter)
conn.Level = LevelTrace
return conn
}
// init connection writer with json config.
// json config only need key "level".
func (c *ConnWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), c)
if err != nil {
@ -31,6 +36,8 @@ func (c *ConnWriter) Init(jsonconfig string) error {
return nil
}
// write message in connection.
// if connection is down, try to re-connect.
func (c *ConnWriter) WriteMsg(msg string, level int) error {
if level < c.Level {
return nil
@ -49,10 +56,12 @@ func (c *ConnWriter) WriteMsg(msg string, level int) error {
return nil
}
// implementing method. empty.
func (c *ConnWriter) Flush() {
}
// destroy connection writer and close tcp listener.
func (c *ConnWriter) Destroy() {
if c.innerWriter == nil {
return

View File

@ -6,11 +6,13 @@ import (
"os"
)
// ConsoleWriter implements LoggerInterface and writes messages to terminal.
type ConsoleWriter struct {
lg *log.Logger
Level int `json:"level"`
}
// create ConsoleWriter returning as LoggerInterface.
func NewConsole() LoggerInterface {
cw := new(ConsoleWriter)
cw.lg = log.New(os.Stdout, "", log.Ldate|log.Ltime)
@ -18,6 +20,8 @@ func NewConsole() LoggerInterface {
return cw
}
// init console logger.
// jsonconfig like '{"level":LevelTrace}'.
func (c *ConsoleWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), c)
if err != nil {
@ -26,6 +30,7 @@ func (c *ConsoleWriter) Init(jsonconfig string) error {
return nil
}
// write message in console.
func (c *ConsoleWriter) WriteMsg(msg string, level int) error {
if level < c.Level {
return nil
@ -34,10 +39,12 @@ func (c *ConsoleWriter) WriteMsg(msg string, level int) error {
return nil
}
// implementing method. empty.
func (c *ConsoleWriter) Destroy() {
}
// implementing method. empty.
func (c *ConsoleWriter) Flush() {
}

View File

@ -13,6 +13,8 @@ import (
"time"
)
// FileLogWriter implements LoggerInterface.
// It writes messages by lines limit, file size limit, or time frequency.
type FileLogWriter struct {
*log.Logger
mw *MuxWriter
@ -38,17 +40,20 @@ type FileLogWriter struct {
Level int `json:"level"`
}
// an *os.File writer with locker.
type MuxWriter struct {
sync.Mutex
fd *os.File
}
// write to os.File.
func (l *MuxWriter) Write(b []byte) (int, error) {
l.Lock()
defer l.Unlock()
return l.fd.Write(b)
}
// set os.File in writer.
func (l *MuxWriter) SetFd(fd *os.File) {
if l.fd != nil {
l.fd.Close()
@ -56,6 +61,7 @@ func (l *MuxWriter) SetFd(fd *os.File) {
l.fd = fd
}
// create a FileLogWriter returning as LoggerInterface.
func NewFileWriter() LoggerInterface {
w := &FileLogWriter{
Filename: "",
@ -73,15 +79,16 @@ func NewFileWriter() LoggerInterface {
return w
}
// jsonconfig like this
//{
// Init file logger with json config.
// jsonconfig like:
// {
// "filename":"logs/beego.log",
// "maxlines":10000,
// "maxsize":1<<30,
// "daily":true,
// "maxdays":15,
// "rotate":true
//}
// }
func (w *FileLogWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), w)
if err != nil {
@ -94,6 +101,7 @@ func (w *FileLogWriter) Init(jsonconfig string) error {
return err
}
// start file logger. create log file and set to locker-inside file writer.
func (w *FileLogWriter) StartLogger() error {
fd, err := w.createLogFile()
if err != nil {
@ -122,6 +130,7 @@ func (w *FileLogWriter) docheck(size int) {
w.maxsize_cursize += size
}
// write logger message into file.
func (w *FileLogWriter) WriteMsg(msg string, level int) error {
if level < w.Level {
return nil
@ -158,6 +167,8 @@ func (w *FileLogWriter) initFd() error {
return nil
}
// DoRotate means it need to write file in new file.
// new file name like xx.log.2013-01-01.2
func (w *FileLogWriter) DoRotate() error {
_, err := os.Lstat(w.Filename)
if err == nil { // file exists
@ -211,10 +222,14 @@ func (w *FileLogWriter) deleteOldLog() {
})
}
// destroy file logger, close file writer.
func (w *FileLogWriter) Destroy() {
w.mw.fd.Close()
}
// flush file logger.
// there are no buffering messages in file logger in memory.
// flush file means sync file from disk.
func (w *FileLogWriter) Flush() {
w.mw.fd.Sync()
}

View File

@ -6,6 +6,7 @@ import (
)
const (
// log message levels
LevelTrace = iota
LevelDebug
LevelInfo
@ -16,6 +17,7 @@ const (
type loggerType func() LoggerInterface
// LoggerInterface defines the behavior of a log provider.
type LoggerInterface interface {
Init(config string) error
WriteMsg(msg string, level int) error
@ -38,6 +40,8 @@ func Register(name string, log loggerType) {
adapters[name] = log
}
// BeeLogger is default logger in beego application.
// it can contain several providers and log message into all providers.
type BeeLogger struct {
lock sync.Mutex
level int
@ -50,7 +54,9 @@ type logMsg struct {
msg string
}
// config need to be correct JSON as string: {"interval":360}
// NewLogger returns a new BeeLogger.
// channellen means the number of messages in chan.
// if the buffering chan is full, logger adapters write to file or other way.
func NewLogger(channellen int64) *BeeLogger {
bl := new(BeeLogger)
bl.msg = make(chan *logMsg, channellen)
@ -60,6 +66,8 @@ func NewLogger(channellen int64) *BeeLogger {
return bl
}
// SetLogger provides a given logger adapter into BeeLogger with config string.
// config need to be correct JSON as string: {"interval":360}.
func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
bl.lock.Lock()
defer bl.lock.Unlock()
@ -73,6 +81,7 @@ func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
}
}
// remove a logger adapter in BeeLogger.
func (bl *BeeLogger) DelLogger(adaptername string) error {
bl.lock.Lock()
defer bl.lock.Unlock()
@ -96,10 +105,14 @@ func (bl *BeeLogger) writerMsg(loglevel int, msg string) error {
return nil
}
// set log message level.
// if message level (such as LevelTrace) is less than logger level (such as LevelWarn), ignore message.
func (bl *BeeLogger) SetLevel(l int) {
bl.level = l
}
// start logger chan reading.
// when chan is full, write logs.
func (bl *BeeLogger) StartLogger() {
for {
select {
@ -111,43 +124,50 @@ func (bl *BeeLogger) StartLogger() {
}
}
// log trace level message.
func (bl *BeeLogger) Trace(format string, v ...interface{}) {
msg := fmt.Sprintf("[T] "+format, v...)
bl.writerMsg(LevelTrace, msg)
}
// log debug level message.
func (bl *BeeLogger) Debug(format string, v ...interface{}) {
msg := fmt.Sprintf("[D] "+format, v...)
bl.writerMsg(LevelDebug, msg)
}
// log info level message.
func (bl *BeeLogger) Info(format string, v ...interface{}) {
msg := fmt.Sprintf("[I] "+format, v...)
bl.writerMsg(LevelInfo, msg)
}
// log warn level message.
func (bl *BeeLogger) Warn(format string, v ...interface{}) {
msg := fmt.Sprintf("[W] "+format, v...)
bl.writerMsg(LevelWarn, msg)
}
// log error level message.
func (bl *BeeLogger) Error(format string, v ...interface{}) {
msg := fmt.Sprintf("[E] "+format, v...)
bl.writerMsg(LevelError, msg)
}
// log critical level message.
func (bl *BeeLogger) Critical(format string, v ...interface{}) {
msg := fmt.Sprintf("[C] "+format, v...)
bl.writerMsg(LevelCritical, msg)
}
//flush all chan data
// flush all chan data.
func (bl *BeeLogger) Flush() {
for _, l := range bl.outputs {
l.Flush()
}
}
// close logger, flush all chan data and destroy all adapters in BeeLogger.
func (bl *BeeLogger) Close() {
for {
if len(bl.msg) > 0 {

View File

@ -12,7 +12,7 @@ const (
subjectPhrase = "Diagnostic message from server"
)
// smtpWriter is used to send emails via given SMTP-server.
// smtpWriter implements LoggerInterface and is used to send emails via given SMTP-server.
type SmtpWriter struct {
Username string `json:"Username"`
Password string `json:"password"`
@ -22,10 +22,21 @@ type SmtpWriter struct {
Level int `json:"level"`
}
// create smtp writer.
func NewSmtpWriter() LoggerInterface {
return &SmtpWriter{Level: LevelTrace}
}
// init smtp writer with json config.
// config like:
// {
// "Username":"example@gmail.com",
// "password:"password",
// "host":"smtp.gmail.com:465",
// "subject":"email title",
// "sendTos":["email1","email2"],
// "level":LevelError
// }
func (s *SmtpWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), s)
if err != nil {
@ -34,6 +45,8 @@ func (s *SmtpWriter) Init(jsonconfig string) error {
return nil
}
// write message in smtp writer.
// it will send an email with subject and only this message.
func (s *SmtpWriter) WriteMsg(msg string, level int) error {
if level < s.Level {
return nil
@ -65,9 +78,12 @@ func (s *SmtpWriter) WriteMsg(msg string, level int) error {
return err
}
// implementing method. empty.
func (s *SmtpWriter) Flush() {
return
}
// implementing method. empty.
func (s *SmtpWriter) Destroy() {
return
}

View File

@ -5,16 +5,17 @@ import (
"compress/flate"
"compress/gzip"
"errors"
//"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"strings"
"sync"
"time"
)
var gmfim map[string]*MemFileInfo = make(map[string]*MemFileInfo)
var lock sync.RWMutex
// OpenMemZipFile returns MemFile object with a compressed static file.
// it's used for serve static file if gzip enable.
@ -32,12 +33,12 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) {
modtime := osfileinfo.ModTime()
fileSize := osfileinfo.Size()
lock.RLock()
cfi, ok := gmfim[zip+":"+path]
lock.RUnlock()
if ok && cfi.ModTime() == modtime && cfi.fileSize == fileSize {
//fmt.Printf("read %s file %s from cache\n", zip, path)
} else {
//fmt.Printf("NOT read %s file %s from cache\n", zip, path)
var content []byte
if zip == "gzip" {
//将文件内容压缩到zipbuf中
@ -81,8 +82,9 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) {
}
cfi = &MemFileInfo{osfileinfo, modtime, content, int64(len(content)), fileSize}
lock.Lock()
defer lock.Unlock()
gmfim[zip+":"+path] = cfi
//fmt.Printf("%s file %s to %d, cache it\n", zip, path, len(content))
}
return &MemFile{fi: cfi, offset: 0}, nil
}

View File

@ -61,6 +61,7 @@ var tpl = `
</html>
`
// render default application error page with error and stack string.
func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack string) {
t, _ := template.New("beegoerrortemp").Parse(tpl)
data := make(map[string]string)
@ -71,6 +72,7 @@ func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack str
data["Stack"] = Stack
data["BeegoVersion"] = VERSION
data["GoVersion"] = runtime.Version()
rw.WriteHeader(500)
t.Execute(rw, data)
}
@ -174,18 +176,19 @@ var errtpl = `
</html>
`
// map of http handlers for each error string.
var ErrorMaps map[string]http.HandlerFunc
func init() {
ErrorMaps = make(map[string]http.HandlerFunc)
}
//404
// show 404 notfound error.
func NotFound(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{})
data["Title"] = "Page Not Found"
data["Content"] = template.HTML("<br>The Page You have requested flown the coop." +
data["Content"] = template.HTML("<br>The page you have requested has flown the coop." +
"<br>Perhaps you are here because:" +
"<br><br><ul>" +
"<br>The page has moved" +
@ -198,28 +201,28 @@ func NotFound(rw http.ResponseWriter, r *http.Request) {
t.Execute(rw, data)
}
//401
// show 401 unauthorized error.
func Unauthorized(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{})
data["Title"] = "Unauthorized"
data["Content"] = template.HTML("<br>The Page You have requested can't authorized." +
data["Content"] = template.HTML("<br>The page you have requested can't be authorized." +
"<br>Perhaps you are here because:" +
"<br><br><ul>" +
"<br>Check the credentials that you supplied" +
"<br>Check the address for errors" +
"<br>The credentials you supplied are incorrect" +
"<br>There are errors in the website address" +
"</ul>")
data["BeegoVersion"] = VERSION
//rw.WriteHeader(http.StatusUnauthorized)
t.Execute(rw, data)
}
//403
// show 403 forbidden error.
func Forbidden(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{})
data["Title"] = "Forbidden"
data["Content"] = template.HTML("<br>The Page You have requested forbidden." +
data["Content"] = template.HTML("<br>The page you have requested is forbidden." +
"<br>Perhaps you are here because:" +
"<br><br><ul>" +
"<br>Your address may be blocked" +
@ -231,12 +234,12 @@ func Forbidden(rw http.ResponseWriter, r *http.Request) {
t.Execute(rw, data)
}
//503
// show 503 service unavailable error.
func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{})
data["Title"] = "Service Unavailable"
data["Content"] = template.HTML("<br>The Page You have requested unavailable." +
data["Content"] = template.HTML("<br>The page you have requested is unavailable." +
"<br>Perhaps you are here because:" +
"<br><br><ul>" +
"<br><br>The page is overloaded" +
@ -247,30 +250,32 @@ func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) {
t.Execute(rw, data)
}
//500
// show 500 internal server error.
func InternalServerError(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{})
data["Title"] = "Internal Server Error"
data["Content"] = template.HTML("<br>The Page You have requested has down now." +
data["Content"] = template.HTML("<br>The page you have requested is down right now." +
"<br><br><ul>" +
"<br>simply try again later" +
"<br>you should report the fault to the website administrator" +
"</ul>")
"<br>Please try again later and report the error to the website administrator" +
"<br></ul>")
data["BeegoVersion"] = VERSION
//rw.WriteHeader(http.StatusInternalServerError)
t.Execute(rw, data)
}
// show 500 internal error with simple text string.
func SimpleServerError(rw http.ResponseWriter, r *http.Request) {
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
// add http handler for given error string.
func Errorhandler(err string, h http.HandlerFunc) {
ErrorMaps[err] = h
}
func RegisterErrorHander() {
// register default error http handlers, 404,401,403,500 and 503.
func RegisterErrorHandler() {
if _, ok := ErrorMaps["404"]; !ok {
ErrorMaps["404"] = NotFound
}
@ -292,6 +297,8 @@ func RegisterErrorHander() {
}
}
// show error string as simple text message.
// if error string is empty, show 500 error as default.
func Exception(errcode string, w http.ResponseWriter, r *http.Request, msg string) {
if h, ok := ErrorMaps[errcode]; ok {
isint, err := strconv.Atoi(errcode)

View File

@ -2,16 +2,19 @@ package middleware
import "fmt"
// http exceptions
type HTTPException struct {
StatusCode int // http status code 4xx, 5xx
Description string
}
// return http exception error string, e.g. "400 Bad Request".
func (e *HTTPException) Error() string {
// return `status description`, e.g. `400 Bad Request`
return fmt.Sprintf("%d %s", e.StatusCode, e.Description)
}
// map of http exceptions for each http status code int.
// defined 400,401,403,404,405,500,502,503 and 504 default.
var HTTPExceptionMaps map[int]HTTPException
func init() {

View File

@ -544,8 +544,9 @@ var mimemaps map[string]string = map[string]string{
".mustache": "text/html",
}
func initMime() {
func initMime() error {
for k, v := range mimemaps {
mime.AddExtensionType(k, v)
}
return nil
}

View File

@ -16,6 +16,7 @@ var (
commands = make(map[string]commander)
)
// print help.
func printHelp(errs ...string) {
content := `orm command usage:
@ -31,6 +32,7 @@ func printHelp(errs ...string) {
os.Exit(2)
}
// listen for orm command and then run it if command arguments passed.
func RunCommand() {
if len(os.Args) < 2 || os.Args[1] != "orm" {
return
@ -58,6 +60,7 @@ func RunCommand() {
}
}
// sync database struct command interface.
type commandSyncDb struct {
al *alias
force bool
@ -66,6 +69,7 @@ type commandSyncDb struct {
rtOnError bool
}
// parse orm command line arguments.
func (d *commandSyncDb) Parse(args []string) {
var name string
@ -78,6 +82,7 @@ func (d *commandSyncDb) Parse(args []string) {
d.al = getDbAlias(name)
}
// run orm line command.
func (d *commandSyncDb) Run() error {
var drops []string
if d.force {
@ -208,10 +213,12 @@ func (d *commandSyncDb) Run() error {
return nil
}
// database creation commander interface implement.
type commandSqlAll struct {
al *alias
}
// parse orm command line arguments.
func (d *commandSqlAll) Parse(args []string) {
var name string
@ -222,6 +229,7 @@ func (d *commandSqlAll) Parse(args []string) {
d.al = getDbAlias(name)
}
// run orm line command.
func (d *commandSqlAll) Run() error {
sqls, indexes := getDbCreateSql(d.al)
var all []string
@ -243,6 +251,10 @@ func init() {
commands["sqlall"] = new(commandSqlAll)
}
// run syncdb command line.
// name means table's alias name. default is "default".
// force means run next sql if the current is error.
// verbose means show all info when running command or not.
func RunSyncdb(name string, force bool, verbose bool) error {
BootStrap()

View File

@ -12,6 +12,7 @@ type dbIndex struct {
Sql string
}
// create database drop sql.
func getDbDropSql(al *alias) (sqls []string) {
if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model")
@ -26,6 +27,7 @@ func getDbDropSql(al *alias) (sqls []string) {
return sqls
}
// get database column type string.
func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
T := al.DbBaser.DbTypes()
fieldType := fi.fieldType
@ -79,6 +81,7 @@ checkColumn:
return
}
// create alter sql string.
func getColumnAddQuery(al *alias, fi *fieldInfo) string {
Q := al.DbBaser.TableQuote()
typ := getColumnTyp(al, fi)
@ -90,6 +93,7 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string {
return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ)
}
// create database creation string.
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model")

173
orm/db.go
View File

@ -15,7 +15,7 @@ const (
)
var (
ErrMissPK = errors.New("missed pk value")
ErrMissPK = errors.New("missed pk value") // missing pk error
)
var (
@ -45,13 +45,22 @@ var (
}
)
// an instance of dbBaser interface/
type dbBase struct {
ins dbBaser
}
// check dbBase implements dbBaser interface.
var _ dbBaser = new(dbBase)
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) {
// get struct columns values as interface slice.
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) {
var columns []string
if names != nil {
columns = *names
}
for _, column := range cols {
var fi *fieldInfo
if fi, _ = mi.fields.GetByAny(column); fi != nil {
@ -64,14 +73,24 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
}
value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
if err != nil {
return nil, nil, err
return nil, err
}
columns = append(columns, column)
if names != nil {
columns = append(columns, column)
}
values = append(values, value)
}
if names != nil {
*names = columns
}
return
}
// get one field value in struct column as interface.
func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) {
var value interface{}
if fi.pk {
@ -140,6 +159,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
return value, nil
}
// create insert sql preparation statement object.
func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
Q := d.ins.TableQuote()
@ -165,8 +185,9 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
return stmt, query, err
}
// insert struct with prepared statement and given struct reflect value.
func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
_, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
if err != nil {
return 0, err
}
@ -185,6 +206,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
}
}
// query sql ,read records and persist in dbBaser.
func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error {
var whereCols []string
var args []interface{}
@ -192,7 +214,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
// if specify cols length > 0, then use it for where condition.
if len(cols) > 0 {
var err error
whereCols, args, err = d.collectValues(mi, ind, cols, false, false, tz)
whereCols = make([]string, 0, len(cols))
args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz)
if err != nil {
return err
}
@ -202,7 +225,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
if ok == false {
return ErrMissPK
}
whereCols = append(whereCols, pkColumn)
whereCols = []string{pkColumn}
args = append(args, pkValue)
}
@ -243,16 +266,77 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
return nil
}
// execute insert sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
names, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
names := make([]string, 0, len(mi.fields.dbcols)-1)
values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
if err != nil {
return 0, err
}
return d.InsertValue(q, mi, names, values)
return d.InsertValue(q, mi, false, names, values)
}
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values []interface{}) (int64, error) {
// multi-insert sql with given slice struct reflect.Value.
func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
var (
cnt int64
nums int
values []interface{}
names []string
)
// typ := reflect.Indirect(mi.addrField).Type()
length := sind.Len()
for i := 1; i <= length; i++ {
ind := reflect.Indirect(sind.Index(i - 1))
// Is this needed ?
// if !ind.Type().AssignableTo(typ) {
// return cnt, ErrArgs
// }
if i == 1 {
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
if err != nil {
return cnt, err
}
values = make([]interface{}, bulk*len(vus))
nums += copy(values, vus)
} else {
vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
if err != nil {
return cnt, err
}
if len(vus) != len(names) {
return cnt, ErrArgs
}
nums += copy(values[nums:], vus)
}
if i > 1 && i%bulk == 0 || length == i {
num, err := d.InsertValue(q, mi, true, names, values[:nums])
if err != nil {
return cnt, err
}
cnt += num
nums = 0
}
}
return cnt, nil
}
// execute insert sql with given struct and given values.
// insert the given values, not the field values in struct.
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
Q := d.ins.TableQuote()
marks := make([]string, len(names))
@ -264,36 +348,51 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values
qmarks := strings.Join(marks, ", ")
columns := strings.Join(names, sep)
multi := len(values) / len(names)
if isMulti {
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
}
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
d.ins.ReplaceMarks(&query)
if d.ins.HasReturningID(mi, &query) {
row := q.QueryRow(query, values...)
var id int64
err := row.Scan(&id)
return id, err
} else {
if isMulti || !d.ins.HasReturningID(mi, &query) {
if res, err := q.Exec(query, values...); err == nil {
if isMulti {
return res.RowsAffected()
}
return res.LastInsertId()
} else {
return 0, err
}
} else {
row := q.QueryRow(query, values...)
var id int64
err := row.Scan(&id)
return id, err
}
}
// execute update sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind)
if ok == false {
return 0, ErrMissPK
}
var setNames []string
// if specify cols length is zero, then commit all columns.
if len(cols) == 0 {
cols = mi.fields.dbcols
setNames = make([]string, 0, len(mi.fields.dbcols)-1)
} else {
setNames = make([]string, 0, len(cols))
}
setNames, setValues, err := d.collectValues(mi, ind, cols, true, false, tz)
setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz)
if err != nil {
return 0, err
}
@ -317,6 +416,8 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
return 0, nil
}
// execute delete sql dbQuerier with given struct reflect.Value.
// delete index is pk.
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind)
if ok == false {
@ -358,6 +459,8 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
return 0, nil
}
// update table-related record by querySet.
// need querySet not struct reflect.Value to update related records.
func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) {
columns := make([]string, 0, len(params))
values := make([]interface{}, 0, len(params))
@ -433,6 +536,8 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
return 0, nil
}
// delete related records.
// do UpdateBanch or DeleteBanch by condition of tables' relationship.
func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error {
for _, fi := range mi.fields.fieldsReverse {
fi = fi.reverseFieldInfo
@ -459,8 +564,11 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *
return nil
}
// delete table-related records.
func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) {
tables := newDbTables(mi, d.ins)
tables.skipEnd = true
if qs != nil {
tables.parseRelated(qs.related, qs.relDepth)
}
@ -486,6 +594,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
rs = r
}
defer rs.Close()
var ref interface{}
args = make([]interface{}, 0)
@ -532,6 +642,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
return 0, nil
}
// read related records.
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
val := reflect.ValueOf(container)
@ -640,6 +751,8 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
refs[i] = &ref
}
defer rs.Close()
slice := ind
var cnt int64
@ -739,6 +852,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
return cnt, nil
}
// excute count sql and return count result int64.
func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) {
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
@ -759,6 +873,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
return
}
// generate sql with replacing operator string placeholders and replaced values.
func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
sql := ""
params := getFlatParams(fi, args, tz)
@ -812,10 +927,12 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri
return sql, params
}
// gernerate sql string with inner function, such as UPPER(text).
func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) {
// default not use
}
// set values to struct column.
func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) {
for i, column := range cols {
val := reflect.Indirect(reflect.ValueOf(values[i])).Interface()
@ -837,6 +954,7 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string,
}
}
// convert value from database result to value following in field type.
func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) {
if val == nil {
return nil, nil
@ -989,6 +1107,7 @@ end:
}
// set one value to struct column field.
func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) {
fieldType := fi.fieldType
@ -1063,6 +1182,7 @@ setValue:
return value, nil
}
// query sql, read values , save to *[]ParamList.
func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) {
var (
@ -1150,6 +1270,8 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
refs[i] = &ref
}
defer rs.Close()
var (
cnt int64
columns []string
@ -1228,6 +1350,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
return cnt, nil
}
// flag of update joined record.
func (d *dbBase) SupportUpdateJoin() bool {
return true
}
@ -1236,30 +1359,37 @@ func (d *dbBase) MaxLimit() uint64 {
return 18446744073709551615
}
// return quote.
func (d *dbBase) TableQuote() string {
return "`"
}
// replace value placeholer in parametered sql string.
func (d *dbBase) ReplaceMarks(query *string) {
// default use `?` as mark, do nothing
}
// flag of RETURNING sql.
func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
return false
}
// convert time from db.
func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) {
*t = t.In(tz)
}
// convert time to db.
func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) {
*t = t.In(tz)
}
// get database types.
func (d *dbBase) DbTypes() map[string]string {
return nil
}
// gt all tables.
func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
tables := make(map[string]bool)
query := d.ins.ShowTablesQuery()
@ -1268,6 +1398,8 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
return tables, err
}
defer rows.Close()
for rows.Next() {
var table string
err := rows.Scan(&table)
@ -1282,6 +1414,7 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) {
return tables, nil
}
// get all cloumns in table.
func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
columns := make(map[string][3]string)
query := d.ins.ShowColumnsQuery(table)
@ -1290,6 +1423,8 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e
return columns, err
}
defer rows.Close()
for rows.Next() {
var (
name string
@ -1306,18 +1441,22 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e
return columns, nil
}
// not implement.
func (d *dbBase) OperatorSql(operator string) string {
panic(ErrNotImplement)
}
// not implement.
func (d *dbBase) ShowTablesQuery() string {
panic(ErrNotImplement)
}
// not implement.
func (d *dbBase) ShowColumnsQuery(table string) string {
panic(ErrNotImplement)
}
// not implement.
func (d *dbBase) IndexExists(dbQuerier, string, string) bool {
panic(ErrNotImplement)
}

View File

@ -9,27 +9,32 @@ import (
"time"
)
// database driver constant int.
type DriverType int
const (
_ DriverType = iota
DR_MySQL
DR_Sqlite
DR_Oracle
DR_Postgres
_ DriverType = iota // int enum type
DR_MySQL // mysql
DR_Sqlite // sqlite
DR_Oracle // oracle
DR_Postgres // pgsql
)
// database driver string.
type driver string
// get type constant int of current driver..
func (d driver) Type() DriverType {
a, _ := dataBaseCache.get(string(d))
return a.Driver
}
// get name of current driver
func (d driver) Name() string {
return string(d)
}
// check driver iis implemented Driver interface or not.
var _ Driver = new(driver)
var (
@ -47,11 +52,13 @@ var (
}
)
// database alias cacher.
type _dbCache struct {
mux sync.RWMutex
cache map[string]*alias
}
// add database alias with original name.
func (ac *_dbCache) add(name string, al *alias) (added bool) {
ac.mux.Lock()
defer ac.mux.Unlock()
@ -62,6 +69,7 @@ func (ac *_dbCache) add(name string, al *alias) (added bool) {
return
}
// get database alias if cached.
func (ac *_dbCache) get(name string) (al *alias, ok bool) {
ac.mux.RLock()
defer ac.mux.RUnlock()
@ -69,6 +77,7 @@ func (ac *_dbCache) get(name string) (al *alias, ok bool) {
return
}
// get default alias.
func (ac *_dbCache) getDefault() (al *alias) {
al, _ = ac.get("default")
return
@ -123,21 +132,18 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
switch al.Driver {
case DR_MySQL:
row := al.DB.QueryRow("SELECT @@session.time_zone")
row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
var tz string
row.Scan(&tz)
if tz == "SYSTEM" {
tz = ""
row = al.DB.QueryRow("SELECT @@system_time_zone")
row.Scan(&tz)
t, err := time.Parse("MST", tz)
if err == nil {
al.TZ = t.Location()
if len(tz) >= 8 {
if tz[0] != '-' {
tz = "+" + tz
}
} else {
t, err := time.Parse("-07:00", tz)
t, err := time.Parse("-07:00:00", tz)
if err == nil {
al.TZ = t.Location()
} else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
}
}
@ -163,6 +169,8 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
loc, err := time.LoadLocation(tz)
if err == nil {
al.TZ = loc
} else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
}
}

View File

@ -4,6 +4,7 @@ import (
"fmt"
)
// mysql operators.
var mysqlOperators = map[string]string{
"exact": "= ?",
"iexact": "LIKE ?",
@ -21,6 +22,7 @@ var mysqlOperators = map[string]string{
"iendswith": "LIKE ?",
}
// mysql column field types.
var mysqlTypes = map[string]string{
"auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY",
"pk": "NOT NULL PRIMARY KEY",
@ -41,29 +43,35 @@ var mysqlTypes = map[string]string{
"float64-decimal": "numeric(%d, %d)",
}
// mysql dbBaser implementation.
type dbBaseMysql struct {
dbBase
}
var _ dbBaser = new(dbBaseMysql)
// get mysql operator.
func (d *dbBaseMysql) OperatorSql(operator string) string {
return mysqlOperators[operator]
}
// get mysql table field types.
func (d *dbBaseMysql) DbTypes() map[string]string {
return mysqlTypes
}
// show table sql for mysql.
func (d *dbBaseMysql) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
}
// show columns sql of table for mysql.
func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
}
// execute sql to check index exist.
func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
@ -72,6 +80,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool
return cnt > 0
}
// create new mysql dbBaser.
func newdbBaseMysql() dbBaser {
b := new(dbBaseMysql)
b.ins = b

View File

@ -1,11 +1,13 @@
package orm
// oracle dbBaser
type dbBaseOracle struct {
dbBase
}
var _ dbBaser = new(dbBaseOracle)
// create oracle dbBaser.
func newdbBaseOracle() dbBaser {
b := new(dbBaseOracle)
b.ins = b

View File

@ -5,6 +5,7 @@ import (
"strconv"
)
// postgresql operators.
var postgresOperators = map[string]string{
"exact": "= ?",
"iexact": "= UPPER(?)",
@ -20,6 +21,7 @@ var postgresOperators = map[string]string{
"iendswith": "LIKE UPPER(?)",
}
// postgresql column field types.
var postgresTypes = map[string]string{
"auto": "serial NOT NULL PRIMARY KEY",
"pk": "NOT NULL PRIMARY KEY",
@ -40,16 +42,19 @@ var postgresTypes = map[string]string{
"float64-decimal": "numeric(%d, %d)",
}
// postgresql dbBaser.
type dbBasePostgres struct {
dbBase
}
var _ dbBaser = new(dbBasePostgres)
// get postgresql operator.
func (d *dbBasePostgres) OperatorSql(operator string) string {
return postgresOperators[operator]
}
// generate functioned sql string, such as contains(text).
func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
switch operator {
case "contains", "startswith", "endswith":
@ -59,6 +64,7 @@ func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string,
}
}
// postgresql unsupports updating joined record.
func (d *dbBasePostgres) SupportUpdateJoin() bool {
return false
}
@ -67,10 +73,13 @@ func (d *dbBasePostgres) MaxLimit() uint64 {
return 0
}
// postgresql quote is ".
func (d *dbBasePostgres) TableQuote() string {
return `"`
}
// postgresql value placeholder is $n.
// replace default ? to $n.
func (d *dbBasePostgres) ReplaceMarks(query *string) {
q := *query
num := 0
@ -97,6 +106,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
*query = string(data)
}
// make returning sql support for postgresql.
func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) {
if mi.fields.pk.auto {
if query != nil {
@ -107,18 +117,22 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool)
return
}
// show table sql for postgresql.
func (d *dbBasePostgres) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
}
// show table columns sql for postgresql.
func (d *dbBasePostgres) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
}
// get column types of postgresql.
func (d *dbBasePostgres) DbTypes() map[string]string {
return postgresTypes
}
// check index exist in postgresql.
func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
row := db.QueryRow(query)
@ -127,6 +141,7 @@ func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bo
return cnt > 0
}
// create new postgresql dbBaser.
func newdbBasePostgres() dbBaser {
b := new(dbBasePostgres)
b.ins = b

View File

@ -5,6 +5,7 @@ import (
"fmt"
)
// sqlite operators.
var sqliteOperators = map[string]string{
"exact": "= ?",
"iexact": "LIKE ? ESCAPE '\\'",
@ -20,6 +21,7 @@ var sqliteOperators = map[string]string{
"iendswith": "LIKE ? ESCAPE '\\'",
}
// sqlite column types.
var sqliteTypes = map[string]string{
"auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT",
"pk": "NOT NULL PRIMARY KEY",
@ -40,38 +42,47 @@ var sqliteTypes = map[string]string{
"float64-decimal": "decimal",
}
// sqlite dbBaser.
type dbBaseSqlite struct {
dbBase
}
var _ dbBaser = new(dbBaseSqlite)
// get sqlite operator.
func (d *dbBaseSqlite) OperatorSql(operator string) string {
return sqliteOperators[operator]
}
// generate functioned sql for sqlite.
// only support DATE(text).
func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
if fi.fieldType == TypeDateField {
*leftCol = fmt.Sprintf("DATE(%s)", *leftCol)
}
}
// unable updating joined record in sqlite.
func (d *dbBaseSqlite) SupportUpdateJoin() bool {
return false
}
// max int in sqlite.
func (d *dbBaseSqlite) MaxLimit() uint64 {
return 9223372036854775807
}
// get column types in sqlite.
func (d *dbBaseSqlite) DbTypes() map[string]string {
return sqliteTypes
}
// get show tables sql in sqlite.
func (d *dbBaseSqlite) ShowTablesQuery() string {
return "SELECT name FROM sqlite_master WHERE type = 'table'"
}
// get columns in sqlite.
func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
query := d.ins.ShowColumnsQuery(table)
rows, err := db.Query(query)
@ -92,10 +103,12 @@ func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]str
return columns, nil
}
// get show columns sql in sqlite.
func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
return fmt.Sprintf("pragma table_info('%s')", table)
}
// check index exist in sqlite.
func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("PRAGMA index_list('%s')", table)
rows, err := db.Query(query)
@ -113,6 +126,7 @@ func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool
return false
}
// create new sqlite dbBaser.
func newdbBaseSqlite() dbBaser {
b := new(dbBaseSqlite)
b.ins = b

View File

@ -6,6 +6,7 @@ import (
"time"
)
// table info struct.
type dbTable struct {
id int
index string
@ -18,13 +19,17 @@ type dbTable struct {
jtl *dbTable
}
// tables collection struct, contains some tables.
type dbTables struct {
tablesM map[string]*dbTable
tables []*dbTable
mi *modelInfo
base dbBaser
skipEnd bool
}
// set table info to collection.
// if not exist, create new.
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
name := strings.Join(names, ExprSep)
if j, ok := t.tablesM[name]; ok {
@ -41,6 +46,7 @@ func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
return t.tablesM[name]
}
// add table info to collection.
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
name := strings.Join(names, ExprSep)
if _, ok := t.tablesM[name]; ok == false {
@ -53,11 +59,14 @@ func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
return t.tablesM[name], false
}
// get table info in collection.
func (t *dbTables) get(name string) (*dbTable, bool) {
j, ok := t.tablesM[name]
return j, ok
}
// get related fields info in recursive depth loop.
// loop once, depth decreases one.
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
if depth < 0 || fi.fieldType == RelManyToMany {
return related
@ -78,6 +87,7 @@ func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []
return related
}
// parse related fields.
func (t *dbTables) parseRelated(rels []string, depth int) {
relsNum := len(rels)
@ -111,7 +121,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
names = append(names, fi.name)
mmi = fi.relModelInfo
if fi.null {
if fi.null || t.skipEnd {
inner = false
}
@ -139,6 +149,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
}
}
// generate join string.
func (t *dbTables) getJoinSql() (join string) {
Q := t.base.TableQuote()
@ -185,9 +196,12 @@ func (t *dbTables) getJoinSql() (join string) {
return
}
// parse orm model struct field tag expression.
func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
var (
jtl *dbTable
fi *fieldInfo
fiN *fieldInfo
mmi = mi
)
@ -196,9 +210,22 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
inner := true
loopFor:
for i, ex := range exprs {
fi, ok := mmi.fields.GetByAny(ex)
var ok, okN bool
if fiN != nil {
fi = fiN
ok = true
fiN = nil
}
if i == 0 {
fi, ok = mmi.fields.GetByAny(ex)
}
_ = okN
if ok {
@ -216,44 +243,61 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
mmi = fi.reverseFieldInfo.mi
}
if i < num {
fiN, okN = mmi.fields.GetByAny(exprs[i+1])
}
if isRel && (fi.mi.isThrough == false || num != i) {
if fi.null {
if fi.null || t.skipEnd {
inner = false
}
jt, _ := t.add(names, mmi, fi, inner)
jt.jtl = jtl
jtl = jt
}
if num == i {
if i == 0 || jtl == nil {
index = "T0"
} else {
index = jtl.index
}
info = fi
if jtl == nil {
name = fi.name
} else {
name = jtl.name + ExprSep + fi.name
}
switch {
case fi.rel:
case fi.reverse:
switch fi.reverseFieldInfo.fieldType {
case RelOneToOne, RelForeignKey:
index = jtl.index
info = fi.reverseFieldInfo.mi.fields.pk
name = info.name
if t.skipEnd && okN || !t.skipEnd {
if t.skipEnd && okN && fiN.pk {
goto loopEnd
}
jt, _ := t.add(names, mmi, fi, inner)
jt.jtl = jtl
jtl = jt
}
}
if num != i {
continue
}
loopEnd:
if i == 0 || jtl == nil {
index = "T0"
} else {
index = jtl.index
}
info = fi
if jtl == nil {
name = fi.name
} else {
name = jtl.name + ExprSep + fi.name
}
switch {
case fi.rel:
case fi.reverse:
switch fi.reverseFieldInfo.fieldType {
case RelOneToOne, RelForeignKey:
index = jtl.index
info = fi.reverseFieldInfo.mi.fields.pk
name = info.name
}
}
break loopFor
} else {
index = ""
name = ""
@ -267,6 +311,7 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
return
}
// generate condition sql.
func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
if cond == nil || cond.IsEmpty() {
return
@ -331,6 +376,7 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
return
}
// generate order sql.
func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
if len(orders) == 0 {
return
@ -359,6 +405,7 @@ func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
return
}
// generate limit sql.
func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) {
if limit == 0 {
limit = int64(DefaultRowsLimit)
@ -381,6 +428,7 @@ func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits
return
}
// crete new tables collection.
func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
tables := &dbTables{}
tables.tablesM = make(map[string]*dbTable)

View File

@ -6,6 +6,7 @@ import (
"time"
)
// get table alias.
func getDbAlias(name string) *alias {
if al, ok := dataBaseCache.get(name); ok {
return al
@ -15,6 +16,7 @@ func getDbAlias(name string) *alias {
return nil
}
// get pk column info.
func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
fi := mi.fields.pk
@ -37,6 +39,7 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac
return
}
// get fields description as flatted string.
func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) {
outFor:

View File

@ -41,6 +41,7 @@ var (
}
)
// model info collection
type _modelCache struct {
sync.RWMutex
orders []string
@ -49,6 +50,7 @@ type _modelCache struct {
done bool
}
// get all model info
func (mc *_modelCache) all() map[string]*modelInfo {
m := make(map[string]*modelInfo, len(mc.cache))
for k, v := range mc.cache {
@ -57,6 +59,7 @@ func (mc *_modelCache) all() map[string]*modelInfo {
return m
}
// get orderd model info
func (mc *_modelCache) allOrdered() []*modelInfo {
m := make([]*modelInfo, 0, len(mc.orders))
for _, table := range mc.orders {
@ -65,16 +68,19 @@ func (mc *_modelCache) allOrdered() []*modelInfo {
return m
}
// get model info by table name
func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) {
mi, ok = mc.cache[table]
return
}
// get model info by field name
func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) {
mi, ok = mc.cacheByFN[name]
return
}
// set model info to collection
func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
mii := mc.cache[table]
mc.cache[table] = mi
@ -85,6 +91,7 @@ func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
return mii
}
// clean all model info.
func (mc *_modelCache) clean() {
mc.orders = make([]string, 0)
mc.cache = make(map[string]*modelInfo)

View File

@ -8,6 +8,8 @@ import (
"strings"
)
// register models.
// prefix means table name prefix.
func registerModel(model interface{}, prefix string) {
val := reflect.ValueOf(model)
ind := reflect.Indirect(val)
@ -67,6 +69,7 @@ func registerModel(model interface{}, prefix string) {
modelCache.set(table, info)
}
// boostrap models
func bootStrap() {
if modelCache.done {
return
@ -281,6 +284,7 @@ end:
}
}
// register models
func RegisterModel(models ...interface{}) {
if modelCache.done {
panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
@ -302,6 +306,8 @@ func RegisterModelWithPrefix(prefix string, models ...interface{}) {
}
}
// bootrap models.
// make all model parsed and can not add more models
func BootStrap() {
if modelCache.done {
return

View File

@ -9,6 +9,7 @@ import (
var errSkipField = errors.New("skip field")
// field info collection
type fields struct {
pk *fieldInfo
columns map[string]*fieldInfo
@ -23,6 +24,7 @@ type fields struct {
dbcols []string
}
// add field info
func (f *fields) Add(fi *fieldInfo) (added bool) {
if f.fields[fi.name] == nil && f.columns[fi.column] == nil {
f.columns[fi.column] = fi
@ -49,14 +51,17 @@ func (f *fields) Add(fi *fieldInfo) (added bool) {
return true
}
// get field info by name
func (f *fields) GetByName(name string) *fieldInfo {
return f.fields[name]
}
// get field info by column name
func (f *fields) GetByColumn(column string) *fieldInfo {
return f.columns[column]
}
// get field info by string, name is prior
func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
if fi, ok := f.fields[name]; ok {
return fi, ok
@ -70,6 +75,7 @@ func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
return nil, false
}
// create new field info collection
func newFields() *fields {
f := new(fields)
f.fields = make(map[string]*fieldInfo)
@ -79,6 +85,7 @@ func newFields() *fields {
return f
}
// single field info
type fieldInfo struct {
mi *modelInfo
fieldIndex int
@ -115,6 +122,7 @@ type fieldInfo struct {
onDelete string
}
// new field info
func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (fi *fieldInfo, err error) {
var (
tag string

View File

@ -7,6 +7,7 @@ import (
"reflect"
)
// single model info
type modelInfo struct {
pkg string
name string
@ -20,6 +21,7 @@ type modelInfo struct {
isThrough bool
}
// new model info
func newModelInfo(val reflect.Value) (info *modelInfo) {
var (
err error
@ -79,6 +81,8 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
return
}
// combine related model info to new model info.
// prepare for relation models query.
func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
info = new(modelInfo)
info.fields = newFields()

View File

@ -7,10 +7,12 @@ import (
"time"
)
// get reflect.Type name with package path.
func getFullName(typ reflect.Type) string {
return typ.PkgPath() + "." + typ.Name()
}
// get table name. method, or field name. auto snaked.
func getTableName(val reflect.Value) string {
ind := reflect.Indirect(val)
fun := val.MethodByName("TableName")
@ -26,6 +28,7 @@ func getTableName(val reflect.Value) string {
return snakeString(ind.Type().Name())
}
// get table engine, mysiam or innodb.
func getTableEngine(val reflect.Value) string {
fun := val.MethodByName("TableEngine")
if fun.IsValid() {
@ -40,6 +43,7 @@ func getTableEngine(val reflect.Value) string {
return ""
}
// get table index from method.
func getTableIndex(val reflect.Value) [][]string {
fun := val.MethodByName("TableIndex")
if fun.IsValid() {
@ -56,6 +60,7 @@ func getTableIndex(val reflect.Value) [][]string {
return nil
}
// get table unique from method
func getTableUnique(val reflect.Value) [][]string {
fun := val.MethodByName("TableUnique")
if fun.IsValid() {
@ -72,6 +77,7 @@ func getTableUnique(val reflect.Value) [][]string {
return nil
}
// get snaked column name
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
col = strings.ToLower(col)
column := col
@ -89,6 +95,7 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col
return column
}
// return field type as type constant from reflect.Value
func getFieldType(val reflect.Value) (ft int, err error) {
elm := reflect.Indirect(val)
switch elm.Kind() {
@ -128,6 +135,7 @@ func getFieldType(val reflect.Value) (ft int, err error) {
return
}
// parse struct tag string
func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) {
attr := make(map[string]bool)
tag := make(map[string]string)

View File

@ -25,6 +25,7 @@ var (
ErrMultiRows = errors.New("<QuerySeter> return multi rows")
ErrNoRows = errors.New("<QuerySeter> no row found")
ErrStmtClosed = errors.New("<QuerySeter> stmt already closed")
ErrArgs = errors.New("<Ormer> args error may be empty")
ErrNotImplement = errors.New("have not implement")
)
@ -39,11 +40,12 @@ type orm struct {
var _ Ormer = new(orm)
func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
// get model info and model reflect value
func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) {
val := reflect.ValueOf(md)
ind = reflect.Indirect(val)
typ := ind.Type()
if val.Kind() != reflect.Ptr {
if needPtr && val.Kind() != reflect.Ptr {
panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
}
name := getFullName(typ)
@ -53,6 +55,7 @@ func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
}
// get field info from model info by given field name
func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
fi, ok := mi.fields.GetByAny(name)
if !ok {
@ -61,8 +64,9 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
return fi
}
// read data to model
func (o *orm) Read(md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md)
mi, ind := o.getMiInd(md, true)
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
if err != nil {
return err
@ -70,26 +74,69 @@ func (o *orm) Read(md interface{}, cols ...string) error {
return nil
}
// insert model data to database
func (o *orm) Insert(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md)
mi, ind := o.getMiInd(md, true)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
if err != nil {
return id, err
}
if id > 0 {
if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id))
} else {
ind.Field(mi.fields.pk.fieldIndex).SetInt(id)
}
}
}
o.setPk(mi, ind, id)
return id, nil
}
// set auto pk field
func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) {
if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id))
} else {
ind.Field(mi.fields.pk.fieldIndex).SetInt(id)
}
}
}
// insert some models to database
func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
var cnt int64
sind := reflect.Indirect(reflect.ValueOf(mds))
switch sind.Kind() {
case reflect.Array, reflect.Slice:
if sind.Len() == 0 {
return cnt, ErrArgs
}
default:
return cnt, ErrArgs
}
if bulk <= 1 {
for i := 0; i < sind.Len(); i++ {
ind := sind.Index(i)
mi, _ := o.getMiInd(ind.Interface(), false)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
if err != nil {
return cnt, err
}
o.setPk(mi, ind, id)
cnt += 1
}
} else {
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ)
}
return cnt, nil
}
// update model to database.
// cols set the columns those want to update.
func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
mi, ind := o.getMiInd(md)
mi, ind := o.getMiInd(md, true)
num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
if err != nil {
return num, err
@ -97,26 +144,22 @@ func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
return num, nil
}
// delete model in database
func (o *orm) Delete(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md)
mi, ind := o.getMiInd(md, true)
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ)
if err != nil {
return num, err
}
if num > 0 {
if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
ind.Field(mi.fields.pk.fieldIndex).SetUint(0)
} else {
ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
}
}
o.setPk(mi, ind, 0)
}
return num, nil
}
// create a models to models queryer
func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
mi, ind := o.getMiInd(md)
mi, ind := o.getMiInd(md, true)
fi := o.getFieldInfo(mi, name)
switch {
@ -129,6 +172,14 @@ func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
return newQueryM2M(md, o, mi, fi, ind)
}
// load related models to md model.
// args are limit, offset int and order string.
//
// example:
// orm.LoadRelated(post,"Tags")
// for _,tag := range post.Tags{...}
//
// make sure the relation is defined in model struct tags.
func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) {
_, fi, ind, qseter := o.queryRelated(md, name)
@ -190,14 +241,21 @@ func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int
return nums, err
}
// return a QuerySeter for related models to md model.
// it can do all, update, delete in QuerySeter.
// example:
// qs := orm.QueryRelated(post,"Tag")
// qs.All(&[]*Tag{})
//
func (o *orm) QueryRelated(md interface{}, name string) QuerySeter {
// is this api needed ?
_, _, _, qs := o.queryRelated(md, name)
return qs
}
// get QuerySeter for related models to md model
func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) {
mi, ind := o.getMiInd(md)
mi, ind := o.getMiInd(md, true)
fi := o.getFieldInfo(mi, name)
_, _, exist := getExistPk(mi, ind)
@ -227,6 +285,7 @@ func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo,
return mi, fi, ind, qs
}
// get reverse relation QuerySeter
func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
switch fi.fieldType {
case RelReverseOne, RelReverseMany:
@ -247,6 +306,7 @@ func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS
return q
}
// get relation QuerySeter
func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
switch fi.fieldType {
case RelOneToOne, RelForeignKey, RelManyToMany:
@ -266,6 +326,9 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
return q
}
// return a QuerySeter for table operations.
// table name can be string or struct.
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
name := ""
if table, ok := ptrStructOrTableName.(string); ok {
@ -285,6 +348,7 @@ func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
return
}
// switch to another registered database driver by given name.
func (o *orm) Using(name string) error {
if o.isTx {
panic(fmt.Errorf("<Ormer.Using> transaction has been start, cannot change db"))
@ -302,6 +366,7 @@ func (o *orm) Using(name string) error {
return nil
}
// begin transaction
func (o *orm) Begin() error {
if o.isTx {
return ErrTxHasBegan
@ -320,6 +385,7 @@ func (o *orm) Begin() error {
return nil
}
// commit transaction
func (o *orm) Commit() error {
if o.isTx == false {
return ErrTxDone
@ -334,6 +400,7 @@ func (o *orm) Commit() error {
return err
}
// rollback transaction
func (o *orm) Rollback() error {
if o.isTx == false {
return ErrTxDone
@ -348,14 +415,17 @@ func (o *orm) Rollback() error {
return err
}
// return a raw query seter for raw sql string.
func (o *orm) Raw(query string, args ...interface{}) RawSeter {
return newRawSet(o, query, args)
}
// return current using database Driver
func (o *orm) Driver() Driver {
return driver(o.alias.Name)
}
// create new orm
func NewOrm() Ormer {
BootStrap() // execute only once

View File

@ -18,15 +18,19 @@ type condValue struct {
isCond bool
}
// condition struct.
// work for WHERE conditions.
type Condition struct {
params []condValue
}
// return new condition struct
func NewCondition() *Condition {
c := &Condition{}
return c
}
// add expression to condition
func (c Condition) And(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.And> args cannot empty"))
@ -35,6 +39,7 @@ func (c Condition) And(expr string, args ...interface{}) *Condition {
return &c
}
// add NOT expression to condition
func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.AndNot> args cannot empty"))
@ -43,6 +48,7 @@ func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
return &c
}
// combine a condition to current condition
func (c *Condition) AndCond(cond *Condition) *Condition {
c = c.clone()
if c == cond {
@ -54,6 +60,7 @@ func (c *Condition) AndCond(cond *Condition) *Condition {
return c
}
// add OR expression to condition
func (c Condition) Or(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.Or> args cannot empty"))
@ -62,6 +69,7 @@ func (c Condition) Or(expr string, args ...interface{}) *Condition {
return &c
}
// add OR NOT expression to condition
func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.OrNot> args cannot empty"))
@ -70,6 +78,7 @@ func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
return &c
}
// combine a OR condition to current condition
func (c *Condition) OrCond(cond *Condition) *Condition {
c = c.clone()
if c == cond {
@ -81,10 +90,12 @@ func (c *Condition) OrCond(cond *Condition) *Condition {
return c
}
// check the condition arguments are empty or not.
func (c *Condition) IsEmpty() bool {
return len(c.params) == 0
}
// clone a condition
func (c Condition) clone() *Condition {
return &c
}

View File

@ -13,6 +13,7 @@ type Log struct {
*log.Logger
}
// set io.Writer to create a Logger.
func NewLog(out io.Writer) *Log {
d := new(Log)
d.Logger = log.New(out, "[ORM]", 1e9)
@ -40,6 +41,8 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error
DebugLog.Println(con)
}
// statement query logger struct.
// if dev mode, use stmtQueryLog, or use stmtQuerier.
type stmtQueryLog struct {
alias *alias
query string
@ -84,6 +87,8 @@ func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier {
return d
}
// database query logger struct.
// if dev mode, use dbQueryLog, or use dbQuerier.
type dbQueryLog struct {
alias *alias
db dbQuerier

View File

@ -5,6 +5,7 @@ import (
"reflect"
)
// an insert queryer struct
type insertSet struct {
mi *modelInfo
orm *orm
@ -14,6 +15,7 @@ type insertSet struct {
var _ Inserter = new(insertSet)
// insert model ignore it's registered or not.
func (o *insertSet) Insert(md interface{}) (int64, error) {
if o.closed {
return 0, ErrStmtClosed
@ -44,6 +46,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) {
return id, nil
}
// close insert queryer statement
func (o *insertSet) Close() error {
if o.closed {
return ErrStmtClosed
@ -52,6 +55,7 @@ func (o *insertSet) Close() error {
return o.stmt.Close()
}
// create new insert queryer.
func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) {
bi := new(insertSet)
bi.orm = orm

View File

@ -4,6 +4,7 @@ import (
"reflect"
)
// model to model struct
type queryM2M struct {
md interface{}
mi *modelInfo
@ -12,6 +13,13 @@ type queryM2M struct {
ind reflect.Value
}
// add models to origin models when creating queryM2M.
// example:
// m2m := orm.QueryM2M(post,"Tag")
// m2m.Add(&Tag1{},&Tag2{})
// for _,tag := range post.Tags{}
//
// make sure the relation is defined in post model struct tag.
func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
fi := o.fi
mi := fi.relThroughModelInfo
@ -44,7 +52,8 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
names := []string{mfi.column, rfi.column}
var nums int64
values := make([]interface{}, 0, len(models)*2)
for _, md := range models {
ind := reflect.Indirect(reflect.ValueOf(md))
@ -59,18 +68,14 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
}
}
values := []interface{}{v1, v2}
_, err := dbase.InsertValue(orm.db, mi, names, values)
if err != nil {
return nums, err
}
values = append(values, v1, v2)
nums += 1
}
return nums, nil
return dbase.InsertValue(orm.db, mi, true, names, values)
}
// remove models following the origin model relationship
func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
fi := o.fi
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
@ -82,17 +87,20 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
return nums, nil
}
// check model is existed in relationship of origin model
func (o *queryM2M) Exist(md interface{}) bool {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
Filter(fi.reverseFieldInfoTwo.name, md).Exist()
}
// clean all models in related of origin model
func (o *queryM2M) Clear() (int64, error) {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete()
}
// count all related models of origin model
func (o *queryM2M) Count() (int64, error) {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count()
@ -100,6 +108,7 @@ func (o *queryM2M) Count() (int64, error) {
var _ QueryM2Mer = new(queryM2M)
// create new M2M queryer.
func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer {
qm2m := new(queryM2M)
qm2m.md = md

View File

@ -18,6 +18,10 @@ const (
Col_Except
)
// ColValue do the field raw changes. e.g Nums = Nums + 10. usage:
// Params{
// "Nums": ColValue(Col_Add, 10),
// }
func ColValue(opt operator, value interface{}) interface{} {
switch opt {
case Col_Add, Col_Minus, Col_Multiply, Col_Except:
@ -34,6 +38,7 @@ func ColValue(opt operator, value interface{}) interface{} {
return val
}
// real query struct
type querySet struct {
mi *modelInfo
cond *Condition
@ -47,6 +52,7 @@ type querySet struct {
var _ QuerySeter = new(querySet)
// add condition expression to QuerySeter.
func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
if o.cond == nil {
o.cond = NewCondition()
@ -55,6 +61,7 @@ func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
return &o
}
// add NOT condition to querySeter.
func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
if o.cond == nil {
o.cond = NewCondition()
@ -63,10 +70,13 @@ func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
return &o
}
// set offset number
func (o *querySet) setOffset(num interface{}) {
o.offset = ToInt64(num)
}
// add LIMIT value.
// args[0] means offset, e.g. LIMIT num,offset.
func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
o.limit = ToInt64(limit)
if len(args) > 0 {
@ -75,16 +85,21 @@ func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
return &o
}
// add OFFSET value
func (o querySet) Offset(offset interface{}) QuerySeter {
o.setOffset(offset)
return &o
}
// add ORDER expression.
// "column" means ASC, "-column" means DESC.
func (o querySet) OrderBy(exprs ...string) QuerySeter {
o.orders = exprs
return &o
}
// set relation model to query together.
// it will query relation models and assign to parent model.
func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
var related []string
if len(params) == 0 {
@ -105,36 +120,50 @@ func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
return &o
}
// set condition to QuerySeter.
func (o querySet) SetCond(cond *Condition) QuerySeter {
o.cond = cond
return &o
}
// return QuerySeter execution result number
func (o *querySet) Count() (int64, error) {
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
}
// check result empty or not after QuerySeter executed
func (o *querySet) Exist() bool {
cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
return cnt > 0
}
// execute update with parameters
func (o *querySet) Update(values Params) (int64, error) {
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
}
// execute delete
func (o *querySet) Delete() (int64, error) {
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
}
// return a insert queryer.
// it can be used in times.
// example:
// i,err := sq.PrepareInsert()
// i.Add(&user1{},&user2{})
func (o *querySet) PrepareInsert() (Inserter, error) {
return newInsertSet(o.orm, o.mi)
}
// query all data and map to containers.
// cols means the columns when querying.
func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
}
// query one row data and map to containers.
// cols means the columns when querying.
func (o *querySet) One(container interface{}, cols ...string) error {
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
if err != nil {
@ -149,18 +178,26 @@ func (o *querySet) One(container interface{}, cols ...string) error {
return nil
}
// query all data and map to []map[string]interface.
// expres means condition expression.
// it converts data to []map[column]value.
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
}
// query all data and map to [][]interface
// it converts data to [][column_index]value
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
}
// query all data and map to []interface.
// it's designed for one row record set, auto change to []value, not [][column]value.
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
}
// create new QuerySeter.
func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {
o := new(querySet)
o.mi = mi

View File

@ -4,10 +4,10 @@ import (
"database/sql"
"fmt"
"reflect"
"strings"
"time"
)
// raw sql string prepared statement
type rawPrepare struct {
rs *rawSet
stmt stmtQuerier
@ -45,6 +45,7 @@ func newRawPreparer(rs *rawSet) (RawPreparer, error) {
return o, nil
}
// raw query seter
type rawSet struct {
query string
args []interface{}
@ -53,11 +54,13 @@ type rawSet struct {
var _ RawSeter = new(rawSet)
// set args for every query
func (o rawSet) SetArgs(args ...interface{}) RawSeter {
o.args = args
return &o
}
// execute raw sql and return sql.Result
func (o *rawSet) Exec() (sql.Result, error) {
query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query)
@ -66,6 +69,7 @@ func (o *rawSet) Exec() (sql.Result, error) {
return o.orm.db.Exec(query, args...)
}
// set field value to row container
func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
switch ind.Kind() {
case reflect.Bool:
@ -164,65 +168,12 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
}
}
func (o *rawSet) loopInitRefs(typ reflect.Type, refsPtr *[]interface{}, sIdxesPtr *[][]int) {
sIdxes := *sIdxesPtr
refs := *refsPtr
if typ.Kind() == reflect.Struct {
if typ.String() == "time.Time" {
var ref interface{}
refs = append(refs, &ref)
sIdxes = append(sIdxes, []int{0})
} else {
idxs := []int{}
outFor:
for idx := 0; idx < typ.NumField(); idx++ {
ctyp := typ.Field(idx)
tag := ctyp.Tag.Get(defaultStructTagName)
for _, v := range strings.Split(tag, defaultStructTagDelim) {
if v == "-" {
continue outFor
}
}
tp := ctyp.Type
if tp.Kind() == reflect.Ptr {
tp = tp.Elem()
}
if tp.String() == "time.Time" {
var ref interface{}
refs = append(refs, &ref)
} else if tp.Kind() != reflect.Struct {
var ref interface{}
refs = append(refs, &ref)
} else {
// skip other type
continue
}
idxs = append(idxs, idx)
}
sIdxes = append(sIdxes, idxs)
}
} else {
var ref interface{}
refs = append(refs, &ref)
sIdxes = append(sIdxes, []int{0})
}
*sIdxesPtr = sIdxes
*refsPtr = refs
}
func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) {
// set field value in loop for slice container
func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) {
nInds := *nIndsPtr
cur := 0
for i, idxs := range sIdxes {
for i := 0; i < len(sInds); i++ {
sInd := sInds[i]
eTyp := eTyps[i]
@ -258,32 +209,8 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect
o.setFieldValue(ind, value)
}
cur++
} else {
hasValue := false
for _, idx := range idxs {
tind := ind.Field(idx)
value := reflect.ValueOf(refs[cur]).Elem().Interface()
if value != nil {
hasValue = true
}
if tind.Kind() == reflect.Ptr {
if value == nil {
tindV := reflect.New(tind.Type()).Elem()
tind.Set(tindV)
} else {
tindV := reflect.New(tind.Type().Elem())
o.setFieldValue(tindV.Elem(), value)
tind.Set(tindV)
}
} else {
o.setFieldValue(tind, value)
}
cur++
}
if hasValue == false && isPtr {
val = reflect.New(val.Type()).Elem()
}
}
} else {
value := reflect.ValueOf(refs[cur]).Elem().Interface()
if isPtr && value == nil {
@ -312,16 +239,14 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect
}
}
// query data and map to container
func (o *rawSet) QueryRow(containers ...interface{}) error {
if len(containers) == 0 {
panic(fmt.Errorf("<RawSeter.QueryRow> need at least one arg"))
}
refs := make([]interface{}, 0, len(containers))
sIdxes := make([][]int, 0)
sInds := make([]reflect.Value, 0)
eTyps := make([]reflect.Type, 0)
structMode := false
var sMi *modelInfo
for _, container := range containers {
val := reflect.ValueOf(container)
ind := reflect.Indirect(val)
@ -335,44 +260,123 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
sInds = append(sInds, ind)
eTyps = append(eTyps, etyp)
o.loopInitRefs(typ, &refs, &sIdxes)
if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
if len(containers) > 1 {
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
}
structMode = true
fn := getFullName(typ)
if mi, ok := modelCache.getByFN(fn); ok {
sMi = mi
}
} else {
var ref interface{}
refs = append(refs, &ref)
}
}
query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
row := o.orm.db.QueryRow(query, args...)
if err := row.Scan(refs...); err == sql.ErrNoRows {
return ErrNoRows
} else if err != nil {
rows, err := o.orm.db.Query(query, args...)
if err != nil {
if err == sql.ErrNoRows {
return ErrNoRows
}
return err
}
nInds := make([]reflect.Value, len(sInds))
o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, true)
for i, sInd := range sInds {
nInd := nInds[i]
sInd.Set(nInd)
defer rows.Close()
if rows.Next() {
if structMode {
columns, err := rows.Columns()
if err != nil {
return err
}
columnsMp := make(map[string]interface{}, len(columns))
refs = make([]interface{}, 0, len(columns))
for _, col := range columns {
var ref interface{}
columnsMp[col] = &ref
refs = append(refs, &ref)
}
if err := rows.Scan(refs...); err != nil {
return err
}
ind := sInds[0]
if ind.Kind() == reflect.Ptr {
if ind.IsNil() || !ind.IsValid() {
ind.Set(reflect.New(eTyps[0].Elem()))
}
ind = ind.Elem()
}
if sMi != nil {
for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value)
}
}
} else {
for i := 0; i < ind.NumField(); i++ {
f := ind.Field(i)
fe := ind.Type().Field(i)
var attrs map[string]bool
var tags map[string]string
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
var col string
if col = tags["column"]; len(col) == 0 {
col = snakeString(fe.Name)
}
if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface()
o.setFieldValue(f, value)
}
}
}
} else {
if err := rows.Scan(refs...); err != nil {
return err
}
nInds := make([]reflect.Value, len(sInds))
o.loopSetRefs(refs, sInds, &nInds, eTyps, true)
for i, sInd := range sInds {
nInd := nInds[i]
sInd.Set(nInd)
}
}
} else {
return ErrNoRows
}
return nil
}
// query data rows and map to container
func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
refs := make([]interface{}, 0)
sIdxes := make([][]int, 0)
refs := make([]interface{}, 0, len(containers))
sInds := make([]reflect.Value, 0)
eTyps := make([]reflect.Type, 0)
structMode := false
var sMi *modelInfo
for _, container := range containers {
val := reflect.ValueOf(container)
sInd := reflect.Indirect(val)
@ -389,7 +393,20 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
sInds = append(sInds, sInd)
eTyps = append(eTyps, etyp)
o.loopInitRefs(typ, &refs, &sIdxes)
if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
if len(containers) > 1 {
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
}
structMode = true
fn := getFullName(typ)
if mi, ok := modelCache.getByFN(fn); ok {
sMi = mi
}
} else {
var ref interface{}
refs = append(refs, &ref)
}
}
query := o.query
@ -401,23 +418,100 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
return 0, err
}
nInds := make([]reflect.Value, len(sInds))
defer rows.Close()
var cnt int64
for rows.Next() {
if err := rows.Scan(refs...); err != nil {
return 0, err
}
nInds := make([]reflect.Value, len(sInds))
sInd := sInds[0]
o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, cnt == 0)
for rows.Next() {
if structMode {
columns, err := rows.Columns()
if err != nil {
return 0, err
}
columnsMp := make(map[string]interface{}, len(columns))
refs = make([]interface{}, 0, len(columns))
for _, col := range columns {
var ref interface{}
columnsMp[col] = &ref
refs = append(refs, &ref)
}
if err := rows.Scan(refs...); err != nil {
return 0, err
}
if cnt == 0 && !sInd.IsNil() {
sInd.Set(reflect.New(sInd.Type()).Elem())
}
var ind reflect.Value
if eTyps[0].Kind() == reflect.Ptr {
ind = reflect.New(eTyps[0].Elem())
} else {
ind = reflect.New(eTyps[0])
}
if ind.Kind() == reflect.Ptr {
ind = ind.Elem()
}
if sMi != nil {
for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value)
}
}
} else {
for i := 0; i < ind.NumField(); i++ {
f := ind.Field(i)
fe := ind.Type().Field(i)
var attrs map[string]bool
var tags map[string]string
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
var col string
if col = tags["column"]; len(col) == 0 {
col = snakeString(fe.Name)
}
if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface()
o.setFieldValue(f, value)
}
}
}
if eTyps[0].Kind() == reflect.Ptr {
ind = ind.Addr()
}
sInd = reflect.Append(sInd, ind)
} else {
if err := rows.Scan(refs...); err != nil {
return 0, err
}
o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0)
}
cnt++
}
if cnt > 0 {
for i, sInd := range sInds {
nInd := nInds[i]
sInd.Set(nInd)
if structMode {
sInds[0].Set(sInd)
} else {
for i, sInd := range sInds {
nInd := nInds[i]
sInd.Set(nInd)
}
}
}
@ -455,6 +549,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
rs = r
}
defer rs.Close()
var (
refs []interface{}
cnt int64
@ -527,18 +623,22 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
return cnt, nil
}
// query data to []map[string]interface
func (o *rawSet) Values(container *[]Params) (int64, error) {
return o.readValues(container)
}
// query data to [][]interface
func (o *rawSet) ValuesList(container *[]ParamsList) (int64, error) {
return o.readValues(container)
}
// query data to []interface
func (o *rawSet) ValuesFlat(container *ParamsList) (int64, error) {
return o.readValues(container)
}
// return prepared raw statement for used in times.
func (o *rawSet) Prepare() (RawPreparer, error) {
return newRawPreparer(o)
}

View File

@ -1322,58 +1322,6 @@ func TestRawQueryRow(t *testing.T) {
}
}
type Tmp struct {
Skip0 string
Id int
Char *string
Skip1 int `orm:"-"`
Date time.Time
DateTime time.Time
}
Boolean = false
Text = ""
Int64 = 0
Uint = 0
tmp := new(Tmp)
cols = []string{
"int", "char", "date", "datetime", "boolean", "text", "int64", "uint",
}
query = fmt.Sprintf("SELECT NULL, %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q)
values = []interface{}{
tmp, &Boolean, &Text, &Int64, &Uint,
}
err = dORM.Raw(query, 1).QueryRow(values...)
throwFailNow(t, err)
for _, col := range cols {
switch col {
case "id":
throwFail(t, AssertIs(tmp.Id, data_values[col]))
case "char":
c := tmp.Char
throwFail(t, AssertIs(*c, data_values[col]))
case "date":
v := tmp.Date.In(DefaultTimeLoc)
value := data_values[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, test_Date))
case "datetime":
v := tmp.DateTime.In(DefaultTimeLoc)
value := data_values[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, test_DateTime))
case "boolean":
throwFail(t, AssertIs(Boolean, data_values[col]))
case "text":
throwFail(t, AssertIs(Text, data_values[col]))
case "int64":
throwFail(t, AssertIs(Int64, data_values[col]))
case "uint":
throwFail(t, AssertIs(Uint, data_values[col]))
}
}
var (
uid int
status *int
@ -1394,22 +1342,13 @@ func TestRawQueryRow(t *testing.T) {
func TestQueryRows(t *testing.T) {
Q := dDbBaser.TableQuote()
cols := []string{
"id", "boolean", "char", "text", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32",
"int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal",
}
var datas []*Data
var dids []int
sep := fmt.Sprintf("%s, %s", Q, Q)
query := fmt.Sprintf("SELECT %s%s%s, id FROM %sdata%s", Q, strings.Join(cols, sep), Q, Q, Q)
num, err := dORM.Raw(query).QueryRows(&datas, &dids)
query := fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q)
num, err := dORM.Raw(query).QueryRows(&datas)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(len(datas), 1))
throwFailNow(t, AssertIs(len(dids), 1))
throwFailNow(t, AssertIs(dids[0], 1))
ind := reflect.Indirect(reflect.ValueOf(datas[0]))
@ -1427,90 +1366,43 @@ func TestQueryRows(t *testing.T) {
throwFail(t, AssertIs(vu == value, true), value, vu)
}
type Tmp struct {
Id int
Name string
Skiped0 string `orm:"-"`
Pid *int
Skiped1 Data
Skiped2 *Data
var datas2 []Data
query = fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q)
num, err = dORM.Raw(query).QueryRows(&datas2)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(len(datas2), 1))
ind = reflect.Indirect(reflect.ValueOf(datas2[0]))
for name, value := range Data_Values {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
}
throwFail(t, AssertIs(vu == value, true), value, vu)
}
var (
ids []int
userNames []string
profileIds1 []int
profileIds2 []*int
createds []time.Time
updateds []time.Time
tmps1 []*Tmp
tmps2 []Tmp
)
cols = []string{
"id", "user_name", "profile_id", "profile_id", "id", "user_name", "profile_id", "id", "user_name", "profile_id", "created", "updated",
}
query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s ORDER BY id", Q, strings.Join(cols, sep), Q, Q, Q)
num, err = dORM.Raw(query).QueryRows(&ids, &userNames, &profileIds1, &profileIds2, &tmps1, &tmps2, &createds, &updateds)
var ids []int
var usernames []string
query = fmt.Sprintf("SELECT %sid%s, %suser_name%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q, Q, Q)
num, err = dORM.Raw(query).QueryRows(&ids, &usernames)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3))
var users []User
dORM.QueryTable("user").OrderBy("Id").All(&users)
for i := 0; i < 3; i++ {
id := ids[i]
name := userNames[i]
pid1 := profileIds1[i]
pid2 := profileIds2[i]
created := createds[i]
updated := updateds[i]
user := users[i]
throwFailNow(t, AssertIs(id, user.Id))
throwFailNow(t, AssertIs(name, user.UserName))
if user.Profile != nil {
throwFailNow(t, AssertIs(pid1, user.Profile.Id))
throwFailNow(t, AssertIs(*pid2, user.Profile.Id))
} else {
throwFailNow(t, AssertIs(pid1, 0))
throwFailNow(t, AssertIs(pid2, nil))
}
throwFailNow(t, AssertIs(created, user.Created, test_Date))
throwFailNow(t, AssertIs(updated, user.Updated, test_DateTime))
tmp := tmps1[i]
tmp1 := *tmp
throwFailNow(t, AssertIs(tmp1.Id, user.Id))
throwFailNow(t, AssertIs(tmp1.Name, user.UserName))
if user.Profile != nil {
pid := tmp1.Pid
throwFailNow(t, AssertIs(*pid, user.Profile.Id))
} else {
throwFailNow(t, AssertIs(tmp1.Pid, nil))
}
tmp2 := tmps2[i]
throwFailNow(t, AssertIs(tmp2.Id, user.Id))
throwFailNow(t, AssertIs(tmp2.Name, user.UserName))
if user.Profile != nil {
pid := tmp2.Pid
throwFailNow(t, AssertIs(*pid, user.Profile.Id))
} else {
throwFailNow(t, AssertIs(tmp2.Pid, nil))
}
}
type Sec struct {
Id int
Name string
}
var tmp []*Sec
query = fmt.Sprintf("SELECT NULL, NULL FROM %suser%s LIMIT 1", Q, Q)
num, err = dORM.Raw(query).QueryRows(&tmp)
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
throwFail(t, AssertIs(tmp[0], nil))
throwFailNow(t, AssertIs(len(ids), 3))
throwFailNow(t, AssertIs(ids[0], 2))
throwFailNow(t, AssertIs(usernames[0], "slene"))
throwFailNow(t, AssertIs(ids[1], 3))
throwFailNow(t, AssertIs(usernames[1], "astaxie"))
throwFailNow(t, AssertIs(ids[2], 4))
throwFailNow(t, AssertIs(usernames[2], "nobody"))
}
func TestRawValues(t *testing.T) {
@ -1669,6 +1561,32 @@ func TestDelete(t *testing.T) {
num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
qs = dORM.QueryTable("comment")
num, err = qs.Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 6))
qs = dORM.QueryTable("post")
num, err = qs.Filter("Id", 3).Delete()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
qs = dORM.QueryTable("comment")
num, err = qs.Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 4))
fmt.Println("...")
qs = dORM.QueryTable("comment")
num, err = qs.Filter("Post__User", 3).Delete()
throwFail(t, err)
throwFail(t, AssertIs(num, 3))
qs = dORM.QueryTable("comment")
num, err = qs.Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
}
func TestTransaction(t *testing.T) {

View File

@ -6,11 +6,13 @@ import (
"time"
)
// database driver
type Driver interface {
Name() string
Type() DriverType
}
// field info
type Fielder interface {
String() string
FieldType() int
@ -18,9 +20,11 @@ type Fielder interface {
RawValue() interface{}
}
// orm struct
type Ormer interface {
Read(interface{}, ...string) error
Insert(interface{}) (int64, error)
InsertMulti(int, interface{}) (int64, error)
Update(interface{}, ...string) (int64, error)
Delete(interface{}) (int64, error)
LoadRelated(interface{}, string, ...interface{}) (int64, error)
@ -34,11 +38,13 @@ type Ormer interface {
Driver() Driver
}
// insert prepared statement
type Inserter interface {
Insert(interface{}) (int64, error)
Close() error
}
// query seter
type QuerySeter interface {
Filter(string, ...interface{}) QuerySeter
Exclude(string, ...interface{}) QuerySeter
@ -59,6 +65,7 @@ type QuerySeter interface {
ValuesFlat(*ParamsList, string) (int64, error)
}
// model to model query struct
type QueryM2Mer interface {
Add(...interface{}) (int64, error)
Remove(...interface{}) (int64, error)
@ -67,11 +74,13 @@ type QueryM2Mer interface {
Count() (int64, error)
}
// raw query statement
type RawPreparer interface {
Exec(...interface{}) (sql.Result, error)
Close() error
}
// raw query seter
type RawSeter interface {
Exec() (sql.Result, error)
QueryRow(...interface{}) error
@ -83,6 +92,7 @@ type RawSeter interface {
Prepare() (RawPreparer, error)
}
// statement querier
type stmtQuerier interface {
Close() error
Exec(args ...interface{}) (sql.Result, error)
@ -90,6 +100,7 @@ type stmtQuerier interface {
QueryRow(args ...interface{}) *sql.Row
}
// db querier
type dbQuerier interface {
Prepare(query string) (*sql.Stmt, error)
Exec(query string, args ...interface{}) (sql.Result, error)
@ -97,19 +108,23 @@ type dbQuerier interface {
QueryRow(query string, args ...interface{}) *sql.Row
}
// transaction beginner
type txer interface {
Begin() (*sql.Tx, error)
}
// transaction ending
type txEnder interface {
Commit() error
Rollback() error
}
// base database struct
type dbBaser interface {
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, []string, []interface{}) (int64, error)
InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)

View File

@ -10,6 +10,7 @@ import (
type StrTo string
// set string
func (f *StrTo) Set(v string) {
if v != "" {
*f = StrTo(v)
@ -18,77 +19,93 @@ func (f *StrTo) Set(v string) {
}
}
// clean string
func (f *StrTo) Clear() {
*f = StrTo(0x1E)
}
// check string exist
func (f StrTo) Exist() bool {
return string(f) != string(0x1E)
}
// string to bool
func (f StrTo) Bool() (bool, error) {
return strconv.ParseBool(f.String())
}
// string to float32
func (f StrTo) Float32() (float32, error) {
v, err := strconv.ParseFloat(f.String(), 32)
return float32(v), err
}
// string to float64
func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64)
}
// string to int
func (f StrTo) Int() (int, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int(v), err
}
// string to int8
func (f StrTo) Int8() (int8, error) {
v, err := strconv.ParseInt(f.String(), 10, 8)
return int8(v), err
}
// string to int16
func (f StrTo) Int16() (int16, error) {
v, err := strconv.ParseInt(f.String(), 10, 16)
return int16(v), err
}
// string to int32
func (f StrTo) Int32() (int32, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int32(v), err
}
// string to int64
func (f StrTo) Int64() (int64, error) {
v, err := strconv.ParseInt(f.String(), 10, 64)
return int64(v), err
}
// string to uint
func (f StrTo) Uint() (uint, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint(v), err
}
// string to uint8
func (f StrTo) Uint8() (uint8, error) {
v, err := strconv.ParseUint(f.String(), 10, 8)
return uint8(v), err
}
// string to uint16
func (f StrTo) Uint16() (uint16, error) {
v, err := strconv.ParseUint(f.String(), 10, 16)
return uint16(v), err
}
// string to uint31
func (f StrTo) Uint32() (uint32, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint32(v), err
}
// string to uint64
func (f StrTo) Uint64() (uint64, error) {
v, err := strconv.ParseUint(f.String(), 10, 64)
return uint64(v), err
}
// string to string
func (f StrTo) String() string {
if f.Exist() {
return string(f)
@ -96,6 +113,7 @@ func (f StrTo) String() string {
return ""
}
// interface to string
func ToStr(value interface{}, args ...int) (s string) {
switch v := value.(type) {
case bool:
@ -134,6 +152,7 @@ func ToStr(value interface{}, args ...int) (s string) {
return s
}
// interface to int64
func ToInt64(value interface{}) (d int64) {
val := reflect.ValueOf(value)
switch value.(type) {
@ -147,6 +166,7 @@ func ToInt64(value interface{}) (d int64) {
return
}
// snake string, XxYy to xx_yy
func snakeString(s string) string {
data := make([]byte, 0, len(s)*2)
j := false
@ -164,6 +184,7 @@ func snakeString(s string) string {
return strings.ToLower(string(data[:len(data)]))
}
// camel string, xx_yy to XxYy
func camelString(s string) string {
data := make([]byte, 0, len(s))
j := false
@ -190,6 +211,7 @@ func camelString(s string) string {
type argString []string
// get string by index from string slice
func (a argString) Get(i int, args ...string) (r string) {
if i >= 0 && i < len(a) {
r = a[i]
@ -201,6 +223,7 @@ func (a argString) Get(i int, args ...string) (r string) {
type argInt []int
// get int by index from int slice
func (a argInt) Get(i int, args ...int) (r int) {
if i >= 0 && i < len(a) {
r = a[i]
@ -213,6 +236,7 @@ func (a argInt) Get(i int, args ...int) (r int) {
type argAny []interface{}
// get interface by index from interface slice
func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
if i >= 0 && i < len(a) {
r = a[i]
@ -223,15 +247,18 @@ func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
return
}
// parse time to string with location
func timeParse(dateString, format string) (time.Time, error) {
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
return tp, err
}
// format time string
func timeFormat(t time.Time, format string) string {
return t.Format(format)
}
// get pointer indirect type
func indirectType(v reflect.Type) reflect.Type {
switch v.Kind() {
case reflect.Ptr:

View File

@ -1,7 +1,10 @@
package beego
import (
"bufio"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
@ -30,6 +33,14 @@ const (
var (
// supported http methods.
HTTPMETHOD = []string{"get", "post", "put", "delete", "patch", "options", "head"}
// these beego.Controller's methods shouldn't reflect to AutoRouter
exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString",
"RenderBytes", "Redirect", "Abort", "StopRun", "UrlFor", "ServeJson", "ServeJsonp",
"ServeXml", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool",
"GetFloat", "GetFile", "SaveToFile", "StartSession", "SetSession", "GetSession",
"DelSession", "SessionRegenerateID", "DestroySession", "IsAjax", "GetSecureCookie",
"SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml",
"GetControllerAndAction"}
)
type controllerInfo struct {
@ -77,7 +88,7 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
params := make(map[int]string)
for i, part := range parts {
if strings.HasPrefix(part, ":") {
expr := "(.+)"
expr := "(.*)"
//a user may choose to override the defult expression
// similar to expressjs: /user/:id([0-9]+)
if index := strings.Index(part, "("); index != -1 {
@ -100,7 +111,7 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
j++
}
if strings.HasPrefix(part, "*") {
expr := "(.+)"
expr := "(.*)"
if part == "*.*" {
params[j] = ":path"
parts[i] = "([^.]+).([^.]+)"
@ -218,8 +229,8 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
// Add auto router to ControllerRegistor.
// example beego.AddAuto(&MainContorlller{}),
// MainController has method List and Page.
// visit the url /main/list to exec List function
// /main/page to exec Page function.
// visit the url /main/list to execute List function
// /main/page to execute Page function.
func (p *ControllerRegistor) AddAuto(c ControllerInterface) {
p.enableAuto = true
reflectVal := reflect.ValueOf(c)
@ -232,14 +243,42 @@ func (p *ControllerRegistor) AddAuto(c ControllerInterface) {
p.autoRouter[firstParam] = make(map[string]reflect.Type)
}
for i := 0; i < rt.NumMethod(); i++ {
p.autoRouter[firstParam][rt.Method(i).Name] = ct
if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
p.autoRouter[firstParam][rt.Method(i).Name] = ct
}
}
}
// Add auto router to ControllerRegistor with prefix.
// example beego.AddAutoPrefix("/admin",&MainContorlller{}),
// MainController has method List and Page.
// visit the url /admin/main/list to execute List function
// /admin/main/page to execute Page function.
func (p *ControllerRegistor) AddAutoPrefix(prefix string, c ControllerInterface) {
p.enableAuto = true
reflectVal := reflect.ValueOf(c)
rt := reflectVal.Type()
ct := reflect.Indirect(reflectVal).Type()
firstParam := strings.Trim(prefix, "/") + "/" + strings.ToLower(strings.TrimSuffix(ct.Name(), "Controller"))
if _, ok := p.autoRouter[firstParam]; ok {
return
} else {
p.autoRouter[firstParam] = make(map[string]reflect.Type)
}
for i := 0; i < rt.NumMethod(); i++ {
if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
p.autoRouter[firstParam][rt.Method(i).Name] = ct
}
}
}
// [Deprecated] use InsertFilter.
// Add FilterFunc with pattern for action.
func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc) {
mr := buildFilter(pattern, filter)
func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc) error {
mr, err := buildFilter(pattern, filter)
if err != nil {
return err
}
switch action {
case "BeforeRouter":
p.filters[BeforeRouter] = append(p.filters[BeforeRouter], mr)
@ -253,13 +292,18 @@ func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc
p.filters[FinishRouter] = append(p.filters[FinishRouter], mr)
}
p.enableFilter = true
return nil
}
// Add a FilterFunc with pattern rule and action constant.
func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc) {
mr := buildFilter(pattern, filter)
func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc) error {
mr, err := buildFilter(pattern, filter)
if err != nil {
return err
}
p.filters[pos] = append(p.filters[pos], mr)
p.enableFilter = true
return nil
}
// UrlFor does another controller handler in this request function.
@ -485,7 +529,9 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
// session init
if SessionOn {
context.Input.CruSession = GlobalSessions.SessionStart(w, r)
defer context.Input.CruSession.SessionRelease()
defer func() {
context.Input.CruSession.SessionRelease(w)
}()
}
if !utils.InSlice(strings.ToLower(r.Method), HTTPMETHOD) {
@ -575,12 +621,11 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
}
// pattern /admin url /admin 200 /admin/ 200
// pattern /admin/ url /admin 301 /admin/ 200
if requestPath[n-1] != '/' && len(route.pattern) == n+1 &&
route.pattern[n] == '/' && route.pattern[:n] == requestPath {
if requestPath[n-1] != '/' && requestPath+"/" == route.pattern {
http.Redirect(w, r, requestPath+"/", 301)
goto Admin
}
if requestPath[n-1] == '/' && n >= 2 && requestPath[:n-2] == route.pattern {
if requestPath[n-1] == '/' && route.pattern+"/" == requestPath {
runMethod = p.getRunMethod(r.Method, context, route)
if runMethod != "" {
runrouter = route.controllerType
@ -857,3 +902,13 @@ func (w *responseWriter) WriteHeader(code int) {
w.started = true
w.writer.WriteHeader(code)
}
// hijacker for http
func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hj, ok := w.writer.(http.Hijacker)
if !ok {
println("supported?")
return nil, nil, errors.New("webserver doesn't support hijacking")
}
return hj.Hijack()
}

View File

@ -198,3 +198,15 @@ func TestPrepare(t *testing.T) {
t.Errorf(w.Body.String() + "user define func can't run")
}
}
func TestAutoPrefix(t *testing.T) {
r, _ := http.NewRequest("GET", "/admin/test/list", nil)
w := httptest.NewRecorder()
handler := NewControllerRegistor()
handler.AddAutoPrefix("/admin", &TestController{})
handler.ServeHTTP(w, r)
if w.Body.String() != "i am list" {
t.Errorf("TestAutoPrefix can't run")
}
}

View File

@ -28,21 +28,21 @@ Then in you web app init the global session manager
* Use **memory** as provider:
func init() {
globalSessions, _ = session.NewManager("memory", "gosessionid", 3600,"")
globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`)
go globalSessions.GC()
}
* Use **file** as provider, the last param is the path where you want file to be stored:
func init() {
globalSessions, _ = session.NewManager("file", "gosessionid", 3600, "./tmp")
globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","./tmp"}`)
go globalSessions.GC()
}
* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password:
func init() {
globalSessions, _ = session.NewManager("redis", "gosessionid", 3600, "127.0.0.1:6379,100,astaxie")
globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","127.0.0.1:6379,100,astaxie"}`)
go globalSessions.GC()
}
@ -50,15 +50,24 @@ Then in you web app init the global session manager
func init() {
globalSessions, _ = session.NewManager(
"mysql", "gosessionid", 3600, "username:password@protocol(address)/dbname?param=value")
"mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","username:password@protocol(address)/dbname?param=value"}`)
go globalSessions.GC()
}
* Use **Cookie** as provider:
func init() {
globalSessions, _ = session.NewManager(
"cookie", `{"cookieName":"gosessionid","enableSetCookie":false,gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`)
go globalSessions.GC()
}
Finally in the handlerfunc you can use it like this
func login(w http.ResponseWriter, r *http.Request) {
sess := globalSessions.SessionStart(w, r)
defer sess.SessionRelease()
defer sess.SessionRelease(w)
username := sess.Get("username")
fmt.Println(username)
if r.Method == "GET" {
@ -78,19 +87,19 @@ When you develop a web app, maybe you want to write own provider because you mus
Writing a provider is easy. You only need to define two struct types
(Session and Provider), which satisfy the interface definition.
Maybe you will find the **memory** provider as good example.
Maybe you will find the **memory** provider is a good example.
type SessionStore interface {
Set(key, value interface{}) error //set session value
Get(key interface{}) interface{} //get session value
Delete(key interface{}) error //delete session value
SessionID() string //back current sessionID
SessionRelease() // release the resource & save data to provider
Flush() error //delete all data
Set(key, value interface{}) error //set session value
Get(key interface{}) interface{} //get session value
Delete(key interface{}) error //delete session value
SessionID() string //back current sessionID
SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
Flush() error //delete all data
}
type Provider interface {
SessionInit(maxlifetime int64, savePath string) error
SessionInit(gclifetime int64, config string) error
SessionRead(sid string) (SessionStore, error)
SessionExist(sid string) bool
SessionRegenerate(oldsid, sid string) (SessionStore, error)

145
session/sess_cookie.go Normal file
View File

@ -0,0 +1,145 @@
package session
import (
"crypto/aes"
"crypto/cipher"
"encoding/json"
"net/http"
"net/url"
"sync"
)
var cookiepder = &CookieProvider{}
type CookieSessionStore struct {
sid string
values map[interface{}]interface{} //session data
lock sync.RWMutex
}
func (st *CookieSessionStore) Set(key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
return nil
}
func (st *CookieSessionStore) Get(key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
} else {
return nil
}
return nil
}
func (st *CookieSessionStore) Delete(key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
return nil
}
func (st *CookieSessionStore) Flush() error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
return nil
}
func (st *CookieSessionStore) SessionID() string {
return st.sid
}
func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) {
str, err := encodeCookie(cookiepder.block,
cookiepder.config.SecurityKey,
cookiepder.config.SecurityName,
st.values)
if err != nil {
return
}
cookie := &http.Cookie{Name: cookiepder.config.CookieName,
Value: url.QueryEscape(str),
Path: "/",
HttpOnly: true,
Secure: cookiepder.config.Secure}
http.SetCookie(w, cookie)
return
}
type cookieConfig struct {
SecurityKey string `json:"securityKey"`
BlockKey string `json:"blockKey"`
SecurityName string `json:"securityName"`
CookieName string `json:"cookieName"`
Secure bool `json:"secure"`
Maxage int `json:"maxage"`
}
type CookieProvider struct {
maxlifetime int64
config *cookieConfig
block cipher.Block
}
func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error {
pder.config = &cookieConfig{}
err := json.Unmarshal([]byte(config), pder.config)
if err != nil {
return err
}
if pder.config.BlockKey == "" {
pder.config.BlockKey = string(generateRandomKey(16))
}
if pder.config.SecurityName == "" {
pder.config.SecurityName = string(generateRandomKey(20))
}
pder.block, err = aes.NewCipher([]byte(pder.config.BlockKey))
if err != nil {
return err
}
return nil
}
func (pder *CookieProvider) SessionRead(sid string) (SessionStore, error) {
maps, _ := decodeCookie(pder.block,
pder.config.SecurityKey,
pder.config.SecurityName,
sid, pder.maxlifetime)
if maps == nil {
maps = make(map[interface{}]interface{})
}
rs := &CookieSessionStore{sid: sid, values: maps}
return rs, nil
}
func (pder *CookieProvider) SessionExist(sid string) bool {
return true
}
func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
return nil, nil
}
func (pder *CookieProvider) SessionDestroy(sid string) error {
return nil
}
func (pder *CookieProvider) SessionGC() {
return
}
func (pder *CookieProvider) SessionAll() int {
return 0
}
func (pder *CookieProvider) SessionUpdate(sid string) error {
return nil
}
func init() {
Register("cookie", cookiepder)
}

View File

@ -0,0 +1,38 @@
package session
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestCookie(t *testing.T) {
config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
globalSessions, err := NewManager("cookie", config)
if err != nil {
t.Fatal("init cookie session err", err)
}
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess := globalSessions.SessionStart(w, r)
err = sess.Set("username", "astaxie")
if err != nil {
t.Fatal("set error,", err)
}
if username := sess.Get("username"); username != "astaxie" {
t.Fatal("get username error")
}
sess.SessionRelease(w)
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
t.Fatal("setcookie error")
} else {
parts := strings.Split(strings.TrimSpace(cookiestr), ";")
for k, v := range parts {
nameval := strings.Split(v, "=")
if k == 0 && nameval[0] != "gosessionid" {
t.Fatal("error")
}
}
}
}

View File

@ -5,6 +5,7 @@ import (
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"path"
"path/filepath"
@ -60,7 +61,7 @@ func (fs *FileSessionStore) SessionID() string {
return fs.sid
}
func (fs *FileSessionStore) SessionRelease() {
func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
defer fs.f.Close()
b, err := encodeGob(fs.values)
if err != nil {

View File

@ -1,38 +0,0 @@
package session
import (
"bytes"
"encoding/gob"
)
func init() {
gob.Register([]interface{}{})
gob.Register(map[int]interface{}{})
gob.Register(map[string]interface{}{})
gob.Register(map[interface{}]interface{}{})
gob.Register(map[string]string{})
gob.Register(map[int]string{})
gob.Register(map[int]int{})
gob.Register(map[int]int64{})
}
func encodeGob(obj map[interface{}]interface{}) ([]byte, error) {
buf := bytes.NewBuffer(nil)
enc := gob.NewEncoder(buf)
err := enc.Encode(obj)
if err != nil {
return []byte(""), err
}
return buf.Bytes(), nil
}
func decodeGob(encoded []byte) (map[interface{}]interface{}, error) {
buf := bytes.NewBuffer(encoded)
dec := gob.NewDecoder(buf)
var out map[interface{}]interface{}
err := dec.Decode(&out)
if err != nil {
return nil, err
}
return out, nil
}

View File

@ -2,6 +2,7 @@ package session
import (
"container/list"
"net/http"
"sync"
"time"
)
@ -9,9 +10,9 @@ import (
var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)}
type MemSessionStore struct {
sid string //session id唯一标示
timeAccessed time.Time //最后访问时间
value map[interface{}]interface{} //session里面存储的值
sid string //session id
timeAccessed time.Time //last access time
value map[interface{}]interface{} //session store
lock sync.RWMutex
}
@ -51,8 +52,7 @@ func (st *MemSessionStore) SessionID() string {
return st.sid
}
func (st *MemSessionStore) SessionRelease() {
func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) {
}
type MemProvider struct {

35
session/sess_mem_test.go Normal file
View File

@ -0,0 +1,35 @@
package session
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestMem(t *testing.T) {
globalSessions, _ := NewManager("memory", `{"cookieName":"gosessionid","gclifetime":10}`)
go globalSessions.GC()
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess := globalSessions.SessionStart(w, r)
defer sess.SessionRelease(w)
err := sess.Set("username", "astaxie")
if err != nil {
t.Fatal("set error,", err)
}
if username := sess.Get("username"); username != "astaxie" {
t.Fatal("get username error")
}
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
t.Fatal("setcookie error")
} else {
parts := strings.Split(strings.TrimSpace(cookiestr), ";")
for k, v := range parts {
nameval := strings.Split(v, "=")
if k == 0 && nameval[0] != "gosessionid" {
t.Fatal("error")
}
}
}
}

View File

@ -9,6 +9,7 @@ package session
import (
"database/sql"
"net/http"
"sync"
"time"
@ -60,15 +61,15 @@ func (st *MysqlSessionStore) SessionID() string {
return st.sid
}
func (st *MysqlSessionStore) SessionRelease() {
func (st *MysqlSessionStore) SessionRelease(w http.ResponseWriter) {
defer st.c.Close()
if len(st.values) > 0 {
b, err := encodeGob(st.values)
if err != nil {
return
}
st.c.Exec("UPDATE session set `session_data`= ? where session_key=?", b, st.sid)
b, err := encodeGob(st.values)
if err != nil {
return
}
st.c.Exec("UPDATE session set `session_data`=?, `session_expiry`=? where session_key=?",
b, time.Now().Unix(), st.sid)
}
type MysqlProvider struct {
@ -96,7 +97,8 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) {
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", sid, "", time.Now().Unix())
c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)",
sid, "", time.Now().Unix())
}
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
@ -113,6 +115,7 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) {
func (mp *MysqlProvider) SessionExist(sid string) bool {
c := mp.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from session where session_key=?", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)

View File

@ -1,6 +1,7 @@
package session
import (
"net/http"
"strconv"
"strings"
"sync"
@ -58,16 +59,14 @@ func (rs *RedisSessionStore) SessionID() string {
return rs.sid
}
func (rs *RedisSessionStore) SessionRelease() {
func (rs *RedisSessionStore) SessionRelease(w http.ResponseWriter) {
defer rs.c.Close()
if len(rs.values) > 0 {
b, err := encodeGob(rs.values)
if err != nil {
return
}
rs.c.Do("SET", rs.sid, string(b))
rs.c.Do("EXPIRE", rs.sid, rs.maxlifetime)
b, err := encodeGob(rs.values)
if err != nil {
return
}
rs.c.Do("SET", rs.sid, string(b))
rs.c.Do("EXPIRE", rs.sid, rs.maxlifetime)
}
type RedisProvider struct {

View File

@ -1,6 +1,8 @@
package session
import (
"crypto/aes"
"encoding/json"
"testing"
)
@ -26,3 +28,82 @@ func Test_gob(t *testing.T) {
t.Error("decode int error")
}
}
func TestGenerate(t *testing.T) {
str := generateRandomKey(20)
if len(str) != 20 {
t.Fatal("generate length is not equal to 20")
}
}
func TestCookieEncodeDecode(t *testing.T) {
hashKey := "testhashKey"
blockkey := generateRandomKey(16)
block, err := aes.NewCipher(blockkey)
if err != nil {
t.Fatal("NewCipher:", err)
}
securityName := string(generateRandomKey(20))
val := make(map[interface{}]interface{})
val["name"] = "astaxie"
val["gender"] = "male"
str, err := encodeCookie(block, hashKey, securityName, val)
if err != nil {
t.Fatal("encodeCookie:", err)
}
dst := make(map[interface{}]interface{})
dst, err = decodeCookie(block, hashKey, securityName, str, 3600)
if err != nil {
t.Fatal("decodeCookie", err)
}
if dst["name"] != "astaxie" {
t.Fatal("dst get map error")
}
if dst["gender"] != "male" {
t.Fatal("dst get map error")
}
}
func TestParseConfig(t *testing.T) {
s := `{"cookieName":"gosessionid","gclifetime":3600}`
cf := new(managerConfig)
cf.EnableSetCookie = true
err := json.Unmarshal([]byte(s), cf)
if err != nil {
t.Fatal("parse json error,", err)
}
if cf.CookieName != "gosessionid" {
t.Fatal("parseconfig get cookiename error")
}
if cf.Gclifetime != 3600 {
t.Fatal("parseconfig get gclifetime error")
}
cc := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
cf2 := new(managerConfig)
cf2.EnableSetCookie = true
err = json.Unmarshal([]byte(cc), cf2)
if err != nil {
t.Fatal("parse json error,", err)
}
if cf2.CookieName != "gosessionid" {
t.Fatal("parseconfig get cookiename error")
}
if cf2.Gclifetime != 3600 {
t.Fatal("parseconfig get gclifetime error")
}
if cf2.EnableSetCookie != false {
t.Fatal("parseconfig get enableSetCookie error")
}
cconfig := new(cookieConfig)
err = json.Unmarshal([]byte(cf2.ProviderConfig), cconfig)
if err != nil {
t.Fatal("parse ProviderConfig err,", err)
}
if cconfig.CookieName != "gosessionid" {
t.Fatal("ProviderConfig get cookieName error")
}
if cconfig.SecurityKey != "beegocookiehashkey" {
t.Fatal("ProviderConfig get securityKey error")
}
}

188
session/sess_utils.go Normal file
View File

@ -0,0 +1,188 @@
package session
import (
"bytes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"crypto/subtle"
"encoding/base64"
"encoding/gob"
"errors"
"fmt"
"io"
"strconv"
"time"
)
func init() {
gob.Register([]interface{}{})
gob.Register(map[int]interface{}{})
gob.Register(map[string]interface{}{})
gob.Register(map[interface{}]interface{}{})
gob.Register(map[string]string{})
gob.Register(map[int]string{})
gob.Register(map[int]int{})
gob.Register(map[int]int64{})
}
func encodeGob(obj map[interface{}]interface{}) ([]byte, error) {
buf := bytes.NewBuffer(nil)
enc := gob.NewEncoder(buf)
err := enc.Encode(obj)
if err != nil {
return []byte(""), err
}
return buf.Bytes(), nil
}
func decodeGob(encoded []byte) (map[interface{}]interface{}, error) {
buf := bytes.NewBuffer(encoded)
dec := gob.NewDecoder(buf)
var out map[interface{}]interface{}
err := dec.Decode(&out)
if err != nil {
return nil, err
}
return out, nil
}
// generateRandomKey creates a random key with the given strength.
func generateRandomKey(strength int) []byte {
k := make([]byte, strength)
if _, err := io.ReadFull(rand.Reader, k); err != nil {
return nil
}
return k
}
// Encryption -----------------------------------------------------------------
// encrypt encrypts a value using the given block in counter mode.
//
// A random initialization vector (http://goo.gl/zF67k) with the length of the
// block size is prepended to the resulting ciphertext.
func encrypt(block cipher.Block, value []byte) ([]byte, error) {
iv := generateRandomKey(block.BlockSize())
if iv == nil {
return nil, errors.New("encrypt: failed to generate random iv")
}
// Encrypt it.
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(value, value)
// Return iv + ciphertext.
return append(iv, value...), nil
}
// decrypt decrypts a value using the given block in counter mode.
//
// The value to be decrypted must be prepended by a initialization vector
// (http://goo.gl/zF67k) with the length of the block size.
func decrypt(block cipher.Block, value []byte) ([]byte, error) {
size := block.BlockSize()
if len(value) > size {
// Extract iv.
iv := value[:size]
// Extract ciphertext.
value = value[size:]
// Decrypt it.
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(value, value)
return value, nil
}
return nil, errors.New("decrypt: the value could not be decrypted")
}
func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) {
var err error
var b []byte
// 1. encodeGob.
if b, err = encodeGob(value); err != nil {
return "", err
}
// 2. Encrypt (optional).
if b, err = encrypt(block, b); err != nil {
return "", err
}
b = encode(b)
// 3. Create MAC for "name|date|value". Extra pipe to be used later.
b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b))
h := hmac.New(sha1.New, []byte(hashKey))
h.Write(b)
sig := h.Sum(nil)
// Append mac, remove name.
b = append(b, sig...)[len(name)+1:]
// 4. Encode to base64.
b = encode(b)
// Done.
return string(b), nil
}
func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) {
// 1. Decode from base64.
b, err := decode([]byte(value))
if err != nil {
return nil, err
}
// 2. Verify MAC. Value is "date|value|mac".
parts := bytes.SplitN(b, []byte("|"), 3)
if len(parts) != 3 {
return nil, errors.New("Decode: invalid value %v")
}
b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...)
h := hmac.New(sha1.New, []byte(hashKey))
h.Write(b)
sig := h.Sum(nil)
if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 {
return nil, errors.New("Decode: the value is not valid")
}
// 3. Verify date ranges.
var t1 int64
if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil {
return nil, errors.New("Decode: invalid timestamp")
}
t2 := time.Now().UTC().Unix()
if t1 > t2 {
return nil, errors.New("Decode: timestamp is too new")
}
if t1 < t2-gcmaxlifetime {
return nil, errors.New("Decode: expired timestamp")
}
// 4. Decrypt (optional).
b, err = decode(parts[1])
if err != nil {
return nil, err
}
if b, err = decrypt(block, b); err != nil {
return nil, err
}
// 5. decodeGob.
if dst, err := decodeGob(b); err != nil {
return nil, err
} else {
return dst, nil
}
// Done.
return nil, nil
}
// Encoding -------------------------------------------------------------------
// encode encodes a value using base64.
func encode(value []byte) []byte {
encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value)))
base64.URLEncoding.Encode(encoded, value)
return encoded
}
// decode decodes a cookie using base64.
func decode(value []byte) ([]byte, error) {
decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value)))
b, err := base64.URLEncoding.Decode(decoded, value)
if err != nil {
return nil, err
}
return decoded[:b], nil
}

View File

@ -6,6 +6,7 @@ import (
"crypto/rand"
"crypto/sha1"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
@ -14,16 +15,16 @@ import (
)
type SessionStore interface {
Set(key, value interface{}) error //set session value
Get(key interface{}) interface{} //get session value
Delete(key interface{}) error //delete session value
SessionID() string //back current sessionID
SessionRelease() // release the resource & save data to provider
Flush() error //delete all data
Set(key, value interface{}) error //set session value
Get(key interface{}) interface{} //get session value
Delete(key interface{}) error //delete session value
SessionID() string //back current sessionID
SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
Flush() error //delete all data
}
type Provider interface {
SessionInit(maxlifetime int64, savePath string) error
SessionInit(gclifetime int64, config string) error
SessionRead(sid string) (SessionStore, error)
SessionExist(sid string) bool
SessionRegenerate(oldsid, sid string) (SessionStore, error)
@ -47,15 +48,22 @@ func Register(name string, provide Provider) {
provides[name] = provide
}
type managerConfig struct {
CookieName string `json:"cookieName"`
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
Gclifetime int64 `json:"gclifetime"`
Maxlifetime int64 `json:"maxLifetime"`
Maxage int `json:"maxage"`
Secure bool `json:"secure"`
SessionIDHashFunc string `json:"sessionIDHashFunc"`
SessionIDHashKey string `json:"sessionIDHashKey"`
CookieLifeTime int64 `json:"cookieLifeTime"`
ProviderConfig string `json:"providerConfig"`
}
type Manager struct {
cookieName string //private cookiename
provider Provider
maxlifetime int64
hashfunc string //support md5 & sha1
hashkey string
maxage int //cookielifetime
secure bool
options []interface{}
provider Provider
config *managerConfig
}
//options
@ -63,74 +71,54 @@ type Manager struct {
//2. hashfunc default sha1
//3. hashkey default beegosessionkey
//4. maxage default is none
func NewManager(provideName, cookieName string, maxlifetime int64, savePath string, options ...interface{}) (*Manager, error) {
func NewManager(provideName, config string) (*Manager, error) {
provider, ok := provides[provideName]
if !ok {
return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName)
}
provider.SessionInit(maxlifetime, savePath)
secure := false
if len(options) > 0 {
secure = options[0].(bool)
cf := new(managerConfig)
cf.EnableSetCookie = true
err := json.Unmarshal([]byte(config), cf)
if err != nil {
return nil, err
}
hashfunc := "sha1"
if len(options) > 1 {
hashfunc = options[1].(string)
if cf.Maxlifetime == 0 {
cf.Maxlifetime = cf.Gclifetime
}
hashkey := "beegosessionkey"
if len(options) > 2 {
hashkey = options[2].(string)
err = provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig)
if err != nil {
return nil, err
}
maxage := -1
if len(options) > 3 {
switch options[3].(type) {
case int:
if options[3].(int) > 0 {
maxage = options[3].(int)
} else if options[3].(int) < 0 {
maxage = 0
}
case int64:
if options[3].(int64) > 0 {
maxage = int(options[3].(int64))
} else if options[3].(int64) < 0 {
maxage = 0
}
case int32:
if options[3].(int32) > 0 {
maxage = int(options[3].(int32))
} else if options[3].(int32) < 0 {
maxage = 0
}
}
if cf.SessionIDHashFunc == "" {
cf.SessionIDHashFunc = "sha1"
}
if cf.SessionIDHashKey == "" {
cf.SessionIDHashKey = string(generateRandomKey(16))
}
return &Manager{
provider: provider,
cookieName: cookieName,
maxlifetime: maxlifetime,
hashfunc: hashfunc,
hashkey: hashkey,
maxage: maxage,
secure: secure,
options: options,
provider,
cf,
}, nil
}
//get Session
func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) {
cookie, err := r.Cookie(manager.cookieName)
cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" {
sid := manager.sessionId(r)
session, _ = manager.provider.SessionRead(sid)
cookie = &http.Cookie{Name: manager.cookieName,
cookie = &http.Cookie{Name: manager.config.CookieName,
Value: url.QueryEscape(sid),
Path: "/",
HttpOnly: true,
Secure: manager.secure}
if manager.maxage >= 0 {
cookie.MaxAge = manager.maxage
Secure: manager.config.Secure}
if manager.config.Maxage >= 0 {
cookie.MaxAge = manager.config.Maxage
}
if manager.config.EnableSetCookie {
http.SetCookie(w, cookie)
}
http.SetCookie(w, cookie)
r.AddCookie(cookie)
} else {
sid, _ := url.QueryUnescape(cookie.Value)
@ -139,15 +127,17 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
} else {
sid = manager.sessionId(r)
session, _ = manager.provider.SessionRead(sid)
cookie = &http.Cookie{Name: manager.cookieName,
cookie = &http.Cookie{Name: manager.config.CookieName,
Value: url.QueryEscape(sid),
Path: "/",
HttpOnly: true,
Secure: manager.secure}
if manager.maxage >= 0 {
cookie.MaxAge = manager.maxage
Secure: manager.config.Secure}
if manager.config.Maxage >= 0 {
cookie.MaxAge = manager.config.Maxage
}
if manager.config.EnableSetCookie {
http.SetCookie(w, cookie)
}
http.SetCookie(w, cookie)
r.AddCookie(cookie)
}
}
@ -156,13 +146,17 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
//Destroy sessionid
func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(manager.cookieName)
cookie, err := r.Cookie(manager.config.CookieName)
if err != nil || cookie.Value == "" {
return
} else {
manager.provider.SessionDestroy(cookie.Value)
expiration := time.Now()
cookie := http.Cookie{Name: manager.cookieName, Path: "/", HttpOnly: true, Expires: expiration, MaxAge: -1}
cookie := http.Cookie{Name: manager.config.CookieName,
Path: "/",
HttpOnly: true,
Expires: expiration,
MaxAge: -1}
http.SetCookie(w, &cookie)
}
}
@ -174,20 +168,20 @@ func (manager *Manager) GetProvider(sid string) (sessions SessionStore, err erro
func (manager *Manager) GC() {
manager.provider.SessionGC()
time.AfterFunc(time.Duration(manager.maxlifetime)*time.Second, func() { manager.GC() })
time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() })
}
func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) {
sid := manager.sessionId(r)
cookie, err := r.Cookie(manager.cookieName)
cookie, err := r.Cookie(manager.config.CookieName)
if err != nil && cookie.Value == "" {
//delete old cookie
session, _ = manager.provider.SessionRead(sid)
cookie = &http.Cookie{Name: manager.cookieName,
cookie = &http.Cookie{Name: manager.config.CookieName,
Value: url.QueryEscape(sid),
Path: "/",
HttpOnly: true,
Secure: manager.secure,
Secure: manager.config.Secure,
}
} else {
oldsid, _ := url.QueryUnescape(cookie.Value)
@ -196,8 +190,8 @@ func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Reque
cookie.HttpOnly = true
cookie.Path = "/"
}
if manager.maxage >= 0 {
cookie.MaxAge = manager.maxage
if manager.config.Maxage >= 0 {
cookie.MaxAge = manager.config.Maxage
}
http.SetCookie(w, cookie)
r.AddCookie(cookie)
@ -209,12 +203,12 @@ func (manager *Manager) GetActiveSession() int {
}
func (manager *Manager) SetHashFunc(hasfunc, hashkey string) {
manager.hashfunc = hasfunc
manager.hashkey = hashkey
manager.config.SessionIDHashFunc = hasfunc
manager.config.SessionIDHashKey = hashkey
}
func (manager *Manager) SetSecure(secure bool) {
manager.secure = secure
manager.config.Secure = secure
}
//remote_addr cruunixnano randdata
@ -224,16 +218,16 @@ func (manager *Manager) sessionId(r *http.Request) (sid string) {
return ""
}
sig := fmt.Sprintf("%s%d%s", r.RemoteAddr, time.Now().UnixNano(), bs)
if manager.hashfunc == "md5" {
if manager.config.SessionIDHashFunc == "md5" {
h := md5.New()
h.Write([]byte(sig))
sid = hex.EncodeToString(h.Sum(nil))
} else if manager.hashfunc == "sha1" {
h := hmac.New(sha1.New, []byte(manager.hashkey))
} else if manager.config.SessionIDHashFunc == "sha1" {
h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey))
fmt.Fprintf(h, "%s", sig)
sid = hex.EncodeToString(h.Sum(nil))
} else {
h := hmac.New(sha1.New, []byte(manager.hashkey))
h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey))
fmt.Fprintf(h, "%s", sig)
sid = hex.EncodeToString(h.Sum(nil))
}

45
utils/captcha/README.md Normal file
View File

@ -0,0 +1,45 @@
# Captcha
an example for use captcha
```
package controllers
import (
"github.com/astaxie/beego"
"github.com/astaxie/beego/cache"
"github.com/astaxie/beego/utils/captcha"
)
var cpt *captcha.Captcha
func init() {
// use beego cache system store the captcha data
store := cache.NewMemoryCache()
cpt = captcha.NewWithFilter("/captcha/", store)
}
type MainController struct {
beego.Controller
}
func (this *MainController) Get() {
this.TplNames = "index.tpl"
}
func (this *MainController) Post() {
this.TplNames = "index.tpl"
this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request)
}
```
template usage
```
{{.Success}}
<form action="/" method="post">
{{create_captcha}}
<input name="captcha" type="text">
</form>
```

248
utils/captcha/captcha.go Normal file
View File

@ -0,0 +1,248 @@
// an example for use captcha
//
// ```
// package controllers
//
// import (
// "github.com/astaxie/beego"
// "github.com/astaxie/beego/cache"
// "github.com/astaxie/beego/utils/captcha"
// )
//
// var cpt *captcha.Captcha
//
// func init() {
// // use beego cache system store the captcha data
// store := cache.NewMemoryCache()
// cpt = captcha.NewWithFilter("/captcha/", store)
// }
//
// type MainController struct {
// beego.Controller
// }
//
// func (this *MainController) Get() {
// this.TplNames = "index.tpl"
// }
//
// func (this *MainController) Post() {
// this.TplNames = "index.tpl"
//
// this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request)
// }
// ```
//
// template usage
//
// ```
// {{.Success}}
// <form action="/" method="post">
// {{create_captcha}}
// <input name="captcha" type="text">
// </form>
// ```
package captcha
import (
"fmt"
"html/template"
"net/http"
"path"
"strings"
"github.com/astaxie/beego"
"github.com/astaxie/beego/cache"
"github.com/astaxie/beego/context"
"github.com/astaxie/beego/utils"
)
var (
defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
)
const (
// default captcha attributes
challengeNums = 6
expiration = 600
fieldIdName = "captcha_id"
fieldCaptchaName = "captcha"
cachePrefix = "captcha_"
urlPrefix = "/captcha/"
)
type Captcha struct {
// beego cache store
store cache.Cache
// url prefix for captcha image
urlPrefix string
// specify captcha id input field name
FieldIdName string
// specify captcha result input field name
FieldCaptchaName string
// captcha image width and height
StdWidth int
StdHeight int
// captcha chars nums
ChallengeNums int
// captcha expiration seconds
Expiration int64
// cache key prefix
CachePrefix string
}
func (c *Captcha) key(id string) string {
return c.CachePrefix + id
}
func (c *Captcha) genRandChars() []byte {
return utils.RandomCreateBytes(c.ChallengeNums, defaultChars...)
}
// beego filter handler for serve captcha image
func (c *Captcha) Handler(ctx *context.Context) {
var chars []byte
id := path.Base(ctx.Request.RequestURI)
if i := strings.Index(id, "."); i != -1 {
id = id[:i]
}
key := c.key(id)
if v, ok := c.store.Get(key).([]byte); ok {
chars = v
} else {
ctx.Output.SetStatus(404)
ctx.WriteString("captcha not found")
return
}
// reload captcha
if len(ctx.Input.Query("reload")) > 0 {
chars = c.genRandChars()
if err := c.store.Put(key, chars, c.Expiration); err != nil {
ctx.Output.SetStatus(500)
ctx.WriteString("captcha reload error")
beego.Error("Reload Create Captcha Error:", err)
return
}
}
img := NewImage(chars, c.StdWidth, c.StdHeight)
if _, err := img.WriteTo(ctx.ResponseWriter); err != nil {
beego.Error("Write Captcha Image Error:", err)
}
}
// tempalte func for output html
func (c *Captcha) CreateCaptchaHtml() template.HTML {
value, err := c.CreateCaptcha()
if err != nil {
beego.Error("Create Captcha Error:", err)
return ""
}
// create html
return template.HTML(fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`+
`<a class="captcha" href="javascript:">`+
`<img onclick="this.src=('%s%s.png?reload='+(new Date()).getTime())" class="captcha-img" src="%s%s.png">`+
`</a>`, c.FieldIdName, value, c.urlPrefix, value, c.urlPrefix, value))
}
// create a new captcha id
func (c *Captcha) CreateCaptcha() (string, error) {
// generate captcha id
id := string(utils.RandomCreateBytes(15))
// get the captcha chars
chars := c.genRandChars()
// save to store
if err := c.store.Put(c.key(id), chars, c.Expiration); err != nil {
return "", err
}
return id, nil
}
// verify from a request
func (c *Captcha) VerifyReq(req *http.Request) bool {
req.ParseForm()
return c.Verify(req.Form.Get(c.FieldIdName), req.Form.Get(c.FieldCaptchaName))
}
// direct verify id and challenge string
func (c *Captcha) Verify(id string, challenge string) (success bool) {
if len(challenge) == 0 || len(id) == 0 {
return
}
var chars []byte
key := c.key(id)
if v, ok := c.store.Get(key).([]byte); ok && len(v) == len(challenge) {
chars = v
} else {
return
}
defer func() {
// finally remove it
c.store.Delete(key)
}()
// verify challenge
for i, c := range chars {
if c != challenge[i]-48 {
return
}
}
return true
}
// create a new captcha.Captcha
func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha {
cpt := &Captcha{}
cpt.store = store
cpt.FieldIdName = fieldIdName
cpt.FieldCaptchaName = fieldCaptchaName
cpt.ChallengeNums = challengeNums
cpt.Expiration = expiration
cpt.CachePrefix = cachePrefix
cpt.StdWidth = stdWidth
cpt.StdHeight = stdHeight
if len(urlPrefix) == 0 {
urlPrefix = urlPrefix
}
if urlPrefix[len(urlPrefix)-1] != '/' {
urlPrefix += "/"
}
cpt.urlPrefix = urlPrefix
return cpt
}
// create a new captcha.Captcha and auto AddFilter for serve captacha image
// and add a tempalte func for output html
func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha {
cpt := NewCaptcha(urlPrefix, store)
// create filter for serve captcha image
beego.AddFilter(urlPrefix+":", "BeforeRouter", cpt.Handler)
// add to template func map
beego.AddFuncMap("create_captcha", cpt.CreateCaptchaHtml)
return cpt
}

484
utils/captcha/image.go Normal file
View File

@ -0,0 +1,484 @@
// modifiy and integrated to Beego from https://github.com/dchest/captcha
package captcha
import (
"bytes"
"image"
"image/color"
"image/png"
"io"
"math"
)
const (
fontWidth = 11
fontHeight = 18
blackChar = 1
// Standard width and height of a captcha image.
stdWidth = 240
stdHeight = 80
// Maximum absolute skew factor of a single digit.
maxSkew = 0.7
// Number of background circles.
circleCount = 20
)
var font = [][]byte{
{ // 0
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1,
0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
},
{ // 1
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0,
0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
},
{ // 2
0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0,
0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
},
{ // 3
0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
},
{ // 4
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0,
0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0,
0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0,
0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0,
0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0,
0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0,
1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0,
1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
},
{ // 5
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0,
},
{ // 6
0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0,
0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0,
0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0,
1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0,
1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0,
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1,
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
},
{ // 7
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0,
0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
},
{ // 8
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1,
0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0,
0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0,
0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
},
{ // 9
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0,
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1,
0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1,
0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,
0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
},
}
type Image struct {
*image.Paletted
numWidth int
numHeight int
dotSize int
}
var prng = &siprng{}
// randIntn returns a pseudorandom non-negative int in range [0, n).
func randIntn(n int) int {
return prng.Intn(n)
}
// randInt returns a pseudorandom int in range [from, to].
func randInt(from, to int) int {
return prng.Intn(to+1-from) + from
}
// randFloat returns a pseudorandom float64 in range [from, to].
func randFloat(from, to float64) float64 {
return (to-from)*prng.Float64() + from
}
func randomPalette() color.Palette {
p := make([]color.Color, circleCount+1)
// Transparent color.
p[0] = color.RGBA{0xFF, 0xFF, 0xFF, 0x00}
// Primary color.
prim := color.RGBA{
uint8(randIntn(129)),
uint8(randIntn(129)),
uint8(randIntn(129)),
0xFF,
}
p[1] = prim
// Circle colors.
for i := 2; i <= circleCount; i++ {
p[i] = randomBrightness(prim, 255)
}
return p
}
// NewImage returns a new captcha image of the given width and height with the
// given digits, where each digit must be in range 0-9.
func NewImage(digits []byte, width, height int) *Image {
m := new(Image)
m.Paletted = image.NewPaletted(image.Rect(0, 0, width, height), randomPalette())
m.calculateSizes(width, height, len(digits))
// Randomly position captcha inside the image.
maxx := width - (m.numWidth+m.dotSize)*len(digits) - m.dotSize
maxy := height - m.numHeight - m.dotSize*2
var border int
if width > height {
border = height / 5
} else {
border = width / 5
}
x := randInt(border, maxx-border)
y := randInt(border, maxy-border)
// Draw digits.
for _, n := range digits {
m.drawDigit(font[n], x, y)
x += m.numWidth + m.dotSize
}
// Draw strike-through line.
m.strikeThrough()
// Apply wave distortion.
m.distort(randFloat(5, 10), randFloat(100, 200))
// Fill image with random circles.
m.fillWithCircles(circleCount, m.dotSize)
return m
}
// encodedPNG encodes an image to PNG and returns
// the result as a byte slice.
func (m *Image) encodedPNG() []byte {
var buf bytes.Buffer
if err := png.Encode(&buf, m.Paletted); err != nil {
panic(err.Error())
}
return buf.Bytes()
}
// WriteTo writes captcha image in PNG format into the given writer.
func (m *Image) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(m.encodedPNG())
return int64(n), err
}
func (m *Image) calculateSizes(width, height, ncount int) {
// Goal: fit all digits inside the image.
var border int
if width > height {
border = height / 4
} else {
border = width / 4
}
// Convert everything to floats for calculations.
w := float64(width - border*2)
h := float64(height - border*2)
// fw takes into account 1-dot spacing between digits.
fw := float64(fontWidth + 1)
fh := float64(fontHeight)
nc := float64(ncount)
// Calculate the width of a single digit taking into account only the
// width of the image.
nw := w / nc
// Calculate the height of a digit from this width.
nh := nw * fh / fw
// Digit too high?
if nh > h {
// Fit digits based on height.
nh = h
nw = fw / fh * nh
}
// Calculate dot size.
m.dotSize = int(nh / fh)
// Save everything, making the actual width smaller by 1 dot to account
// for spacing between digits.
m.numWidth = int(nw) - m.dotSize
m.numHeight = int(nh)
}
func (m *Image) drawHorizLine(fromX, toX, y int, colorIdx uint8) {
for x := fromX; x <= toX; x++ {
m.SetColorIndex(x, y, colorIdx)
}
}
func (m *Image) drawCircle(x, y, radius int, colorIdx uint8) {
f := 1 - radius
dfx := 1
dfy := -2 * radius
xo := 0
yo := radius
m.SetColorIndex(x, y+radius, colorIdx)
m.SetColorIndex(x, y-radius, colorIdx)
m.drawHorizLine(x-radius, x+radius, y, colorIdx)
for xo < yo {
if f >= 0 {
yo--
dfy += 2
f += dfy
}
xo++
dfx += 2
f += dfx
m.drawHorizLine(x-xo, x+xo, y+yo, colorIdx)
m.drawHorizLine(x-xo, x+xo, y-yo, colorIdx)
m.drawHorizLine(x-yo, x+yo, y+xo, colorIdx)
m.drawHorizLine(x-yo, x+yo, y-xo, colorIdx)
}
}
func (m *Image) fillWithCircles(n, maxradius int) {
maxx := m.Bounds().Max.X
maxy := m.Bounds().Max.Y
for i := 0; i < n; i++ {
colorIdx := uint8(randInt(1, circleCount-1))
r := randInt(1, maxradius)
m.drawCircle(randInt(r, maxx-r), randInt(r, maxy-r), r, colorIdx)
}
}
func (m *Image) strikeThrough() {
maxx := m.Bounds().Max.X
maxy := m.Bounds().Max.Y
y := randInt(maxy/3, maxy-maxy/3)
amplitude := randFloat(5, 20)
period := randFloat(80, 180)
dx := 2.0 * math.Pi / period
for x := 0; x < maxx; x++ {
xo := amplitude * math.Cos(float64(y)*dx)
yo := amplitude * math.Sin(float64(x)*dx)
for yn := 0; yn < m.dotSize; yn++ {
r := randInt(0, m.dotSize)
m.drawCircle(x+int(xo), y+int(yo)+(yn*m.dotSize), r/2, 1)
}
}
}
func (m *Image) drawDigit(digit []byte, x, y int) {
skf := randFloat(-maxSkew, maxSkew)
xs := float64(x)
r := m.dotSize / 2
y += randInt(-r, r)
for yo := 0; yo < fontHeight; yo++ {
for xo := 0; xo < fontWidth; xo++ {
if digit[yo*fontWidth+xo] != blackChar {
continue
}
m.drawCircle(x+xo*m.dotSize, y+yo*m.dotSize, r, 1)
}
xs += skf
x = int(xs)
}
}
func (m *Image) distort(amplude float64, period float64) {
w := m.Bounds().Max.X
h := m.Bounds().Max.Y
oldm := m.Paletted
newm := image.NewPaletted(image.Rect(0, 0, w, h), oldm.Palette)
dx := 2.0 * math.Pi / period
for x := 0; x < w; x++ {
for y := 0; y < h; y++ {
xo := amplude * math.Sin(float64(y)*dx)
yo := amplude * math.Cos(float64(x)*dx)
newm.SetColorIndex(x, y, oldm.ColorIndexAt(x+int(xo), y+int(yo)))
}
}
m.Paletted = newm
}
func randomBrightness(c color.RGBA, max uint8) color.RGBA {
minc := min3(c.R, c.G, c.B)
maxc := max3(c.R, c.G, c.B)
if maxc > max {
return c
}
n := randIntn(int(max-maxc)) - int(minc)
return color.RGBA{
uint8(int(c.R) + n),
uint8(int(c.G) + n),
uint8(int(c.B) + n),
uint8(c.A),
}
}
func min3(x, y, z uint8) (m uint8) {
m = x
if y < m {
m = y
}
if z < m {
m = z
}
return
}
func max3(x, y, z uint8) (m uint8) {
m = x
if y > m {
m = y
}
if z > m {
m = z
}
return
}

View File

@ -0,0 +1,38 @@
package captcha
import (
"testing"
"github.com/astaxie/beego/utils"
)
type byteCounter struct {
n int64
}
func (bc *byteCounter) Write(b []byte) (int, error) {
bc.n += int64(len(b))
return len(b), nil
}
func BenchmarkNewImage(b *testing.B) {
b.StopTimer()
d := utils.RandomCreateBytes(challengeNums, defaultChars...)
b.StartTimer()
for i := 0; i < b.N; i++ {
NewImage(d, stdWidth, stdHeight)
}
}
func BenchmarkImageWriteTo(b *testing.B) {
b.StopTimer()
d := utils.RandomCreateBytes(challengeNums, defaultChars...)
b.StartTimer()
counter := &byteCounter{}
for i := 0; i < b.N; i++ {
img := NewImage(d, stdWidth, stdHeight)
img.WriteTo(counter)
b.SetBytes(counter.n)
counter.n = 0
}
}

264
utils/captcha/siprng.go Normal file
View File

@ -0,0 +1,264 @@
// modifiy and integrated to Beego from https://github.com/dchest/captcha
package captcha
import (
"crypto/rand"
"encoding/binary"
"io"
"sync"
)
// siprng is PRNG based on SipHash-2-4.
type siprng struct {
mu sync.Mutex
k0, k1, ctr uint64
}
// siphash implements SipHash-2-4, accepting a uint64 as a message.
func siphash(k0, k1, m uint64) uint64 {
// Initialization.
v0 := k0 ^ 0x736f6d6570736575
v1 := k1 ^ 0x646f72616e646f6d
v2 := k0 ^ 0x6c7967656e657261
v3 := k1 ^ 0x7465646279746573
t := uint64(8) << 56
// Compression.
v3 ^= m
// Round 1.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
// Round 2.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
v0 ^= m
// Compress last block.
v3 ^= t
// Round 1.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
// Round 2.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
v0 ^= t
// Finalization.
v2 ^= 0xff
// Round 1.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
// Round 2.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
// Round 3.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
// Round 4.
v0 += v1
v1 = v1<<13 | v1>>(64-13)
v1 ^= v0
v0 = v0<<32 | v0>>(64-32)
v2 += v3
v3 = v3<<16 | v3>>(64-16)
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>(64-21)
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>(64-17)
v1 ^= v2
v2 = v2<<32 | v2>>(64-32)
return v0 ^ v1 ^ v2 ^ v3
}
// rekey sets a new PRNG key, which is read from crypto/rand.
func (p *siprng) rekey() {
var k [16]byte
if _, err := io.ReadFull(rand.Reader, k[:]); err != nil {
panic(err.Error())
}
p.k0 = binary.LittleEndian.Uint64(k[0:8])
p.k1 = binary.LittleEndian.Uint64(k[8:16])
p.ctr = 1
}
// Uint64 returns a new pseudorandom uint64.
// It rekeys PRNG on the first call and every 64 MB of generated data.
func (p *siprng) Uint64() uint64 {
p.mu.Lock()
if p.ctr == 0 || p.ctr > 8*1024*1024 {
p.rekey()
}
v := siphash(p.k0, p.k1, p.ctr)
p.ctr++
p.mu.Unlock()
return v
}
func (p *siprng) Int63() int64 {
return int64(p.Uint64() & 0x7fffffffffffffff)
}
func (p *siprng) Uint32() uint32 {
return uint32(p.Uint64())
}
func (p *siprng) Int31() int32 {
return int32(p.Uint32() & 0x7fffffff)
}
func (p *siprng) Intn(n int) int {
if n <= 0 {
panic("invalid argument to Intn")
}
if n <= 1<<31-1 {
return int(p.Int31n(int32(n)))
}
return int(p.Int63n(int64(n)))
}
func (p *siprng) Int63n(n int64) int64 {
if n <= 0 {
panic("invalid argument to Int63n")
}
max := int64((1 << 63) - 1 - (1<<63)%uint64(n))
v := p.Int63()
for v > max {
v = p.Int63()
}
return v % n
}
func (p *siprng) Int31n(n int32) int32 {
if n <= 0 {
panic("invalid argument to Int31n")
}
max := int32((1 << 31) - 1 - (1<<31)%uint32(n))
v := p.Int31()
for v > max {
v = p.Int31()
}
return v % n
}
func (p *siprng) Float64() float64 { return float64(p.Int63()) / (1 << 63) }

View File

@ -0,0 +1,19 @@
package captcha
import "testing"
func TestSiphash(t *testing.T) {
good := uint64(0xe849e8bb6ffe2567)
cur := siphash(0, 0, 0)
if cur != good {
t.Fatalf("siphash: expected %x, got %x", good, cur)
}
}
func BenchmarkSiprng(b *testing.B) {
b.SetBytes(8)
p := &siprng{}
for i := 0; i < b.N; i++ {
p.Uint64()
}
}

299
utils/mail.go Normal file
View File

@ -0,0 +1,299 @@
package utils
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
"mime/multipart"
"net/mail"
"net/smtp"
"net/textproto"
"os"
"path"
"path/filepath"
"strconv"
"strings"
)
const (
maxLineLength = 76
)
// Email is the type used for email messages
type Email struct {
Auth smtp.Auth
Identity string `json:"identity"`
Username string `json:"username"`
Password string `json:"password"`
Host string `json:"host"`
Port int `json:"port"`
From string `json:"from"`
To []string
Bcc []string
Cc []string
Subject string
Text string // Plaintext message (optional)
HTML string // Html message (optional)
Headers textproto.MIMEHeader
Attachments []*Attachment
ReadReceipt []string
}
// Attachment is a struct representing an email attachment.
// Based on the mime/multipart.FileHeader struct, Attachment contains the name, MIMEHeader, and content of the attachment in question
type Attachment struct {
Filename string
Header textproto.MIMEHeader
Content []byte
}
func NewEMail(config string) *Email {
e := new(Email)
e.Headers = textproto.MIMEHeader{}
err := json.Unmarshal([]byte(config), e)
if err != nil {
return nil
}
if e.From == "" {
e.From = e.Username
}
return e
}
// make all send information to byte
func (e *Email) Bytes() ([]byte, error) {
buff := &bytes.Buffer{}
w := multipart.NewWriter(buff)
// Set the appropriate headers (overwriting any conflicts)
// Leave out Bcc (only included in envelope headers)
e.Headers.Set("To", strings.Join(e.To, ","))
if e.Cc != nil {
e.Headers.Set("Cc", strings.Join(e.Cc, ","))
}
e.Headers.Set("From", e.From)
e.Headers.Set("Subject", e.Subject)
if len(e.ReadReceipt) != 0 {
e.Headers.Set("Disposition-Notification-To", strings.Join(e.ReadReceipt, ","))
}
e.Headers.Set("MIME-Version", "1.0")
e.Headers.Set("Content-Type", fmt.Sprintf("multipart/mixed;\r\n boundary=%s\r\n", w.Boundary()))
// Write the envelope headers (including any custom headers)
if err := headerToBytes(buff, e.Headers); err != nil {
return nil, fmt.Errorf("Failed to render message headers: %s", err)
}
// Start the multipart/mixed part
fmt.Fprintf(buff, "--%s\r\n", w.Boundary())
header := textproto.MIMEHeader{}
// Check to see if there is a Text or HTML field
if e.Text != "" || e.HTML != "" {
subWriter := multipart.NewWriter(buff)
// Create the multipart alternative part
header.Set("Content-Type", fmt.Sprintf("multipart/alternative;\r\n boundary=%s\r\n", subWriter.Boundary()))
// Write the header
if err := headerToBytes(buff, header); err != nil {
return nil, fmt.Errorf("Failed to render multipart message headers: %s", err)
}
// Create the body sections
if e.Text != "" {
header.Set("Content-Type", fmt.Sprintf("text/plain; charset=UTF-8"))
header.Set("Content-Transfer-Encoding", "quoted-printable")
if _, err := subWriter.CreatePart(header); err != nil {
return nil, err
}
// Write the text
if err := quotePrintEncode(buff, e.Text); err != nil {
return nil, err
}
}
if e.HTML != "" {
header.Set("Content-Type", fmt.Sprintf("text/html; charset=UTF-8"))
header.Set("Content-Transfer-Encoding", "quoted-printable")
if _, err := subWriter.CreatePart(header); err != nil {
return nil, err
}
// Write the text
if err := quotePrintEncode(buff, e.HTML); err != nil {
return nil, err
}
}
if err := subWriter.Close(); err != nil {
return nil, err
}
}
// Create attachment part, if necessary
for _, a := range e.Attachments {
ap, err := w.CreatePart(a.Header)
if err != nil {
return nil, err
}
// Write the base64Wrapped content to the part
base64Wrap(ap, a.Content)
}
if err := w.Close(); err != nil {
return nil, err
}
return buff.Bytes(), nil
}
// Attach file to the send mail
func (e *Email) AttachFile(filename string) (a *Attachment, err error) {
f, err := os.Open(filename)
if err != nil {
return
}
ct := mime.TypeByExtension(filepath.Ext(filename))
basename := path.Base(filename)
return e.Attach(f, basename, ct)
}
// Attach is used to attach content from an io.Reader to the email.
// Required parameters include an io.Reader, the desired filename for the attachment, and the Content-Type
func (e *Email) Attach(r io.Reader, filename string, c string) (a *Attachment, err error) {
var buffer bytes.Buffer
if _, err = io.Copy(&buffer, r); err != nil {
return
}
at := &Attachment{
Filename: filename,
Header: textproto.MIMEHeader{},
Content: buffer.Bytes(),
}
// Get the Content-Type to be used in the MIMEHeader
if c != "" {
at.Header.Set("Content-Type", c)
} else {
// If the Content-Type is blank, set the Content-Type to "application/octet-stream"
at.Header.Set("Content-Type", "application/octet-stream")
}
at.Header.Set("Content-Disposition", fmt.Sprintf("attachment;\r\n filename=\"%s\"", filename))
at.Header.Set("Content-Transfer-Encoding", "base64")
e.Attachments = append(e.Attachments, at)
return at, nil
}
func (e *Email) Send() error {
if e.Auth == nil {
e.Auth = smtp.PlainAuth(e.Identity, e.Username, e.Password, e.Host)
}
// Merge the To, Cc, and Bcc fields
to := make([]string, 0, len(e.To)+len(e.Cc)+len(e.Bcc))
to = append(append(append(to, e.To...), e.Cc...), e.Bcc...)
// Check to make sure there is at least one recipient and one "From" address
if e.From == "" || len(to) == 0 {
return errors.New("Must specify at least one From address and one To address")
}
from, err := mail.ParseAddress(e.From)
if err != nil {
return err
}
raw, err := e.Bytes()
if err != nil {
return err
}
return smtp.SendMail(e.Host+":"+strconv.Itoa(e.Port), e.Auth, from.Address, to, raw)
}
// quotePrintEncode writes the quoted-printable text to the IO Writer (according to RFC 2045)
func quotePrintEncode(w io.Writer, s string) error {
var buf [3]byte
mc := 0
for i := 0; i < len(s); i++ {
c := s[i]
// We're assuming Unix style text formats as input (LF line break), and
// quoted-printble uses CRLF line breaks. (Literal CRs will become
// "=0D", but probably shouldn't be there to begin with!)
if c == '\n' {
io.WriteString(w, "\r\n")
mc = 0
continue
}
var nextOut []byte
if isPrintable(c) {
nextOut = append(buf[:0], c)
} else {
nextOut = buf[:]
qpEscape(nextOut, c)
}
// Add a soft line break if the next (encoded) byte would push this line
// to or past the limit.
if mc+len(nextOut) >= maxLineLength {
if _, err := io.WriteString(w, "=\r\n"); err != nil {
return err
}
mc = 0
}
if _, err := w.Write(nextOut); err != nil {
return err
}
mc += len(nextOut)
}
// No trailing end-of-line?? Soft line break, then. TODO: is this sane?
if mc > 0 {
io.WriteString(w, "=\r\n")
}
return nil
}
// isPrintable returns true if the rune given is "printable" according to RFC 2045, false otherwise
func isPrintable(c byte) bool {
return (c >= '!' && c <= '<') || (c >= '>' && c <= '~') || (c == ' ' || c == '\n' || c == '\t')
}
// qpEscape is a helper function for quotePrintEncode which escapes a
// non-printable byte. Expects len(dest) == 3.
func qpEscape(dest []byte, c byte) {
const nums = "0123456789ABCDEF"
dest[0] = '='
dest[1] = nums[(c&0xf0)>>4]
dest[2] = nums[(c & 0xf)]
}
// headerToBytes enumerates the key and values in the header, and writes the results to the IO Writer
func headerToBytes(w io.Writer, t textproto.MIMEHeader) error {
for k, v := range t {
// Write the header key
_, err := fmt.Fprintf(w, "%s:", k)
if err != nil {
return err
}
// Write each value in the header
for _, c := range v {
_, err := fmt.Fprintf(w, " %s\r\n", c)
if err != nil {
return err
}
}
}
return nil
}
// base64Wrap encodeds the attachment content, and wraps it according to RFC 2045 standards (every 76 chars)
// The output is then written to the specified io.Writer
func base64Wrap(w io.Writer, b []byte) {
// 57 raw bytes per 76-byte base64 line.
const maxRaw = 57
// Buffer for each line, including trailing CRLF.
var buffer [maxLineLength + len("\r\n")]byte
copy(buffer[maxLineLength:], "\r\n")
// Process raw chunks until there's no longer enough to fill a line.
for len(b) >= maxRaw {
base64.StdEncoding.Encode(buffer[:], b[:maxRaw])
w.Write(buffer[:])
b = b[maxRaw:]
}
// Handle the last chunk of bytes.
if len(b) > 0 {
out := buffer[:base64.StdEncoding.EncodedLen(len(b))]
base64.StdEncoding.Encode(out, b)
out = append(out, "\r\n"...)
w.Write(out)
}
}

27
utils/mail_test.go Normal file
View File

@ -0,0 +1,27 @@
package utils
import "testing"
func TestMail(t *testing.T) {
config := `{"username":"astaxie@gmail.com","password":"astaxie","host":"smtp.gmail.com","port":587}`
mail := NewEMail(config)
if mail.Username != "astaxie@gmail.com" {
t.Fatal("email parse get username error")
}
if mail.Password != "astaxie" {
t.Fatal("email parse get password error")
}
if mail.Host != "smtp.gmail.com" {
t.Fatal("email parse get host error")
}
if mail.Port != 587 {
t.Fatal("email parse get port error")
}
mail.To = []string{"xiemengjun@gmail.com"}
mail.From = "astaxie@gmail.com"
mail.Subject = "hi, just from beego!"
mail.Text = "Text Body is, of course, supported!"
mail.HTML = "<h1>Fancy Html is supported, too!</h1>"
mail.AttachFile("/Users/astaxie/github/beego/beego.go")
mail.Send()
}

20
utils/rand.go Normal file
View File

@ -0,0 +1,20 @@
package utils
import (
"crypto/rand"
)
// RandomCreateBytes generate random []byte by specify chars.
func RandomCreateBytes(n int, alphabets ...byte) []byte {
const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
var bytes = make([]byte, n)
rand.Read(bytes)
for i, b := range bytes {
if len(alphabets) == 0 {
bytes[i] = alphanum[b%byte(len(alphanum))]
} else {
bytes[i] = alphabets[b%byte(len(alphabets))]
}
}
return bytes
}