From 62f54cbbeef58c5bdbf2d656c5e251de4512ba8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=B0=8F=E9=BB=91?= Date: Sat, 28 Dec 2013 20:14:36 +0800 Subject: [PATCH 01/46] fix typo error --- controller.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/controller.go b/controller.go index 1f148f06..04352706 100644 --- a/controller.go +++ b/controller.go @@ -140,7 +140,7 @@ func (c *Controller) RenderString() (string, error) { return string(b), e } -// RenderBytes returns the bytes of renderd tempate string. Do not send out response. +// RenderBytes returns the bytes of rendered template string. Do not send out response. func (c *Controller) RenderBytes() ([]byte, error) { //if the controller has set layout, then first get the tplname's content set the content to the layout if c.Layout != "" { From 3a08eec1f93eb6f659ceb6fdd17d8caeed1ee63b Mon Sep 17 00:00:00 2001 From: Pengfei Xue Date: Mon, 30 Dec 2013 11:29:35 +0800 Subject: [PATCH 02/46] simplify condition test for trailing / --- router.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/router.go b/router.go index d54576ec..cac3c82c 100644 --- a/router.go +++ b/router.go @@ -575,12 +575,11 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) } // pattern /admin url /admin 200 /admin/ 200 // pattern /admin/ url /admin 301 /admin/ 200 - if requestPath[n-1] != '/' && len(route.pattern) == n+1 && - route.pattern[n] == '/' && route.pattern[:n] == requestPath { + if requestPath[n-1] != '/' && requestPath+"/" == route.pattern { http.Redirect(w, r, requestPath+"/", 301) goto Admin } - if requestPath[n-1] == '/' && n >= 2 && requestPath[:n-2] == route.pattern { + if requestPath[n-1] == '/' && route.pattern+"/" == requestPath { runMethod = p.getRunMethod(r.Method, context, route) if runMethod != "" { runrouter = route.controllerType From 984b0cbf31ea342464d3eb4834f5655f273c636d Mon Sep 17 00:00:00 2001 From: astaxie Date: Mon, 30 Dec 2013 15:06:51 +0800 Subject: [PATCH 03/46] 1. :all param default expr change from (.+) to (.*) 2. add hookfunc to support appstart hook --- beego.go | 21 +++++++++++++++++++++ filter.go | 17 +++++++++++------ fiter_test.go | 29 +++++++++++++++++++++++++++++ router.go | 20 ++++++++++++++------ 4 files changed, 75 insertions(+), 12 deletions(-) diff --git a/beego.go b/beego.go index 64e224da..a7f3fd5a 100644 --- a/beego.go +++ b/beego.go @@ -13,6 +13,13 @@ import ( // beego web framework version. const VERSION = "1.0.1" +type hookfunc func() error //hook function to run +var hooks []hookfunc //hook function slice to store the hookfunc + +func init() { + hooks = make([]hookfunc, 0) +} + // Router adds a patterned controller handler to BeeApp. // it's an alias method of App.Router. func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { @@ -87,6 +94,12 @@ func InsertFilter(pattern string, pos int, filter FilterFunc) *App { return BeeApp } +// The hookfunc will run in beego.Run() +// such as sessionInit, middlerware start, buildtemplate, admin start +func AddAPPStartHook(hf hookfunc) { + hooks = append(hooks, hf) +} + // Run beego application. // it's alias of App.Run. func Run() { @@ -102,6 +115,14 @@ func Run() { //init mime initMime() + // do hooks function + for _, hk := range hooks { + err := hk() + if err != nil { + panic(err) + } + } + if SessionOn { GlobalSessions, _ = session.NewManager(SessionProvider, SessionName, diff --git a/filter.go b/filter.go index 7e0245a6..98868865 100644 --- a/filter.go +++ b/filter.go @@ -28,6 +28,12 @@ func (mr *FilterRouter) ValidRouter(router string) (bool, map[string]string) { if router == mr.pattern { return true, nil } + //pattern /admin router /admin/ match + //pattern /admin/ router /admin don't match, because url will 301 in router + if n := len(router); n > 1 && router[n-1] == '/' && router[:n-2] == mr.pattern { + return true, nil + } + if mr.hasregex { if !mr.regex.MatchString(router) { return false, nil @@ -46,7 +52,7 @@ func (mr *FilterRouter) ValidRouter(router string) (bool, map[string]string) { return false, nil } -func buildFilter(pattern string, filter FilterFunc) *FilterRouter { +func buildFilter(pattern string, filter FilterFunc) (*FilterRouter, error) { mr := new(FilterRouter) mr.params = make(map[int]string) mr.filterFunc = filter @@ -54,7 +60,7 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter { j := 0 for i, part := range parts { if strings.HasPrefix(part, ":") { - expr := "(.+)" + expr := "(.*)" //a user may choose to override the default expression // similar to expressjs: ‘/user/:id([0-9]+)’ if index := strings.Index(part, "("); index != -1 { @@ -77,7 +83,7 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter { j++ } if strings.HasPrefix(part, "*") { - expr := "(.+)" + expr := "(.*)" if part == "*.*" { mr.params[j] = ":path" parts[i] = "([^.]+).([^.]+)" @@ -137,12 +143,11 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter { pattern = strings.Join(parts, "/") regex, regexErr := regexp.Compile(pattern) if regexErr != nil { - //TODO add error handling here to avoid panic - panic(regexErr) + return nil, regexErr } mr.regex = regex mr.hasregex = true } mr.pattern = pattern - return mr + return mr, nil } diff --git a/fiter_test.go b/fiter_test.go index 4e9dae6a..7fe9a641 100644 --- a/fiter_test.go +++ b/fiter_test.go @@ -23,3 +23,32 @@ func TestFilter(t *testing.T) { t.Errorf("user define func can't run") } } + +var FilterAdminUser = func(ctx *context.Context) { + ctx.Output.Body([]byte("i am admin")) +} + +// Filter pattern /admin/:all +// all url like /admin/ /admin/xie will all get filter + +func TestPatternTwo(t *testing.T) { + r, _ := http.NewRequest("GET", "/admin/", nil) + w := httptest.NewRecorder() + handler := NewControllerRegistor() + handler.AddFilter("/admin/:all", "AfterStatic", FilterAdminUser) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am admin" { + t.Errorf("filter /admin/ can't run") + } +} + +func TestPatternThree(t *testing.T) { + r, _ := http.NewRequest("GET", "/admin/astaxie", nil) + w := httptest.NewRecorder() + handler := NewControllerRegistor() + handler.AddFilter("/admin/:all", "AfterStatic", FilterAdminUser) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am admin" { + t.Errorf("filter /admin/astaxie can't run") + } +} diff --git a/router.go b/router.go index d54576ec..21479a13 100644 --- a/router.go +++ b/router.go @@ -77,7 +77,7 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM params := make(map[int]string) for i, part := range parts { if strings.HasPrefix(part, ":") { - expr := "(.+)" + expr := "(.*)" //a user may choose to override the defult expression // similar to expressjs: ‘/user/:id([0-9]+)’ if index := strings.Index(part, "("); index != -1 { @@ -100,7 +100,7 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM j++ } if strings.HasPrefix(part, "*") { - expr := "(.+)" + expr := "(.*)" if part == "*.*" { params[j] = ":path" parts[i] = "([^.]+).([^.]+)" @@ -238,8 +238,11 @@ func (p *ControllerRegistor) AddAuto(c ControllerInterface) { // [Deprecated] use InsertFilter. // Add FilterFunc with pattern for action. -func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc) { - mr := buildFilter(pattern, filter) +func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc) error { + mr, err := buildFilter(pattern, filter) + if err != nil { + return err + } switch action { case "BeforeRouter": p.filters[BeforeRouter] = append(p.filters[BeforeRouter], mr) @@ -253,13 +256,18 @@ func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc p.filters[FinishRouter] = append(p.filters[FinishRouter], mr) } p.enableFilter = true + return nil } // Add a FilterFunc with pattern rule and action constant. -func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc) { - mr := buildFilter(pattern, filter) +func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc) error { + mr, err := buildFilter(pattern, filter) + if err != nil { + return err + } p.filters[pos] = append(p.filters[pos], mr) p.enableFilter = true + return nil } // UrlFor does another controller handler in this request function. From e0e8fa6e2a77db772515c9632c64c9e5789f4065 Mon Sep 17 00:00:00 2001 From: astaxie Date: Mon, 30 Dec 2013 22:51:54 +0800 Subject: [PATCH 04/46] fix #413 --- example/chat/controllers/ws.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/example/chat/controllers/ws.go b/example/chat/controllers/ws.go index 9334336f..9b3f5b10 100644 --- a/example/chat/controllers/ws.go +++ b/example/chat/controllers/ws.go @@ -1,12 +1,13 @@ package controllers import ( - "github.com/astaxie/beego" - "github.com/garyburd/go-websocket/websocket" "io/ioutil" "math/rand" "net/http" "time" + + "github.com/astaxie/beego" + "github.com/gorilla/websocket" ) const ( From 412a4a04de232c3c83fcaedf50f961c0e962c2c1 Mon Sep 17 00:00:00 2001 From: slene Date: Mon, 30 Dec 2013 23:04:13 +0800 Subject: [PATCH 05/46] #384 --- orm/orm_raw.go | 305 ++++++++++++++++++++++++++++++------------------ orm/orm_test.go | 177 ++++++---------------------- 2 files changed, 228 insertions(+), 254 deletions(-) diff --git a/orm/orm_raw.go b/orm/orm_raw.go index 864515ac..7d204876 100644 --- a/orm/orm_raw.go +++ b/orm/orm_raw.go @@ -4,7 +4,6 @@ import ( "database/sql" "fmt" "reflect" - "strings" "time" ) @@ -164,65 +163,11 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { } } -func (o *rawSet) loopInitRefs(typ reflect.Type, refsPtr *[]interface{}, sIdxesPtr *[][]int) { - sIdxes := *sIdxesPtr - refs := *refsPtr - - if typ.Kind() == reflect.Struct { - if typ.String() == "time.Time" { - var ref interface{} - refs = append(refs, &ref) - sIdxes = append(sIdxes, []int{0}) - } else { - idxs := []int{} - outFor: - for idx := 0; idx < typ.NumField(); idx++ { - ctyp := typ.Field(idx) - - tag := ctyp.Tag.Get(defaultStructTagName) - for _, v := range strings.Split(tag, defaultStructTagDelim) { - if v == "-" { - continue outFor - } - } - - tp := ctyp.Type - if tp.Kind() == reflect.Ptr { - tp = tp.Elem() - } - - if tp.String() == "time.Time" { - var ref interface{} - refs = append(refs, &ref) - - } else if tp.Kind() != reflect.Struct { - var ref interface{} - refs = append(refs, &ref) - - } else { - // skip other type - continue - } - - idxs = append(idxs, idx) - } - sIdxes = append(sIdxes, idxs) - } - } else { - var ref interface{} - refs = append(refs, &ref) - sIdxes = append(sIdxes, []int{0}) - } - - *sIdxesPtr = sIdxes - *refsPtr = refs -} - -func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) { +func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) { nInds := *nIndsPtr cur := 0 - for i, idxs := range sIdxes { + for i := 0; i < len(sInds); i++ { sInd := sInds[i] eTyp := eTyps[i] @@ -258,32 +203,8 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect o.setFieldValue(ind, value) } cur++ - } else { - hasValue := false - for _, idx := range idxs { - tind := ind.Field(idx) - value := reflect.ValueOf(refs[cur]).Elem().Interface() - if value != nil { - hasValue = true - } - if tind.Kind() == reflect.Ptr { - if value == nil { - tindV := reflect.New(tind.Type()).Elem() - tind.Set(tindV) - } else { - tindV := reflect.New(tind.Type().Elem()) - o.setFieldValue(tindV.Elem(), value) - tind.Set(tindV) - } - } else { - o.setFieldValue(tind, value) - } - cur++ - } - if hasValue == false && isPtr { - val = reflect.New(val.Type()).Elem() - } } + } else { value := reflect.ValueOf(refs[cur]).Elem().Interface() if isPtr && value == nil { @@ -313,15 +234,12 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect } func (o *rawSet) QueryRow(containers ...interface{}) error { - if len(containers) == 0 { - panic(fmt.Errorf(" need at least one arg")) - } - refs := make([]interface{}, 0, len(containers)) - sIdxes := make([][]int, 0) sInds := make([]reflect.Value, 0) eTyps := make([]reflect.Type, 0) + structMode := false + var sMi *modelInfo for _, container := range containers { val := reflect.ValueOf(container) ind := reflect.Indirect(val) @@ -335,44 +253,120 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { if typ.Kind() == reflect.Ptr { typ = typ.Elem() } - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } sInds = append(sInds, ind) eTyps = append(eTyps, etyp) - o.loopInitRefs(typ, &refs, &sIdxes) + if typ.Kind() == reflect.Struct && typ.String() != "time.Time" { + if len(containers) > 1 { + panic(fmt.Errorf(" now support one struct only. see #384")) + } + + structMode = true + fn := getFullName(typ) + if mi, ok := modelCache.getByFN(fn); ok { + sMi = mi + } + } else { + var ref interface{} + refs = append(refs, &ref) + } } query := o.query o.orm.alias.DbBaser.ReplaceMarks(&query) args := getFlatParams(nil, o.args, o.orm.alias.TZ) - row := o.orm.db.QueryRow(query, args...) - - if err := row.Scan(refs...); err == sql.ErrNoRows { - return ErrNoRows - } else if err != nil { + rows, err := o.orm.db.Query(query, args...) + if err != nil { + if err == sql.ErrNoRows { + return ErrNoRows + } return err } - nInds := make([]reflect.Value, len(sInds)) - o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, true) - for i, sInd := range sInds { - nInd := nInds[i] - sInd.Set(nInd) + if rows.Next() { + if structMode { + columns, err := rows.Columns() + if err != nil { + return err + } + + columnsMp := make(map[string]interface{}, len(columns)) + + refs = make([]interface{}, 0, len(columns)) + for _, col := range columns { + var ref interface{} + columnsMp[col] = &ref + refs = append(refs, &ref) + } + + if err := rows.Scan(refs...); err != nil { + return err + } + + ind := sInds[0] + + if ind.Kind() == reflect.Ptr { + if ind.IsNil() || !ind.IsValid() { + ind.Set(reflect.New(eTyps[0].Elem())) + } + ind = ind.Elem() + } + + if sMi != nil { + for _, col := range columns { + if fi := sMi.fields.GetByColumn(col); fi != nil { + value := reflect.ValueOf(columnsMp[col]).Elem().Interface() + o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value) + } + } + } else { + for i := 0; i < ind.NumField(); i++ { + f := ind.Field(i) + fe := ind.Type().Field(i) + + var attrs map[string]bool + var tags map[string]string + parseStructTag(fe.Tag.Get("orm"), &attrs, &tags) + var col string + if col = tags["column"]; len(col) == 0 { + col = snakeString(fe.Name) + } + if v, ok := columnsMp[col]; ok { + value := reflect.ValueOf(v).Elem().Interface() + o.setFieldValue(f, value) + } + } + } + + } else { + if err := rows.Scan(refs...); err != nil { + return err + } + + nInds := make([]reflect.Value, len(sInds)) + o.loopSetRefs(refs, sInds, &nInds, eTyps, true) + for i, sInd := range sInds { + nInd := nInds[i] + sInd.Set(nInd) + } + } + + } else { + return ErrNoRows } return nil } func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { - refs := make([]interface{}, 0) - sIdxes := make([][]int, 0) + refs := make([]interface{}, 0, len(containers)) sInds := make([]reflect.Value, 0) eTyps := make([]reflect.Type, 0) + structMode := false + var sMi *modelInfo for _, container := range containers { val := reflect.ValueOf(container) sInd := reflect.Indirect(val) @@ -389,7 +383,20 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { sInds = append(sInds, sInd) eTyps = append(eTyps, etyp) - o.loopInitRefs(typ, &refs, &sIdxes) + if typ.Kind() == reflect.Struct && typ.String() != "time.Time" { + if len(containers) > 1 { + panic(fmt.Errorf(" now support one struct only. see #384")) + } + + structMode = true + fn := getFullName(typ) + if mi, ok := modelCache.getByFN(fn); ok { + sMi = mi + } + } else { + var ref interface{} + refs = append(refs, &ref) + } } query := o.query @@ -403,21 +410,97 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { nInds := make([]reflect.Value, len(sInds)) + sInd := sInds[0] + var cnt int64 for rows.Next() { - if err := rows.Scan(refs...); err != nil { - return 0, err - } - o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, cnt == 0) + if structMode { + columns, err := rows.Columns() + if err != nil { + return 0, err + } + + columnsMp := make(map[string]interface{}, len(columns)) + + refs = make([]interface{}, 0, len(columns)) + for _, col := range columns { + var ref interface{} + columnsMp[col] = &ref + refs = append(refs, &ref) + } + + if err := rows.Scan(refs...); err != nil { + return 0, err + } + + if cnt == 0 && !sInd.IsNil() { + sInd.Set(reflect.New(sInd.Type()).Elem()) + } + + var ind reflect.Value + if eTyps[0].Kind() == reflect.Ptr { + ind = reflect.New(eTyps[0].Elem()) + } else { + ind = reflect.New(eTyps[0]) + } + + if ind.Kind() == reflect.Ptr { + ind = ind.Elem() + } + + if sMi != nil { + for _, col := range columns { + if fi := sMi.fields.GetByColumn(col); fi != nil { + value := reflect.ValueOf(columnsMp[col]).Elem().Interface() + o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value) + } + } + } else { + for i := 0; i < ind.NumField(); i++ { + f := ind.Field(i) + fe := ind.Type().Field(i) + + var attrs map[string]bool + var tags map[string]string + parseStructTag(fe.Tag.Get("orm"), &attrs, &tags) + var col string + if col = tags["column"]; len(col) == 0 { + col = snakeString(fe.Name) + } + if v, ok := columnsMp[col]; ok { + value := reflect.ValueOf(v).Elem().Interface() + o.setFieldValue(f, value) + } + } + } + + if eTyps[0].Kind() == reflect.Ptr { + ind = ind.Addr() + } + + sInd = reflect.Append(sInd, ind) + + } else { + if err := rows.Scan(refs...); err != nil { + return 0, err + } + + o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0) + } cnt++ } if cnt > 0 { - for i, sInd := range sInds { - nInd := nInds[i] - sInd.Set(nInd) + + if structMode { + sInds[0].Set(sInd) + } else { + for i, sInd := range sInds { + nInd := nInds[i] + sInd.Set(nInd) + } } } diff --git a/orm/orm_test.go b/orm/orm_test.go index d92e3fab..410aa484 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -1322,58 +1322,6 @@ func TestRawQueryRow(t *testing.T) { } } - type Tmp struct { - Skip0 string - Id int - Char *string - Skip1 int `orm:"-"` - Date time.Time - DateTime time.Time - } - - Boolean = false - Text = "" - Int64 = 0 - Uint = 0 - - tmp := new(Tmp) - - cols = []string{ - "int", "char", "date", "datetime", "boolean", "text", "int64", "uint", - } - query = fmt.Sprintf("SELECT NULL, %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q) - values = []interface{}{ - tmp, &Boolean, &Text, &Int64, &Uint, - } - err = dORM.Raw(query, 1).QueryRow(values...) - throwFailNow(t, err) - - for _, col := range cols { - switch col { - case "id": - throwFail(t, AssertIs(tmp.Id, data_values[col])) - case "char": - c := tmp.Char - throwFail(t, AssertIs(*c, data_values[col])) - case "date": - v := tmp.Date.In(DefaultTimeLoc) - value := data_values[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, test_Date)) - case "datetime": - v := tmp.DateTime.In(DefaultTimeLoc) - value := data_values[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, test_DateTime)) - case "boolean": - throwFail(t, AssertIs(Boolean, data_values[col])) - case "text": - throwFail(t, AssertIs(Text, data_values[col])) - case "int64": - throwFail(t, AssertIs(Int64, data_values[col])) - case "uint": - throwFail(t, AssertIs(Uint, data_values[col])) - } - } - var ( uid int status *int @@ -1394,22 +1342,13 @@ func TestRawQueryRow(t *testing.T) { func TestQueryRows(t *testing.T) { Q := dDbBaser.TableQuote() - cols := []string{ - "id", "boolean", "char", "text", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32", - "int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal", - } - var datas []*Data - var dids []int - sep := fmt.Sprintf("%s, %s", Q, Q) - query := fmt.Sprintf("SELECT %s%s%s, id FROM %sdata%s", Q, strings.Join(cols, sep), Q, Q, Q) - num, err := dORM.Raw(query).QueryRows(&datas, &dids) + query := fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) + num, err := dORM.Raw(query).QueryRows(&datas) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(len(datas), 1)) - throwFailNow(t, AssertIs(len(dids), 1)) - throwFailNow(t, AssertIs(dids[0], 1)) ind := reflect.Indirect(reflect.ValueOf(datas[0])) @@ -1427,90 +1366,42 @@ func TestQueryRows(t *testing.T) { throwFail(t, AssertIs(vu == value, true), value, vu) } - type Tmp struct { - Id int - Name string - Skiped0 string `orm:"-"` - Pid *int - Skiped1 Data - Skiped2 *Data + var datas2 []Data + + query = fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q) + num, err = dORM.Raw(query).QueryRows(&datas2) + throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(len(datas2), 1)) + + ind = reflect.Indirect(reflect.ValueOf(datas2[0])) + + for name, value := range Data_Values { + e := ind.FieldByName(name) + vu := e.Interface() + switch name { + case "Date": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date) + value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date) + case "DateTime": + vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) + value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime) + } + throwFail(t, AssertIs(vu == value, true), value, vu) } - var ( - ids []int - userNames []string - profileIds1 []int - profileIds2 []*int - createds []time.Time - updateds []time.Time - tmps1 []*Tmp - tmps2 []Tmp - ) - cols = []string{ - "id", "user_name", "profile_id", "profile_id", "id", "user_name", "profile_id", "id", "user_name", "profile_id", "created", "updated", - } - query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s ORDER BY id", Q, strings.Join(cols, sep), Q, Q, Q) - num, err = dORM.Raw(query).QueryRows(&ids, &userNames, &profileIds1, &profileIds2, &tmps1, &tmps2, &createds, &updateds) + var ids []int + var usernames []string + num, err = dORM.Raw("SELECT id, user_name FROM user ORDER BY id asc").QueryRows(&ids, &usernames) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 3)) - - var users []User - dORM.QueryTable("user").OrderBy("Id").All(&users) - - for i := 0; i < 3; i++ { - id := ids[i] - name := userNames[i] - pid1 := profileIds1[i] - pid2 := profileIds2[i] - created := createds[i] - updated := updateds[i] - - user := users[i] - throwFailNow(t, AssertIs(id, user.Id)) - throwFailNow(t, AssertIs(name, user.UserName)) - if user.Profile != nil { - throwFailNow(t, AssertIs(pid1, user.Profile.Id)) - throwFailNow(t, AssertIs(*pid2, user.Profile.Id)) - } else { - throwFailNow(t, AssertIs(pid1, 0)) - throwFailNow(t, AssertIs(pid2, nil)) - } - throwFailNow(t, AssertIs(created, user.Created, test_Date)) - throwFailNow(t, AssertIs(updated, user.Updated, test_DateTime)) - - tmp := tmps1[i] - tmp1 := *tmp - throwFailNow(t, AssertIs(tmp1.Id, user.Id)) - throwFailNow(t, AssertIs(tmp1.Name, user.UserName)) - if user.Profile != nil { - pid := tmp1.Pid - throwFailNow(t, AssertIs(*pid, user.Profile.Id)) - } else { - throwFailNow(t, AssertIs(tmp1.Pid, nil)) - } - - tmp2 := tmps2[i] - throwFailNow(t, AssertIs(tmp2.Id, user.Id)) - throwFailNow(t, AssertIs(tmp2.Name, user.UserName)) - if user.Profile != nil { - pid := tmp2.Pid - throwFailNow(t, AssertIs(*pid, user.Profile.Id)) - } else { - throwFailNow(t, AssertIs(tmp2.Pid, nil)) - } - } - - type Sec struct { - Id int - Name string - } - - var tmp []*Sec - query = fmt.Sprintf("SELECT NULL, NULL FROM %suser%s LIMIT 1", Q, Q) - num, err = dORM.Raw(query).QueryRows(&tmp) - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) - throwFail(t, AssertIs(tmp[0], nil)) + throwFailNow(t, AssertIs(len(ids), 3)) + throwFailNow(t, AssertIs(ids[0], 2)) + throwFailNow(t, AssertIs(usernames[0], "slene")) + throwFailNow(t, AssertIs(ids[1], 3)) + throwFailNow(t, AssertIs(usernames[1], "astaxie")) + throwFailNow(t, AssertIs(ids[2], 4)) + throwFailNow(t, AssertIs(usernames[2], "nobody")) } func TestRawValues(t *testing.T) { From 94ad13c846a9e8a71b64d8560abc4643cd98a850 Mon Sep 17 00:00:00 2001 From: FuXiaoHei Date: Mon, 30 Dec 2013 23:32:57 +0800 Subject: [PATCH 06/46] add comments in logs package --- logs/conn.go | 9 +++++++++ logs/console.go | 7 +++++++ logs/file.go | 21 ++++++++++++++++++--- logs/log.go | 24 ++++++++++++++++++++++-- logs/smtp.go | 18 +++++++++++++++++- 5 files changed, 73 insertions(+), 6 deletions(-) diff --git a/logs/conn.go b/logs/conn.go index a5bc75c5..eed9ae2f 100644 --- a/logs/conn.go +++ b/logs/conn.go @@ -7,6 +7,8 @@ import ( "net" ) +// ConnWriter implements LoggerInterface. +// it writes messages in keep-live tcp connection. type ConnWriter struct { lg *log.Logger innerWriter io.WriteCloser @@ -17,12 +19,15 @@ type ConnWriter struct { Level int `json:"level"` } +// create new ConnWrite returning as LoggerInterface. func NewConn() LoggerInterface { conn := new(ConnWriter) conn.Level = LevelTrace return conn } +// init connection writer with json config. +// json config only need key "level". func (c *ConnWriter) Init(jsonconfig string) error { err := json.Unmarshal([]byte(jsonconfig), c) if err != nil { @@ -31,6 +36,8 @@ func (c *ConnWriter) Init(jsonconfig string) error { return nil } +// write message in connection. +// if connection is down, try to re-connect. func (c *ConnWriter) WriteMsg(msg string, level int) error { if level < c.Level { return nil @@ -49,10 +56,12 @@ func (c *ConnWriter) WriteMsg(msg string, level int) error { return nil } +// implementing method. empty. func (c *ConnWriter) Flush() { } +// destroy connection writer and close tcp listener. func (c *ConnWriter) Destroy() { if c.innerWriter == nil { return diff --git a/logs/console.go b/logs/console.go index 0c7fc1e9..c5fa2380 100644 --- a/logs/console.go +++ b/logs/console.go @@ -6,11 +6,13 @@ import ( "os" ) +// ConsoleWriter implements LoggerInterface and writes messages to terminal. type ConsoleWriter struct { lg *log.Logger Level int `json:"level"` } +// create ConsoleWriter returning as LoggerInterface. func NewConsole() LoggerInterface { cw := new(ConsoleWriter) cw.lg = log.New(os.Stdout, "", log.Ldate|log.Ltime) @@ -18,6 +20,8 @@ func NewConsole() LoggerInterface { return cw } +// init console logger. +// jsonconfig like '{"level":LevelTrace}'. func (c *ConsoleWriter) Init(jsonconfig string) error { err := json.Unmarshal([]byte(jsonconfig), c) if err != nil { @@ -26,6 +30,7 @@ func (c *ConsoleWriter) Init(jsonconfig string) error { return nil } +// write message in console. func (c *ConsoleWriter) WriteMsg(msg string, level int) error { if level < c.Level { return nil @@ -34,10 +39,12 @@ func (c *ConsoleWriter) WriteMsg(msg string, level int) error { return nil } +// implementing method. empty. func (c *ConsoleWriter) Destroy() { } +// implementing method. empty. func (c *ConsoleWriter) Flush() { } diff --git a/logs/file.go b/logs/file.go index e19c6c7f..d0512e26 100644 --- a/logs/file.go +++ b/logs/file.go @@ -13,6 +13,8 @@ import ( "time" ) +// FileLogWriter implements LoggerInterface. +// It writes messages by lines limit, file size limit, or time frequency. type FileLogWriter struct { *log.Logger mw *MuxWriter @@ -38,17 +40,20 @@ type FileLogWriter struct { Level int `json:"level"` } +// an *os.File writer with locker. type MuxWriter struct { sync.Mutex fd *os.File } +// write to os.File. func (l *MuxWriter) Write(b []byte) (int, error) { l.Lock() defer l.Unlock() return l.fd.Write(b) } +// set os.File in writer. func (l *MuxWriter) SetFd(fd *os.File) { if l.fd != nil { l.fd.Close() @@ -56,6 +61,7 @@ func (l *MuxWriter) SetFd(fd *os.File) { l.fd = fd } +// create a FileLogWriter returning as LoggerInterface. func NewFileWriter() LoggerInterface { w := &FileLogWriter{ Filename: "", @@ -73,15 +79,16 @@ func NewFileWriter() LoggerInterface { return w } -// jsonconfig like this -//{ +// Init file logger with json config. +// jsonconfig like: +// { // "filename":"logs/beego.log", // "maxlines":10000, // "maxsize":1<<30, // "daily":true, // "maxdays":15, // "rotate":true -//} +// } func (w *FileLogWriter) Init(jsonconfig string) error { err := json.Unmarshal([]byte(jsonconfig), w) if err != nil { @@ -94,6 +101,7 @@ func (w *FileLogWriter) Init(jsonconfig string) error { return err } +// start file logger. create log file and set to locker-inside file writer. func (w *FileLogWriter) StartLogger() error { fd, err := w.createLogFile() if err != nil { @@ -122,6 +130,7 @@ func (w *FileLogWriter) docheck(size int) { w.maxsize_cursize += size } +// write logger message into file. func (w *FileLogWriter) WriteMsg(msg string, level int) error { if level < w.Level { return nil @@ -158,6 +167,8 @@ func (w *FileLogWriter) initFd() error { return nil } +// DoRotate means it need to write file in new file. +// new file name like xx.log.2013-01-01.2 func (w *FileLogWriter) DoRotate() error { _, err := os.Lstat(w.Filename) if err == nil { // file exists @@ -211,10 +222,14 @@ func (w *FileLogWriter) deleteOldLog() { }) } +// destroy file logger, close file writer. func (w *FileLogWriter) Destroy() { w.mw.fd.Close() } +// flush file logger. +// there are no buffering messages in file logger in memory. +// flush file means sync file from disk. func (w *FileLogWriter) Flush() { w.mw.fd.Sync() } diff --git a/logs/log.go b/logs/log.go index a9254aaa..b65414cb 100644 --- a/logs/log.go +++ b/logs/log.go @@ -6,6 +6,7 @@ import ( ) const ( + // log message levels LevelTrace = iota LevelDebug LevelInfo @@ -16,6 +17,7 @@ const ( type loggerType func() LoggerInterface +// LoggerInterface defines the behavior of a log provider. type LoggerInterface interface { Init(config string) error WriteMsg(msg string, level int) error @@ -38,6 +40,8 @@ func Register(name string, log loggerType) { adapters[name] = log } +// BeeLogger is default logger in beego application. +// it can contain several providers and log message into all providers. type BeeLogger struct { lock sync.Mutex level int @@ -50,7 +54,9 @@ type logMsg struct { msg string } -// config need to be correct JSON as string: {"interval":360} +// NewLogger returns a new BeeLogger. +// channellen means the number of messages in chan. +// if the buffering chan is full, logger adapters write to file or other way. func NewLogger(channellen int64) *BeeLogger { bl := new(BeeLogger) bl.msg = make(chan *logMsg, channellen) @@ -60,6 +66,8 @@ func NewLogger(channellen int64) *BeeLogger { return bl } +// SetLogger provides a given logger adapter into BeeLogger with config string. +// config need to be correct JSON as string: {"interval":360}. func (bl *BeeLogger) SetLogger(adaptername string, config string) error { bl.lock.Lock() defer bl.lock.Unlock() @@ -73,6 +81,7 @@ func (bl *BeeLogger) SetLogger(adaptername string, config string) error { } } +// remove a logger adapter in BeeLogger. func (bl *BeeLogger) DelLogger(adaptername string) error { bl.lock.Lock() defer bl.lock.Unlock() @@ -96,10 +105,14 @@ func (bl *BeeLogger) writerMsg(loglevel int, msg string) error { return nil } +// set log message level. +// if message level (such as LevelTrace) is less than logger level (such as LevelWarn), ignore message. func (bl *BeeLogger) SetLevel(l int) { bl.level = l } +// start logger chan reading. +// when chan is full, write logs. func (bl *BeeLogger) StartLogger() { for { select { @@ -111,43 +124,50 @@ func (bl *BeeLogger) StartLogger() { } } +// log trace level message. func (bl *BeeLogger) Trace(format string, v ...interface{}) { msg := fmt.Sprintf("[T] "+format, v...) bl.writerMsg(LevelTrace, msg) } +// log debug level message. func (bl *BeeLogger) Debug(format string, v ...interface{}) { msg := fmt.Sprintf("[D] "+format, v...) bl.writerMsg(LevelDebug, msg) } +// log info level message. func (bl *BeeLogger) Info(format string, v ...interface{}) { msg := fmt.Sprintf("[I] "+format, v...) bl.writerMsg(LevelInfo, msg) } +// log warn level message. func (bl *BeeLogger) Warn(format string, v ...interface{}) { msg := fmt.Sprintf("[W] "+format, v...) bl.writerMsg(LevelWarn, msg) } +// log error level message. func (bl *BeeLogger) Error(format string, v ...interface{}) { msg := fmt.Sprintf("[E] "+format, v...) bl.writerMsg(LevelError, msg) } +// log critical level message. func (bl *BeeLogger) Critical(format string, v ...interface{}) { msg := fmt.Sprintf("[C] "+format, v...) bl.writerMsg(LevelCritical, msg) } -//flush all chan data +// flush all chan data. func (bl *BeeLogger) Flush() { for _, l := range bl.outputs { l.Flush() } } +// close logger, flush all chan data and destroy all adapters in BeeLogger. func (bl *BeeLogger) Close() { for { if len(bl.msg) > 0 { diff --git a/logs/smtp.go b/logs/smtp.go index 228977bb..19296887 100644 --- a/logs/smtp.go +++ b/logs/smtp.go @@ -12,7 +12,7 @@ const ( subjectPhrase = "Diagnostic message from server" ) -// smtpWriter is used to send emails via given SMTP-server. +// smtpWriter implements LoggerInterface and is used to send emails via given SMTP-server. type SmtpWriter struct { Username string `json:"Username"` Password string `json:"password"` @@ -22,10 +22,21 @@ type SmtpWriter struct { Level int `json:"level"` } +// create smtp writer. func NewSmtpWriter() LoggerInterface { return &SmtpWriter{Level: LevelTrace} } +// init smtp writer with json config. +// config like: +// { +// "Username":"example@gmail.com", +// "password:"password", +// "host":"smtp.gmail.com:465", +// "subject":"email title", +// "sendTos":["email1","email2"], +// "level":LevelError +// } func (s *SmtpWriter) Init(jsonconfig string) error { err := json.Unmarshal([]byte(jsonconfig), s) if err != nil { @@ -34,6 +45,8 @@ func (s *SmtpWriter) Init(jsonconfig string) error { return nil } +// write message in smtp writer. +// it will send an email with subject and only this message. func (s *SmtpWriter) WriteMsg(msg string, level int) error { if level < s.Level { return nil @@ -65,9 +78,12 @@ func (s *SmtpWriter) WriteMsg(msg string, level int) error { return err } +// implementing method. empty. func (s *SmtpWriter) Flush() { return } + +// implementing method. empty. func (s *SmtpWriter) Destroy() { return } From 383a04f4c2d0f8ac36a0748888f9aab8ebb0ab79 Mon Sep 17 00:00:00 2001 From: astaxie Date: Tue, 31 Dec 2013 00:34:47 +0800 Subject: [PATCH 07/46] move initmime from beego.Run to hookfunc --- beego.go | 5 ++--- mime.go | 3 ++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/beego.go b/beego.go index a7f3fd5a..acafa77b 100644 --- a/beego.go +++ b/beego.go @@ -18,6 +18,8 @@ var hooks []hookfunc //hook function slice to store the hookfunc func init() { hooks = make([]hookfunc, 0) + //init mime + AddAPPStartHook(initMime) } // Router adds a patterned controller handler to BeeApp. @@ -112,9 +114,6 @@ func Run() { } } - //init mime - initMime() - // do hooks function for _, hk := range hooks { err := hk() diff --git a/mime.go b/mime.go index 234ac9f3..97ed2449 100644 --- a/mime.go +++ b/mime.go @@ -544,8 +544,9 @@ var mimemaps map[string]string = map[string]string{ ".mustache": "text/html", } -func initMime() { +func initMime() error { for k, v := range mimemaps { mime.AddExtensionType(k, v) } + return nil } From 61c0b3e2860eb2ad23fc83a0e8850c1876178aac Mon Sep 17 00:00:00 2001 From: slene Date: Tue, 31 Dec 2013 09:55:29 +0800 Subject: [PATCH 08/46] fix db locked --- orm/db.go | 10 ++++++++++ orm/orm_raw.go | 11 ++++++++--- orm/orm_test.go | 3 ++- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/orm/db.go b/orm/db.go index 3454da79..782e3bc3 100644 --- a/orm/db.go +++ b/orm/db.go @@ -486,6 +486,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con rs = r } + defer rs.Close() + var ref interface{} args = make([]interface{}, 0) @@ -640,6 +642,8 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi refs[i] = &ref } + defer rs.Close() + slice := ind var cnt int64 @@ -1150,6 +1154,8 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond refs[i] = &ref } + defer rs.Close() + var ( cnt int64 columns []string @@ -1268,6 +1274,8 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { return tables, err } + defer rows.Close() + for rows.Next() { var table string err := rows.Scan(&table) @@ -1290,6 +1298,8 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e return columns, err } + defer rows.Close() + for rows.Next() { var ( name string diff --git a/orm/orm_raw.go b/orm/orm_raw.go index 7d204876..a713dbac 100644 --- a/orm/orm_raw.go +++ b/orm/orm_raw.go @@ -285,6 +285,8 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { return err } + defer rows.Close() + if rows.Next() { if structMode { columns, err := rows.Columns() @@ -408,11 +410,12 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { return 0, err } - nInds := make([]reflect.Value, len(sInds)) - - sInd := sInds[0] + defer rows.Close() var cnt int64 + nInds := make([]reflect.Value, len(sInds)) + sInd := sInds[0] + for rows.Next() { if structMode { @@ -538,6 +541,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) { rs = r } + defer rs.Close() + var ( refs []interface{} cnt int64 diff --git a/orm/orm_test.go b/orm/orm_test.go index 410aa484..bd4b6972 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -1392,7 +1392,8 @@ func TestQueryRows(t *testing.T) { var ids []int var usernames []string - num, err = dORM.Raw("SELECT id, user_name FROM user ORDER BY id asc").QueryRows(&ids, &usernames) + query = fmt.Sprintf("SELECT %sid%s, %suser_name%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q, Q, Q) + num, err = dORM.Raw(query).QueryRows(&ids, &usernames) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 3)) throwFailNow(t, AssertIs(len(ids), 3)) From 1e57587fe902edcb9c1554387365043d8c08cea5 Mon Sep 17 00:00:00 2001 From: astaxie Date: Tue, 31 Dec 2013 20:47:48 +0800 Subject: [PATCH 09/46] support Hijacker #428 --- router.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/router.go b/router.go index 5047497f..401fddbb 100644 --- a/router.go +++ b/router.go @@ -1,7 +1,10 @@ package beego import ( + "bufio" + "errors" "fmt" + "net" "net/http" "net/url" "os" @@ -864,3 +867,13 @@ func (w *responseWriter) WriteHeader(code int) { w.started = true w.writer.WriteHeader(code) } + +// hijacker for http +func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj, ok := w.writer.(http.Hijacker) + if !ok { + println("supported?") + return nil, nil, errors.New("webserver doesn't support hijacking") + } + return hj.Hijack() +} From 803d91c0773d4358bdbeeaa1f6a5526a6ed5aa2a Mon Sep 17 00:00:00 2001 From: astaxie Date: Tue, 31 Dec 2013 23:43:15 +0800 Subject: [PATCH 10/46] support modules design! // the follow code is write in modules: // GR:=NewGroupRouters() // GR.AddRouter("/login",&UserController,"get:Login") // GR.AddRouter("/logout",&UserController,"get:Logout") // GR.AddRouter("/register",&UserController,"get:Reg") // the follow code is write in app: // import "github.com/beego/modules/auth" // AddRouterGroup("/admin", auth.GR) --- beego.go | 64 ++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 4 deletions(-) diff --git a/beego.go b/beego.go index acafa77b..b3ce46e1 100644 --- a/beego.go +++ b/beego.go @@ -16,10 +16,60 @@ const VERSION = "1.0.1" type hookfunc func() error //hook function to run var hooks []hookfunc //hook function slice to store the hookfunc -func init() { - hooks = make([]hookfunc, 0) - //init mime - AddAPPStartHook(initMime) +type groupRouter struct { + pattern string + controller ControllerInterface + mappingMethods string +} + +// RouterGroups which will store routers +type GroupRouters []groupRouter + +// Get a new GroupRouters +func NewGroupRouters() GroupRouters { + return make([]groupRouter, 0) +} + +// Add Router in the GroupRouters +// it is for plugin or module to register router +func (gr GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingMethod ...string) { + var newRG groupRouter + if len(mappingMethod) > 0 { + newRG = groupRouter{ + pattern, + c, + mappingMethod[0], + } + } else { + newRG = groupRouter{ + pattern, + c, + "", + } + } + gr = append(gr, newRG) +} + +// AddGroupRouter with the prefix +// it will register the router in BeeApp +// the follow code is write in modules: +// GR:=NewGroupRouters() +// GR.AddRouter("/login",&UserController,"get:Login") +// GR.AddRouter("/logout",&UserController,"get:Logout") +// GR.AddRouter("/register",&UserController,"get:Reg") +// the follow code is write in app: +// import "github.com/beego/modules/auth" +// AddRouterGroup("/admin", auth.GR) +func AddGroupRouter(prefix string, groups GroupRouters) *App { + for _, v := range groups { + if v.mappingMethods != "" { + BeeApp.Router(prefix+v.pattern, v.controller, v.mappingMethods) + } else { + BeeApp.Router(prefix+v.pattern, v.controller) + } + + } + return BeeApp } // Router adds a patterned controller handler to BeeApp. @@ -151,3 +201,9 @@ func Run() { BeeApp.Run() } + +func init() { + hooks = make([]hookfunc, 0) + //init mime + AddAPPStartHook(initMime) +} From d57557dc554d7967f35762830b2d35b0451f94b4 Mon Sep 17 00:00:00 2001 From: astaxie Date: Wed, 1 Jan 2014 17:57:57 +0800 Subject: [PATCH 11/46] add AutoRouterWithPrefix --- app.go | 8 ++++++++ beego.go | 20 +++++++++++++++++++- router.go | 39 ++++++++++++++++++++++++++++++++++++--- router_test.go | 12 ++++++++++++ 4 files changed, 75 insertions(+), 4 deletions(-) diff --git a/app.go b/app.go index da18718e..f252367b 100644 --- a/app.go +++ b/app.go @@ -118,6 +118,14 @@ func (app *App) AutoRouter(c ControllerInterface) *App { return app } +// AutoRouterWithPrefix adds beego-defined controller handler with prefix. +// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page, +// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function. +func (app *App) AutoRouterWithPrefix(prefix string, c ControllerInterface) *App { + app.Handlers.AddAutoPrefix(prefix, c) + return app +} + // UrlFor creates a url with another registered controller handler with params. // The endpoint is formed as path.controller.name to defined the controller method which will run. // The values need key-pair data to assign into controller method. diff --git a/beego.go b/beego.go index b3ce46e1..fd61a942 100644 --- a/beego.go +++ b/beego.go @@ -50,6 +50,15 @@ func (gr GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingM gr = append(gr, newRG) } +func (gr GroupRouters) AddAuto(c ControllerInterface) { + newRG := groupRouter{ + "", + c, + "", + } + gr = append(gr, newRG) +} + // AddGroupRouter with the prefix // it will register the router in BeeApp // the follow code is write in modules: @@ -62,7 +71,9 @@ func (gr GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingM // AddRouterGroup("/admin", auth.GR) func AddGroupRouter(prefix string, groups GroupRouters) *App { for _, v := range groups { - if v.mappingMethods != "" { + if v.pattern == "" { + BeeApp.AutoRouterWithPrefix(prefix, v.controller) + } else if v.mappingMethods != "" { BeeApp.Router(prefix+v.pattern, v.controller, v.mappingMethods) } else { BeeApp.Router(prefix+v.pattern, v.controller) @@ -95,6 +106,13 @@ func AutoRouter(c ControllerInterface) *App { return BeeApp } +// AutoPrefix adds controller handler to BeeApp with prefix. +// it's same to App.AutoRouterWithPrefix. +func AutoPrefix(prefix string, c ControllerInterface) *App { + BeeApp.AutoRouterWithPrefix(prefix, c) + return BeeApp +} + // ErrorHandler registers http.HandlerFunc to each http err code string. // usage: // beego.ErrorHandler("404",NotFound) diff --git a/router.go b/router.go index 401fddbb..f7c92ce7 100644 --- a/router.go +++ b/router.go @@ -33,6 +33,14 @@ const ( var ( // supported http methods. HTTPMETHOD = []string{"get", "post", "put", "delete", "patch", "options", "head"} + // these beego.Controller's methods shouldn't reflect to AutoRouter + exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString", + "RenderBytes", "Redirect", "Abort", "StopRun", "UrlFor", "ServeJson", "ServeJsonp", + "ServeXml", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool", + "GetFloat", "GetFile", "SaveToFile", "StartSession", "SetSession", "GetSession", + "DelSession", "SessionRegenerateID", "DestroySession", "IsAjax", "GetSecureCookie", + "SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml", + "GetControllerAndAction"} ) type controllerInfo struct { @@ -221,8 +229,8 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM // Add auto router to ControllerRegistor. // example beego.AddAuto(&MainContorlller{}), // MainController has method List and Page. -// visit the url /main/list to exec List function -// /main/page to exec Page function. +// visit the url /main/list to execute List function +// /main/page to execute Page function. func (p *ControllerRegistor) AddAuto(c ControllerInterface) { p.enableAuto = true reflectVal := reflect.ValueOf(c) @@ -235,7 +243,32 @@ func (p *ControllerRegistor) AddAuto(c ControllerInterface) { p.autoRouter[firstParam] = make(map[string]reflect.Type) } for i := 0; i < rt.NumMethod(); i++ { - p.autoRouter[firstParam][rt.Method(i).Name] = ct + if !utils.InSlice(rt.Method(i).Name, exceptMethod) { + p.autoRouter[firstParam][rt.Method(i).Name] = ct + } + } +} + +// Add auto router to ControllerRegistor with prefix. +// example beego.AddAutoPrefix("/admin",&MainContorlller{}), +// MainController has method List and Page. +// visit the url /admin/main/list to execute List function +// /admin/main/page to execute Page function. +func (p *ControllerRegistor) AddAutoPrefix(prefix string, c ControllerInterface) { + p.enableAuto = true + reflectVal := reflect.ValueOf(c) + rt := reflectVal.Type() + ct := reflect.Indirect(reflectVal).Type() + firstParam := strings.Trim(prefix, "/") + "/" + strings.ToLower(strings.TrimSuffix(ct.Name(), "Controller")) + if _, ok := p.autoRouter[firstParam]; ok { + return + } else { + p.autoRouter[firstParam] = make(map[string]reflect.Type) + } + for i := 0; i < rt.NumMethod(); i++ { + if !utils.InSlice(rt.Method(i).Name, exceptMethod) { + p.autoRouter[firstParam][rt.Method(i).Name] = ct + } } } diff --git a/router_test.go b/router_test.go index a79ab5b4..c1a7f213 100644 --- a/router_test.go +++ b/router_test.go @@ -198,3 +198,15 @@ func TestPrepare(t *testing.T) { t.Errorf(w.Body.String() + "user define func can't run") } } + +func TestAutoPrefix(t *testing.T) { + r, _ := http.NewRequest("GET", "/admin/test/list", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegistor() + handler.AddAutoPrefix("/admin", &TestController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("TestAutoPrefix can't run") + } +} From 480aa521e59b29e569b1b0e99d4978f7cc22699f Mon Sep 17 00:00:00 2001 From: astaxie Date: Wed, 1 Jan 2014 20:50:06 +0800 Subject: [PATCH 12/46] fix #430 --- cache/cache_test.go | 50 ++++++++++++++++++++++++++++++++++++++++++++- cache/file.go | 10 ++++----- 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/cache/cache_test.go b/cache/cache_test.go index cb0fc76c..bc484e0f 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -5,7 +5,7 @@ import ( "time" ) -func Test_cache(t *testing.T) { +func TestCache(t *testing.T) { bm, err := NewCache("memory", `{"interval":20}`) if err != nil { t.Error("init err") @@ -51,3 +51,51 @@ func Test_cache(t *testing.T) { t.Error("delete err") } } + +func TestFileCache(t *testing.T) { + bm, err := NewCache("file", `{"CachePath":"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0}`) + if err != nil { + t.Error("init err") + } + if err = bm.Put("astaxie", 1, 10); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + + if err = bm.Incr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 2 { + t.Error("get err") + } + + if err = bm.Decr("astaxie"); err != nil { + t.Error("Incr Error", err) + } + + if v := bm.Get("astaxie"); v.(int) != 1 { + t.Error("get err") + } + bm.Delete("astaxie") + if bm.IsExist("astaxie") { + t.Error("delete err") + } + //test string + if err = bm.Put("astaxie", "author", 10); err != nil { + t.Error("set Error", err) + } + if !bm.IsExist("astaxie") { + t.Error("check err") + } + + if v := bm.Get("astaxie"); v.(string) != "author" { + t.Error("get err") + } +} diff --git a/cache/file.go b/cache/file.go index 3807fa7c..410da3a0 100644 --- a/cache/file.go +++ b/cache/file.go @@ -61,6 +61,7 @@ func (this *FileCache) StartAndGC(config string) error { var cfg map[string]string json.Unmarshal([]byte(config), &cfg) //fmt.Println(cfg) + //fmt.Println(config) if _, ok := cfg["CachePath"]; !ok { cfg["CachePath"] = FileCachePath } @@ -135,7 +136,7 @@ func (this *FileCache) Get(key string) interface{} { return "" } var to FileCacheItem - Gob_decode([]byte(filedata), &to) + Gob_decode(filedata, &to) if to.Expired < time.Now().Unix() { return "" } @@ -177,7 +178,7 @@ func (this *FileCache) Delete(key string) error { func (this *FileCache) Incr(key string) error { data := this.Get(key) var incr int - fmt.Println(reflect.TypeOf(data).Name()) + //fmt.Println(reflect.TypeOf(data).Name()) if reflect.TypeOf(data).Name() != "int" { incr = 0 } else { @@ -210,8 +211,7 @@ func (this *FileCache) IsExist(key string) bool { // Clean cached files. // not implemented. func (this *FileCache) ClearAll() error { - //this.CachePath .递归删除 - + //this.CachePath return nil } @@ -271,7 +271,7 @@ func Gob_encode(data interface{}) ([]byte, error) { } // Gob decodes file cache item. -func Gob_decode(data []byte, to interface{}) error { +func Gob_decode(data []byte, to *FileCacheItem) error { buf := bytes.NewBuffer(data) dec := gob.NewDecoder(buf) return dec.Decode(&to) From f5cf2876ddebfc387482e076a598ca9f013a50e0 Mon Sep 17 00:00:00 2001 From: Scott Merkling Date: Thu, 2 Jan 2014 09:53:09 -0500 Subject: [PATCH 13/46] Improved the language on the error pages --- middleware/error.go | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/middleware/error.go b/middleware/error.go index 7f854141..c86be159 100644 --- a/middleware/error.go +++ b/middleware/error.go @@ -185,9 +185,9 @@ func NotFound(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) data := make(map[string]interface{}) data["Title"] = "Page Not Found" - data["Content"] = template.HTML("
The Page You have requested flown the coop." + + data["Content"] = template.HTML("
The page you have requested has flown the coop." + "
Perhaps you are here because:" + - "

    " + + "
      " + "
      The page has moved" + "
      The page no longer exists" + "
      You were looking for your puppy and got lost" + @@ -203,11 +203,11 @@ func Unauthorized(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) data := make(map[string]interface{}) data["Title"] = "Unauthorized" - data["Content"] = template.HTML("
      The Page You have requested can't authorized." + + data["Content"] = template.HTML("
      The page you have requested can't be authorized." + "
      Perhaps you are here because:" + "

        " + - "
        Check the credentials that you supplied" + - "
        Check the address for errors" + + "
        The credentials you supplied are incorrect" + + "
        There are errors in the website address" + "
      ") data["BeegoVersion"] = VERSION //rw.WriteHeader(http.StatusUnauthorized) @@ -219,7 +219,7 @@ func Forbidden(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) data := make(map[string]interface{}) data["Title"] = "Forbidden" - data["Content"] = template.HTML("
      The Page You have requested forbidden." + + data["Content"] = template.HTML("
      The page you have requested is forbidden." + "
      Perhaps you are here because:" + "

        " + "
        Your address may be blocked" + @@ -236,7 +236,7 @@ func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) data := make(map[string]interface{}) data["Title"] = "Service Unavailable" - data["Content"] = template.HTML("
        The Page You have requested unavailable." + + data["Content"] = template.HTML("
        The page you have requested is unavailable." + "
        Perhaps you are here because:" + "

          " + "

          The page is overloaded" + @@ -252,11 +252,10 @@ func InternalServerError(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) data := make(map[string]interface{}) data["Title"] = "Internal Server Error" - data["Content"] = template.HTML("
          The Page You have requested has down now." + + data["Content"] = template.HTML("
          The page you have requested is down right now." + "

            " + - "
            simply try again later" + - "
            you should report the fault to the website administrator" + - "
          ") + "
          Please try again later and report the error to the website administrator" + + "
        ") data["BeegoVersion"] = VERSION //rw.WriteHeader(http.StatusInternalServerError) t.Execute(rw, data) From c433b7029f938991327fbf11334e3e09b256cc0d Mon Sep 17 00:00:00 2001 From: Scott Merkling Date: Thu, 2 Jan 2014 09:54:15 -0500 Subject: [PATCH 14/46] added back a
        --- middleware/error.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/middleware/error.go b/middleware/error.go index c86be159..ebfe2288 100644 --- a/middleware/error.go +++ b/middleware/error.go @@ -187,7 +187,7 @@ func NotFound(rw http.ResponseWriter, r *http.Request) { data["Title"] = "Page Not Found" data["Content"] = template.HTML("
        The page you have requested has flown the coop." + "
        Perhaps you are here because:" + - "
          " + + "

            " + "
            The page has moved" + "
            The page no longer exists" + "
            You were looking for your puppy and got lost" + From ef79a2b4846c43a8801b4d8c83d366a6d8e74869 Mon Sep 17 00:00:00 2001 From: slene Date: Sat, 4 Jan 2014 00:04:15 +0800 Subject: [PATCH 15/46] fix #440 --- orm/db.go | 2 ++ orm/db_tables.go | 11 ++++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/orm/db.go b/orm/db.go index 782e3bc3..66ae498e 100644 --- a/orm/db.go +++ b/orm/db.go @@ -461,6 +461,8 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz * func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { tables := newDbTables(mi, d.ins) + tables.skipEnd = true + if qs != nil { tables.parseRelated(qs.related, qs.relDepth) } diff --git a/orm/db_tables.go b/orm/db_tables.go index 1ab47cdb..972077c2 100644 --- a/orm/db_tables.go +++ b/orm/db_tables.go @@ -23,6 +23,7 @@ type dbTables struct { tables []*dbTable mi *modelInfo base dbBaser + skipEnd bool } func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { @@ -221,9 +222,13 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string inner = false } - jt, _ := t.add(names, mmi, fi, inner) - jt.jtl = jtl - jtl = jt + if num == i && t.skipEnd { + } else { + jt, _ := t.add(names, mmi, fi, inner) + jt.jtl = jtl + jtl = jt + } + } if num == i { From 95c65de97cb56f42373e7285424ae4c5ec5d2e32 Mon Sep 17 00:00:00 2001 From: slene Date: Sat, 4 Jan 2014 22:30:17 +0800 Subject: [PATCH 16/46] fix #440 --- orm/db_tables.go | 88 ++++++++++++++++++++++++++++++++---------------- orm/orm_test.go | 26 ++++++++++++++ 2 files changed, 85 insertions(+), 29 deletions(-) diff --git a/orm/db_tables.go b/orm/db_tables.go index 972077c2..5a78cf21 100644 --- a/orm/db_tables.go +++ b/orm/db_tables.go @@ -112,7 +112,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) { names = append(names, fi.name) mmi = fi.relModelInfo - if fi.null { + if fi.null || t.skipEnd { inner = false } @@ -189,6 +189,8 @@ func (t *dbTables) getJoinSql() (join string) { func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { var ( jtl *dbTable + fi *fieldInfo + fiN *fieldInfo mmi = mi ) @@ -197,9 +199,24 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string inner := true +loopFor: for i, ex := range exprs { - fi, ok := mmi.fields.GetByAny(ex) + var ok, okN bool + + if fiN != nil { + fi = fiN + ok = true + fiN = nil + } + + if i == 0 { + fi, ok = mmi.fields.GetByAny(ex) + } + + // fmt.Println(ex, fi.name, fiN) + + _ = okN if ok { @@ -217,13 +234,20 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string mmi = fi.reverseFieldInfo.mi } + if i < num { + fiN, okN = mmi.fields.GetByAny(exprs[i+1]) + } + if isRel && (fi.mi.isThrough == false || num != i) { - if fi.null { + if fi.null || t.skipEnd { inner = false } - if num == i && t.skipEnd { - } else { + if t.skipEnd && okN || !t.skipEnd { + if t.skipEnd && okN && fiN.pk { + goto loopEnd + } + jt, _ := t.add(names, mmi, fi, inner) jt.jtl = jtl jtl = jt @@ -231,34 +255,40 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string } - if num == i { - if i == 0 || jtl == nil { - index = "T0" - } else { + if num != i { + continue + } + + loopEnd: + + if i == 0 || jtl == nil { + index = "T0" + } else { + index = jtl.index + } + + info = fi + + if jtl == nil { + name = fi.name + } else { + name = jtl.name + ExprSep + fi.name + } + + switch { + case fi.rel: + + case fi.reverse: + switch fi.reverseFieldInfo.fieldType { + case RelOneToOne, RelForeignKey: index = jtl.index - } - - info = fi - - if jtl == nil { - name = fi.name - } else { - name = jtl.name + ExprSep + fi.name - } - - switch { - case fi.rel: - - case fi.reverse: - switch fi.reverseFieldInfo.fieldType { - case RelOneToOne, RelForeignKey: - index = jtl.index - info = fi.reverseFieldInfo.mi.fields.pk - name = info.name - } + info = fi.reverseFieldInfo.mi.fields.pk + name = info.name } } + break loopFor + } else { index = "" name = "" diff --git a/orm/orm_test.go b/orm/orm_test.go index bd4b6972..f5101811 100644 --- a/orm/orm_test.go +++ b/orm/orm_test.go @@ -1561,6 +1561,32 @@ func TestDelete(t *testing.T) { num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count() throwFail(t, err) throwFail(t, AssertIs(num, 1)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 6)) + + qs = dORM.QueryTable("post") + num, err = qs.Filter("Id", 3).Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 4)) + + fmt.Println("...") + qs = dORM.QueryTable("comment") + num, err = qs.Filter("Post__User", 3).Delete() + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + + qs = dORM.QueryTable("comment") + num, err = qs.Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) } func TestTransaction(t *testing.T) { From 481448fa90dbceb9554d17b69d293ee0282c6674 Mon Sep 17 00:00:00 2001 From: astaxie Date: Sun, 5 Jan 2014 14:48:36 +0800 Subject: [PATCH 17/46] modify session module change a log --- session/README.md | 35 +++++--- session/sess_cookie.go | 143 +++++++++++++++++++++++++++++ session/sess_file.go | 3 +- session/sess_gob.go | 38 -------- session/sess_mem.go | 10 +-- session/sess_mem_test.go | 35 ++++++++ session/sess_mysql.go | 3 +- session/sess_redis.go | 3 +- session/sess_test.go | 81 +++++++++++++++++ session/sess_utils.go | 188 +++++++++++++++++++++++++++++++++++++++ session/session.go | 154 +++++++++++++++----------------- 11 files changed, 551 insertions(+), 142 deletions(-) create mode 100644 session/sess_cookie.go delete mode 100644 session/sess_gob.go create mode 100644 session/sess_mem_test.go create mode 100644 session/sess_utils.go diff --git a/session/README.md b/session/README.md index 220100ef..2ebf069a 100644 --- a/session/README.md +++ b/session/README.md @@ -28,21 +28,21 @@ Then in you web app init the global session manager * Use **memory** as provider: func init() { - globalSessions, _ = session.NewManager("memory", "gosessionid", 3600,"") + globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`) go globalSessions.GC() } * Use **file** as provider, the last param is the path where you want file to be stored: func init() { - globalSessions, _ = session.NewManager("file", "gosessionid", 3600, "./tmp") + globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","./tmp"}`) go globalSessions.GC() } * Use **Redis** as provider, the last param is the Redis conn address,poolsize,password: func init() { - globalSessions, _ = session.NewManager("redis", "gosessionid", 3600, "127.0.0.1:6379,100,astaxie") + globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","127.0.0.1:6379,100,astaxie"}`) go globalSessions.GC() } @@ -50,15 +50,24 @@ Then in you web app init the global session manager func init() { globalSessions, _ = session.NewManager( - "mysql", "gosessionid", 3600, "username:password@protocol(address)/dbname?param=value") + "mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","username:password@protocol(address)/dbname?param=value"}`) go globalSessions.GC() } +* Use **Cookie** as provider: + + func init() { + globalSessions, _ = session.NewManager( + "cookie", `{"cookieName":"gosessionid","enableSetCookie":false,gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`) + go globalSessions.GC() + } + + Finally in the handlerfunc you can use it like this func login(w http.ResponseWriter, r *http.Request) { sess := globalSessions.SessionStart(w, r) - defer sess.SessionRelease() + defer sess.SessionRelease(w) username := sess.Get("username") fmt.Println(username) if r.Method == "GET" { @@ -78,19 +87,19 @@ When you develop a web app, maybe you want to write own provider because you mus Writing a provider is easy. You only need to define two struct types (Session and Provider), which satisfy the interface definition. -Maybe you will find the **memory** provider as good example. +Maybe you will find the **memory** provider is a good example. type SessionStore interface { - Set(key, value interface{}) error //set session value - Get(key interface{}) interface{} //get session value - Delete(key interface{}) error //delete session value - SessionID() string //back current sessionID - SessionRelease() // release the resource & save data to provider - Flush() error //delete all data + Set(key, value interface{}) error //set session value + Get(key interface{}) interface{} //get session value + Delete(key interface{}) error //delete session value + SessionID() string //back current sessionID + SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data + Flush() error //delete all data } type Provider interface { - SessionInit(maxlifetime int64, savePath string) error + SessionInit(gclifetime int64, config string) error SessionRead(sid string) (SessionStore, error) SessionExist(sid string) bool SessionRegenerate(oldsid, sid string) (SessionStore, error) diff --git a/session/sess_cookie.go b/session/sess_cookie.go new file mode 100644 index 00000000..deff70a0 --- /dev/null +++ b/session/sess_cookie.go @@ -0,0 +1,143 @@ +package session + +import ( + "crypto/aes" + "crypto/cipher" + "encoding/json" + "net/http" + "net/url" + "sync" +) + +var cookiepder = &CookieProvider{} + +type CookieSessionStore struct { + sid string + values map[interface{}]interface{} //session data + lock sync.RWMutex +} + +func (st *CookieSessionStore) Set(key, value interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + st.values[key] = value + return nil +} + +func (st *CookieSessionStore) Get(key interface{}) interface{} { + st.lock.RLock() + defer st.lock.RUnlock() + if v, ok := st.values[key]; ok { + return v + } else { + return nil + } + return nil +} + +func (st *CookieSessionStore) Delete(key interface{}) error { + st.lock.Lock() + defer st.lock.Unlock() + delete(st.values, key) + return nil +} + +func (st *CookieSessionStore) Flush() error { + st.lock.Lock() + defer st.lock.Unlock() + st.values = make(map[interface{}]interface{}) + return nil +} + +func (st *CookieSessionStore) SessionID() string { + return st.sid +} + +func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) { + str, err := encodeCookie(cookiepder.block, + cookiepder.config.SecurityKey, + cookiepder.config.SecurityName, + st.values) + if err != nil { + return + } + cookie := &http.Cookie{Name: cookiepder.config.CookieName, + Value: url.QueryEscape(str), + Path: "/", + HttpOnly: true, + Secure: cookiepder.config.Secure} + http.SetCookie(w, cookie) + return +} + +type cookieConfig struct { + SecurityKey string `json:"securityKey"` + BlockKey string `json:"blockKey"` + SecurityName string `json:"securityName"` + CookieName string `json:"cookieName"` + Secure bool `json:"secure"` + Maxage int `json:"maxage"` +} + +type CookieProvider struct { + maxlifetime int64 + config *cookieConfig + block cipher.Block +} + +func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error { + pder.config = &cookieConfig{} + err := json.Unmarshal([]byte(config), pder.config) + if err != nil { + return err + } + if pder.config.BlockKey == "" { + pder.config.BlockKey = string(generateRandomKey(16)) + } + if pder.config.SecurityName == "" { + pder.config.SecurityName = string(generateRandomKey(20)) + } + pder.block, err = aes.NewCipher([]byte(pder.config.BlockKey)) + if err != nil { + return err + } + return nil +} + +func (pder *CookieProvider) SessionRead(sid string) (SessionStore, error) { + kv := make(map[interface{}]interface{}) + kv, _ = decodeCookie(pder.block, + pder.config.SecurityKey, + pder.config.SecurityName, + sid, pder.maxlifetime) + rs := &CookieSessionStore{sid: sid, values: kv} + return rs, nil +} + +func (pder *CookieProvider) SessionExist(sid string) bool { + return true +} + +func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) { + return nil, nil +} + +func (pder *CookieProvider) SessionDestroy(sid string) error { + return nil +} + +func (pder *CookieProvider) SessionGC() { + return +} + +func (pder *CookieProvider) SessionAll() int { + return 0 +} + +func (pder *CookieProvider) SessionUpdate(sid string) error { + return nil +} + +func init() { + Register("cookie", cookiepder) +} diff --git a/session/sess_file.go b/session/sess_file.go index 1db4022e..5d33d0e2 100644 --- a/session/sess_file.go +++ b/session/sess_file.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "io/ioutil" + "net/http" "os" "path" "path/filepath" @@ -60,7 +61,7 @@ func (fs *FileSessionStore) SessionID() string { return fs.sid } -func (fs *FileSessionStore) SessionRelease() { +func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { defer fs.f.Close() b, err := encodeGob(fs.values) if err != nil { diff --git a/session/sess_gob.go b/session/sess_gob.go deleted file mode 100644 index 92313947..00000000 --- a/session/sess_gob.go +++ /dev/null @@ -1,38 +0,0 @@ -package session - -import ( - "bytes" - "encoding/gob" -) - -func init() { - gob.Register([]interface{}{}) - gob.Register(map[int]interface{}{}) - gob.Register(map[string]interface{}{}) - gob.Register(map[interface{}]interface{}{}) - gob.Register(map[string]string{}) - gob.Register(map[int]string{}) - gob.Register(map[int]int{}) - gob.Register(map[int]int64{}) -} - -func encodeGob(obj map[interface{}]interface{}) ([]byte, error) { - buf := bytes.NewBuffer(nil) - enc := gob.NewEncoder(buf) - err := enc.Encode(obj) - if err != nil { - return []byte(""), err - } - return buf.Bytes(), nil -} - -func decodeGob(encoded []byte) (map[interface{}]interface{}, error) { - buf := bytes.NewBuffer(encoded) - dec := gob.NewDecoder(buf) - var out map[interface{}]interface{} - err := dec.Decode(&out) - if err != nil { - return nil, err - } - return out, nil -} diff --git a/session/sess_mem.go b/session/sess_mem.go index 2e615c6f..c74c2602 100644 --- a/session/sess_mem.go +++ b/session/sess_mem.go @@ -2,6 +2,7 @@ package session import ( "container/list" + "net/http" "sync" "time" ) @@ -9,9 +10,9 @@ import ( var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)} type MemSessionStore struct { - sid string //session id唯一标示 - timeAccessed time.Time //最后访问时间 - value map[interface{}]interface{} //session里面存储的值 + sid string //session id + timeAccessed time.Time //last access time + value map[interface{}]interface{} //session store lock sync.RWMutex } @@ -51,8 +52,7 @@ func (st *MemSessionStore) SessionID() string { return st.sid } -func (st *MemSessionStore) SessionRelease() { - +func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) { } type MemProvider struct { diff --git a/session/sess_mem_test.go b/session/sess_mem_test.go new file mode 100644 index 00000000..df2a9a1e --- /dev/null +++ b/session/sess_mem_test.go @@ -0,0 +1,35 @@ +package session + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestMem(t *testing.T) { + globalSessions, _ := NewManager("memory", `{"cookieName":"gosessionid","gclifetime":10}`) + go globalSessions.GC() + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + sess := globalSessions.SessionStart(w, r) + defer sess.SessionRelease(w) + err := sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set error,", err) + } + if username := sess.Get("username"); username != "astaxie" { + t.Fatal("get username error") + } + if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { + t.Fatal("setcookie error") + } else { + parts := strings.Split(strings.TrimSpace(cookiestr), ";") + for k, v := range parts { + nameval := strings.Split(v, "=") + if k == 0 && nameval[0] != "gosessionid" { + t.Fatal("error") + } + } + } +} diff --git a/session/sess_mysql.go b/session/sess_mysql.go index f1a59d4f..1101e437 100644 --- a/session/sess_mysql.go +++ b/session/sess_mysql.go @@ -9,6 +9,7 @@ package session import ( "database/sql" + "net/http" "sync" "time" @@ -60,7 +61,7 @@ func (st *MysqlSessionStore) SessionID() string { return st.sid } -func (st *MysqlSessionStore) SessionRelease() { +func (st *MysqlSessionStore) SessionRelease(w http.ResponseWriter) { defer st.c.Close() if len(st.values) > 0 { b, err := encodeGob(st.values) diff --git a/session/sess_redis.go b/session/sess_redis.go index e582c6ed..0f8c0308 100644 --- a/session/sess_redis.go +++ b/session/sess_redis.go @@ -1,6 +1,7 @@ package session import ( + "net/http" "strconv" "strings" "sync" @@ -58,7 +59,7 @@ func (rs *RedisSessionStore) SessionID() string { return rs.sid } -func (rs *RedisSessionStore) SessionRelease() { +func (rs *RedisSessionStore) SessionRelease(w http.ResponseWriter) { defer rs.c.Close() if len(rs.values) > 0 { b, err := encodeGob(rs.values) diff --git a/session/sess_test.go b/session/sess_test.go index b7d0b38c..d754b526 100644 --- a/session/sess_test.go +++ b/session/sess_test.go @@ -1,6 +1,8 @@ package session import ( + "crypto/aes" + "encoding/json" "testing" ) @@ -26,3 +28,82 @@ func Test_gob(t *testing.T) { t.Error("decode int error") } } + +func TestGenerate(t *testing.T) { + str := generateRandomKey(20) + if len(str) != 20 { + t.Fatal("generate length is not equal to 20") + } +} + +func TestCookieEncodeDecode(t *testing.T) { + hashKey := "testhashKey" + blockkey := generateRandomKey(16) + block, err := aes.NewCipher(blockkey) + if err != nil { + t.Fatal("NewCipher:", err) + } + securityName := string(generateRandomKey(20)) + val := make(map[interface{}]interface{}) + val["name"] = "astaxie" + val["gender"] = "male" + str, err := encodeCookie(block, hashKey, securityName, val) + if err != nil { + t.Fatal("encodeCookie:", err) + } + dst := make(map[interface{}]interface{}) + dst, err = decodeCookie(block, hashKey, securityName, str, 3600) + if err != nil { + t.Fatal("decodeCookie", err) + } + if dst["name"] != "astaxie" { + t.Fatal("dst get map error") + } + if dst["gender"] != "male" { + t.Fatal("dst get map error") + } +} + +func TestParseConfig(t *testing.T) { + s := `{"cookieName":"gosessionid","gclifetime":3600}` + cf := new(managerConfig) + cf.EnableSetCookie = true + err := json.Unmarshal([]byte(s), cf) + if err != nil { + t.Fatal("parse json error,", err) + } + if cf.CookieName != "gosessionid" { + t.Fatal("parseconfig get cookiename error") + } + if cf.Gclifetime != 3600 { + t.Fatal("parseconfig get gclifetime error") + } + + cc := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + cf2 := new(managerConfig) + cf2.EnableSetCookie = true + err = json.Unmarshal([]byte(cc), cf2) + if err != nil { + t.Fatal("parse json error,", err) + } + if cf2.CookieName != "gosessionid" { + t.Fatal("parseconfig get cookiename error") + } + if cf2.Gclifetime != 3600 { + t.Fatal("parseconfig get gclifetime error") + } + if cf2.EnableSetCookie != false { + t.Fatal("parseconfig get enableSetCookie error") + } + cconfig := new(cookieConfig) + err = json.Unmarshal([]byte(cf2.ProviderConfig), cconfig) + if err != nil { + t.Fatal("parse ProviderConfig err,", err) + } + if cconfig.CookieName != "gosessionid" { + t.Fatal("ProviderConfig get cookieName error") + } + if cconfig.SecurityKey != "beegocookiehashkey" { + t.Fatal("ProviderConfig get securityKey error") + } +} diff --git a/session/sess_utils.go b/session/sess_utils.go new file mode 100644 index 00000000..73f96630 --- /dev/null +++ b/session/sess_utils.go @@ -0,0 +1,188 @@ +package session + +import ( + "bytes" + "crypto/cipher" + "crypto/hmac" + "crypto/rand" + "crypto/sha1" + "crypto/subtle" + "encoding/base64" + "encoding/gob" + "errors" + "fmt" + "io" + "strconv" + "time" +) + +func init() { + gob.Register([]interface{}{}) + gob.Register(map[int]interface{}{}) + gob.Register(map[string]interface{}{}) + gob.Register(map[interface{}]interface{}{}) + gob.Register(map[string]string{}) + gob.Register(map[int]string{}) + gob.Register(map[int]int{}) + gob.Register(map[int]int64{}) +} + +func encodeGob(obj map[interface{}]interface{}) ([]byte, error) { + buf := bytes.NewBuffer(nil) + enc := gob.NewEncoder(buf) + err := enc.Encode(obj) + if err != nil { + return []byte(""), err + } + return buf.Bytes(), nil +} + +func decodeGob(encoded []byte) (map[interface{}]interface{}, error) { + buf := bytes.NewBuffer(encoded) + dec := gob.NewDecoder(buf) + var out map[interface{}]interface{} + err := dec.Decode(&out) + if err != nil { + return nil, err + } + return out, nil +} + +// generateRandomKey creates a random key with the given strength. +func generateRandomKey(strength int) []byte { + k := make([]byte, strength) + if _, err := io.ReadFull(rand.Reader, k); err != nil { + return nil + } + return k +} + +// Encryption ----------------------------------------------------------------- + +// encrypt encrypts a value using the given block in counter mode. +// +// A random initialization vector (http://goo.gl/zF67k) with the length of the +// block size is prepended to the resulting ciphertext. +func encrypt(block cipher.Block, value []byte) ([]byte, error) { + iv := generateRandomKey(block.BlockSize()) + if iv == nil { + return nil, errors.New("encrypt: failed to generate random iv") + } + // Encrypt it. + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(value, value) + // Return iv + ciphertext. + return append(iv, value...), nil +} + +// decrypt decrypts a value using the given block in counter mode. +// +// The value to be decrypted must be prepended by a initialization vector +// (http://goo.gl/zF67k) with the length of the block size. +func decrypt(block cipher.Block, value []byte) ([]byte, error) { + size := block.BlockSize() + if len(value) > size { + // Extract iv. + iv := value[:size] + // Extract ciphertext. + value = value[size:] + // Decrypt it. + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(value, value) + return value, nil + } + return nil, errors.New("decrypt: the value could not be decrypted") +} + +func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) { + var err error + var b []byte + // 1. encodeGob. + if b, err = encodeGob(value); err != nil { + return "", err + } + // 2. Encrypt (optional). + if b, err = encrypt(block, b); err != nil { + return "", err + } + b = encode(b) + // 3. Create MAC for "name|date|value". Extra pipe to be used later. + b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b)) + h := hmac.New(sha1.New, []byte(hashKey)) + h.Write(b) + sig := h.Sum(nil) + // Append mac, remove name. + b = append(b, sig...)[len(name)+1:] + // 4. Encode to base64. + b = encode(b) + // Done. + return string(b), nil +} + +func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) { + // 1. Decode from base64. + b, err := decode([]byte(value)) + if err != nil { + return nil, err + } + // 2. Verify MAC. Value is "date|value|mac". + parts := bytes.SplitN(b, []byte("|"), 3) + if len(parts) != 3 { + return nil, errors.New("Decode: invalid value %v") + } + + b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...) + h := hmac.New(sha1.New, []byte(hashKey)) + h.Write(b) + sig := h.Sum(nil) + if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 { + return nil, errors.New("Decode: the value is not valid") + } + // 3. Verify date ranges. + var t1 int64 + if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil { + return nil, errors.New("Decode: invalid timestamp") + } + t2 := time.Now().UTC().Unix() + if t1 > t2 { + return nil, errors.New("Decode: timestamp is too new") + } + if t1 < t2-gcmaxlifetime { + return nil, errors.New("Decode: expired timestamp") + } + // 4. Decrypt (optional). + b, err = decode(parts[1]) + if err != nil { + return nil, err + } + if b, err = decrypt(block, b); err != nil { + return nil, err + } + // 5. decodeGob. + if dst, err := decodeGob(b); err != nil { + return nil, err + } else { + return dst, nil + } + // Done. + return nil, nil +} + +// Encoding ------------------------------------------------------------------- + +// encode encodes a value using base64. +func encode(value []byte) []byte { + encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value))) + base64.URLEncoding.Encode(encoded, value) + return encoded +} + +// decode decodes a cookie using base64. +func decode(value []byte) ([]byte, error) { + decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value))) + b, err := base64.URLEncoding.Decode(decoded, value) + if err != nil { + return nil, err + } + return decoded[:b], nil +} diff --git a/session/session.go b/session/session.go index 062bbfd6..df348fab 100644 --- a/session/session.go +++ b/session/session.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "crypto/sha1" "encoding/hex" + "encoding/json" "fmt" "io" "net/http" @@ -14,16 +15,16 @@ import ( ) type SessionStore interface { - Set(key, value interface{}) error //set session value - Get(key interface{}) interface{} //get session value - Delete(key interface{}) error //delete session value - SessionID() string //back current sessionID - SessionRelease() // release the resource & save data to provider - Flush() error //delete all data + Set(key, value interface{}) error //set session value + Get(key interface{}) interface{} //get session value + Delete(key interface{}) error //delete session value + SessionID() string //back current sessionID + SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data + Flush() error //delete all data } type Provider interface { - SessionInit(maxlifetime int64, savePath string) error + SessionInit(gclifetime int64, config string) error SessionRead(sid string) (SessionStore, error) SessionExist(sid string) bool SessionRegenerate(oldsid, sid string) (SessionStore, error) @@ -47,15 +48,21 @@ func Register(name string, provide Provider) { provides[name] = provide } +type managerConfig struct { + CookieName string `json:"cookieName"` + EnableSetCookie bool `json:"enableSetCookie,omitempty"` + Gclifetime int64 `json:"gclifetime"` + Maxage int `json:"maxage"` + Secure bool `json:"secure"` + SessionIDHashFunc string `json:"sessionIDHashFunc"` + SessionIDHashKey string `json:"sessionIDHashKey"` + CookieLifeTime int64 `json:"cookieLifeTime"` + ProviderConfig string `json:"providerConfig"` +} + type Manager struct { - cookieName string //private cookiename - provider Provider - maxlifetime int64 - hashfunc string //support md5 & sha1 - hashkey string - maxage int //cookielifetime - secure bool - options []interface{} + provider Provider + config *managerConfig } //options @@ -63,74 +70,49 @@ type Manager struct { //2. hashfunc default sha1 //3. hashkey default beegosessionkey //4. maxage default is none -func NewManager(provideName, cookieName string, maxlifetime int64, savePath string, options ...interface{}) (*Manager, error) { +func NewManager(provideName, config string) (*Manager, error) { provider, ok := provides[provideName] if !ok { return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName) } - provider.SessionInit(maxlifetime, savePath) - secure := false - if len(options) > 0 { - secure = options[0].(bool) + cf := new(managerConfig) + cf.EnableSetCookie = true + err := json.Unmarshal([]byte(config), cf) + if err != nil { + return nil, err } - hashfunc := "sha1" - if len(options) > 1 { - hashfunc = options[1].(string) + provider.SessionInit(cf.Gclifetime, cf.ProviderConfig) + + if cf.SessionIDHashFunc == "" { + cf.SessionIDHashFunc = "sha1" } - hashkey := "beegosessionkey" - if len(options) > 2 { - hashkey = options[2].(string) - } - maxage := -1 - if len(options) > 3 { - switch options[3].(type) { - case int: - if options[3].(int) > 0 { - maxage = options[3].(int) - } else if options[3].(int) < 0 { - maxage = 0 - } - case int64: - if options[3].(int64) > 0 { - maxage = int(options[3].(int64)) - } else if options[3].(int64) < 0 { - maxage = 0 - } - case int32: - if options[3].(int32) > 0 { - maxage = int(options[3].(int32)) - } else if options[3].(int32) < 0 { - maxage = 0 - } - } + if cf.SessionIDHashKey == "" { + cf.SessionIDHashKey = string(generateRandomKey(16)) } + return &Manager{ - provider: provider, - cookieName: cookieName, - maxlifetime: maxlifetime, - hashfunc: hashfunc, - hashkey: hashkey, - maxage: maxage, - secure: secure, - options: options, + provider, + cf, }, nil } //get Session func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) { - cookie, err := r.Cookie(manager.cookieName) + cookie, err := r.Cookie(manager.config.CookieName) if err != nil || cookie.Value == "" { sid := manager.sessionId(r) session, _ = manager.provider.SessionRead(sid) - cookie = &http.Cookie{Name: manager.cookieName, + cookie = &http.Cookie{Name: manager.config.CookieName, Value: url.QueryEscape(sid), Path: "/", HttpOnly: true, - Secure: manager.secure} - if manager.maxage >= 0 { - cookie.MaxAge = manager.maxage + Secure: manager.config.Secure} + if manager.config.Maxage >= 0 { + cookie.MaxAge = manager.config.Maxage + } + if manager.config.EnableSetCookie { + http.SetCookie(w, cookie) } - http.SetCookie(w, cookie) r.AddCookie(cookie) } else { sid, _ := url.QueryUnescape(cookie.Value) @@ -139,15 +121,17 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se } else { sid = manager.sessionId(r) session, _ = manager.provider.SessionRead(sid) - cookie = &http.Cookie{Name: manager.cookieName, + cookie = &http.Cookie{Name: manager.config.CookieName, Value: url.QueryEscape(sid), Path: "/", HttpOnly: true, - Secure: manager.secure} - if manager.maxage >= 0 { - cookie.MaxAge = manager.maxage + Secure: manager.config.Secure} + if manager.config.Maxage >= 0 { + cookie.MaxAge = manager.config.Maxage + } + if manager.config.EnableSetCookie { + http.SetCookie(w, cookie) } - http.SetCookie(w, cookie) r.AddCookie(cookie) } } @@ -156,13 +140,17 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se //Destroy sessionid func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { - cookie, err := r.Cookie(manager.cookieName) + cookie, err := r.Cookie(manager.config.CookieName) if err != nil || cookie.Value == "" { return } else { manager.provider.SessionDestroy(cookie.Value) expiration := time.Now() - cookie := http.Cookie{Name: manager.cookieName, Path: "/", HttpOnly: true, Expires: expiration, MaxAge: -1} + cookie := http.Cookie{Name: manager.config.CookieName, + Path: "/", + HttpOnly: true, + Expires: expiration, + MaxAge: -1} http.SetCookie(w, &cookie) } } @@ -174,20 +162,20 @@ func (manager *Manager) GetProvider(sid string) (sessions SessionStore, err erro func (manager *Manager) GC() { manager.provider.SessionGC() - time.AfterFunc(time.Duration(manager.maxlifetime)*time.Second, func() { manager.GC() }) + time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() }) } func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) { sid := manager.sessionId(r) - cookie, err := r.Cookie(manager.cookieName) + cookie, err := r.Cookie(manager.config.CookieName) if err != nil && cookie.Value == "" { //delete old cookie session, _ = manager.provider.SessionRead(sid) - cookie = &http.Cookie{Name: manager.cookieName, + cookie = &http.Cookie{Name: manager.config.CookieName, Value: url.QueryEscape(sid), Path: "/", HttpOnly: true, - Secure: manager.secure, + Secure: manager.config.Secure, } } else { oldsid, _ := url.QueryUnescape(cookie.Value) @@ -196,8 +184,8 @@ func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Reque cookie.HttpOnly = true cookie.Path = "/" } - if manager.maxage >= 0 { - cookie.MaxAge = manager.maxage + if manager.config.Maxage >= 0 { + cookie.MaxAge = manager.config.Maxage } http.SetCookie(w, cookie) r.AddCookie(cookie) @@ -209,12 +197,12 @@ func (manager *Manager) GetActiveSession() int { } func (manager *Manager) SetHashFunc(hasfunc, hashkey string) { - manager.hashfunc = hasfunc - manager.hashkey = hashkey + manager.config.SessionIDHashFunc = hasfunc + manager.config.SessionIDHashKey = hashkey } func (manager *Manager) SetSecure(secure bool) { - manager.secure = secure + manager.config.Secure = secure } //remote_addr cruunixnano randdata @@ -224,16 +212,16 @@ func (manager *Manager) sessionId(r *http.Request) (sid string) { return "" } sig := fmt.Sprintf("%s%d%s", r.RemoteAddr, time.Now().UnixNano(), bs) - if manager.hashfunc == "md5" { + if manager.config.SessionIDHashFunc == "md5" { h := md5.New() h.Write([]byte(sig)) sid = hex.EncodeToString(h.Sum(nil)) - } else if manager.hashfunc == "sha1" { - h := hmac.New(sha1.New, []byte(manager.hashkey)) + } else if manager.config.SessionIDHashFunc == "sha1" { + h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey)) fmt.Fprintf(h, "%s", sig) sid = hex.EncodeToString(h.Sum(nil)) } else { - h := hmac.New(sha1.New, []byte(manager.hashkey)) + h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey)) fmt.Fprintf(h, "%s", sig) sid = hex.EncodeToString(h.Sum(nil)) } From 9cbd4757019b8c141e18d6fbd077a8508b2cb301 Mon Sep 17 00:00:00 2001 From: astaxie Date: Sun, 5 Jan 2014 14:59:39 +0800 Subject: [PATCH 18/46] beego support new version session --- beego.go | 19 ++++++++++++------- config.go | 2 ++ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/beego.go b/beego.go index fd61a942..586ea5b4 100644 --- a/beego.go +++ b/beego.go @@ -191,14 +191,19 @@ func Run() { } if SessionOn { + sessionConfig := AppConfig.String("sessionConfig") + if sessionConfig == "" { + sessionConfig = `{"cookieName":` + SessionName + `,` + + `"gclifetime":` + SessionGCMaxLifetime + `,` + + `"providerConfig":` + SessionSavePath + `,` + + `"secure":` + HttpTLS + `,` + + `"sessionIDHashFunc":` + SessionHashFunc + `,` + + `"sessionIDHashKey":` + SessionHashKey + `,` + + `"enableSetCookie":` + SessionAutoSetCookie + `,` + + `"cookieLifeTime":` + SessionCookieLifeTime + `,}` + } GlobalSessions, _ = session.NewManager(SessionProvider, - SessionName, - SessionGCMaxLifetime, - SessionSavePath, - HttpTLS, - SessionHashFunc, - SessionHashKey, - SessionCookieLifeTime) + sessionConfig) go GlobalSessions.GC() } diff --git a/config.go b/config.go index 9baf9aa5..39a893c0 100644 --- a/config.go +++ b/config.go @@ -40,6 +40,7 @@ var ( SessionHashFunc string // session hash generation func. SessionHashKey string // session hash salt string. SessionCookieLifeTime int // the life time of session id in cookie. + SessionAutoSetCookie bool // auto setcookie UseFcgi bool MaxMemory int64 EnableGzip bool // flag of enable gzip @@ -96,6 +97,7 @@ func init() { SessionHashFunc = "sha1" SessionHashKey = "beegoserversessionkey" SessionCookieLifeTime = 0 //set cookie default is the brower life + SessionAutoSetCookie = true UseFcgi = false From 31bdb793cf9bf1698ecaf485561979acc66abad7 Mon Sep 17 00:00:00 2001 From: astaxie Date: Sun, 5 Jan 2014 15:21:50 +0800 Subject: [PATCH 19/46] make fix --- beego.go | 9 +++++---- router.go | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/beego.go b/beego.go index 586ea5b4..d672e834 100644 --- a/beego.go +++ b/beego.go @@ -4,6 +4,7 @@ import ( "net/http" "path" "path/filepath" + "strconv" "strings" "github.com/astaxie/beego/middleware" @@ -194,13 +195,13 @@ func Run() { sessionConfig := AppConfig.String("sessionConfig") if sessionConfig == "" { sessionConfig = `{"cookieName":` + SessionName + `,` + - `"gclifetime":` + SessionGCMaxLifetime + `,` + + `"gclifetime":` + strconv.FormatInt(SessionGCMaxLifetime, 10) + `,` + `"providerConfig":` + SessionSavePath + `,` + - `"secure":` + HttpTLS + `,` + + `"secure":` + strconv.FormatBool(HttpTLS) + `,` + `"sessionIDHashFunc":` + SessionHashFunc + `,` + `"sessionIDHashKey":` + SessionHashKey + `,` + - `"enableSetCookie":` + SessionAutoSetCookie + `,` + - `"cookieLifeTime":` + SessionCookieLifeTime + `,}` + `"enableSetCookie":` + strconv.FormatBool(SessionAutoSetCookie) + `,` + + `"cookieLifeTime":` + strconv.Itoa(SessionCookieLifeTime) + `,}` } GlobalSessions, _ = session.NewManager(SessionProvider, sessionConfig) diff --git a/router.go b/router.go index f7c92ce7..fce3885b 100644 --- a/router.go +++ b/router.go @@ -529,7 +529,7 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) // session init if SessionOn { context.Input.CruSession = GlobalSessions.SessionStart(w, r) - defer context.Input.CruSession.SessionRelease() + defer context.Input.CruSession.SessionRelease(w) } if !utils.InSlice(strings.ToLower(r.Method), HTTPMETHOD) { From 338124e3fb3abe31b3d1b2c5d71cd48e4864c08c Mon Sep 17 00:00:00 2001 From: astaxie Date: Sun, 5 Jan 2014 15:43:48 +0800 Subject: [PATCH 20/46] fix #443 --- middleware/error.go | 1 + 1 file changed, 1 insertion(+) diff --git a/middleware/error.go b/middleware/error.go index ebfe2288..35d9eb59 100644 --- a/middleware/error.go +++ b/middleware/error.go @@ -71,6 +71,7 @@ func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack str data["Stack"] = Stack data["BeegoVersion"] = VERSION data["GoVersion"] = runtime.Version() + rw.WriteHeader(500) t.Execute(rw, data) } From 6f3a759ba5e2af5666ea30080d6c9396a7c258fb Mon Sep 17 00:00:00 2001 From: astaxie Date: Sun, 5 Jan 2014 23:16:47 +0800 Subject: [PATCH 21/46] gmfim add lock. fix #445 --- memzipfile.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/memzipfile.go b/memzipfile.go index 43a82a30..8d3edd9c 100644 --- a/memzipfile.go +++ b/memzipfile.go @@ -5,16 +5,17 @@ import ( "compress/flate" "compress/gzip" "errors" - //"fmt" "io" "io/ioutil" "net/http" "os" "strings" + "sync" "time" ) var gmfim map[string]*MemFileInfo = make(map[string]*MemFileInfo) +var lock sync.RWMutex // OpenMemZipFile returns MemFile object with a compressed static file. // it's used for serve static file if gzip enable. @@ -32,12 +33,12 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) { modtime := osfileinfo.ModTime() fileSize := osfileinfo.Size() - + lock.RLock() cfi, ok := gmfim[zip+":"+path] + lock.RUnlock() if ok && cfi.ModTime() == modtime && cfi.fileSize == fileSize { - //fmt.Printf("read %s file %s from cache\n", zip, path) + } else { - //fmt.Printf("NOT read %s file %s from cache\n", zip, path) var content []byte if zip == "gzip" { //将文件内容压缩到zipbuf中 @@ -81,8 +82,9 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) { } cfi = &MemFileInfo{osfileinfo, modtime, content, int64(len(content)), fileSize} + lock.Lock() + defer lock.Unlock() gmfim[zip+":"+path] = cfi - //fmt.Printf("%s file %s to %d, cache it\n", zip, path, len(content)) } return &MemFile{fi: cfi, offset: 0}, nil } From b766f65c268c3d55aeea4a3380b651660c0029e4 Mon Sep 17 00:00:00 2001 From: slene Date: Mon, 6 Jan 2014 11:07:03 +0800 Subject: [PATCH 22/46] #436 support insert multi --- orm/db.go | 119 ++++++++++++++++++++++++++++++++++++++------ orm/db_tables.go | 2 - orm/orm.go | 81 +++++++++++++++++++++--------- orm/orm_querym2m.go | 12 ++--- orm/types.go | 4 +- 5 files changed, 167 insertions(+), 51 deletions(-) diff --git a/orm/db.go b/orm/db.go index 66ae498e..c6e92ec9 100644 --- a/orm/db.go +++ b/orm/db.go @@ -51,7 +51,13 @@ type dbBase struct { var _ dbBaser = new(dbBase) -func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) { +func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) { + var columns []string + + if names != nil { + columns = *names + } + for _, column := range cols { var fi *fieldInfo if fi, _ = mi.fields.GetByAny(column); fi != nil { @@ -64,11 +70,20 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, } value, err := d.collectFieldValue(mi, fi, ind, insert, tz) if err != nil { - return nil, nil, err + return nil, err } - columns = append(columns, column) + + if names != nil { + columns = append(columns, column) + } + values = append(values, value) } + + if names != nil { + *names = columns + } + return } @@ -166,7 +181,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, } func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { - _, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz) + values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) if err != nil { return 0, err } @@ -192,7 +207,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo // if specify cols length > 0, then use it for where condition. if len(cols) > 0 { var err error - whereCols, args, err = d.collectValues(mi, ind, cols, false, false, tz) + whereCols = make([]string, 0, len(cols)) + args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) if err != nil { return err } @@ -202,7 +218,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo if ok == false { return ErrMissPK } - whereCols = append(whereCols, pkColumn) + whereCols = []string{pkColumn} args = append(args, pkValue) } @@ -244,15 +260,72 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo } func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { - names, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz) + names := make([]string, 0, len(mi.fields.dbcols)-1) + values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz) if err != nil { return 0, err } - return d.InsertValue(q, mi, names, values) + return d.InsertValue(q, mi, false, names, values) } -func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values []interface{}) (int64, error) { +func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { + var ( + cnt int64 + nums int + values []interface{} + names []string + ) + + // typ := reflect.Indirect(mi.addrField).Type() + + length := sind.Len() + + for i := 1; i <= length; i++ { + + ind := reflect.Indirect(sind.Index(i - 1)) + + // Is this needed ? + // if !ind.Type().AssignableTo(typ) { + // return cnt, ErrArgs + // } + + if i == 1 { + vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz) + if err != nil { + return cnt, err + } + values = make([]interface{}, bulk*len(vus)) + nums += copy(values, vus) + + } else { + + vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) + if err != nil { + return cnt, err + } + + if len(vus) != len(names) { + return cnt, ErrArgs + } + + nums += copy(values[nums:], vus) + } + + if i > 1 && i%bulk == 0 || length == i { + num, err := d.InsertValue(q, mi, true, names, values[:nums]) + if err != nil { + return cnt, err + } + cnt += num + nums = 0 + } + } + + return cnt, nil +} + +func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { Q := d.ins.TableQuote() marks := make([]string, len(names)) @@ -264,21 +337,30 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values qmarks := strings.Join(marks, ", ") columns := strings.Join(names, sep) + multi := len(values) / len(names) + + if isMulti { + qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks + } + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) d.ins.ReplaceMarks(&query) - if d.ins.HasReturningID(mi, &query) { - row := q.QueryRow(query, values...) - var id int64 - err := row.Scan(&id) - return id, err - } else { + if isMulti || !d.ins.HasReturningID(mi, &query) { if res, err := q.Exec(query, values...); err == nil { + if isMulti { + return res.RowsAffected() + } return res.LastInsertId() } else { return 0, err } + } else { + row := q.QueryRow(query, values...) + var id int64 + err := row.Scan(&id) + return id, err } } @@ -288,12 +370,17 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. return 0, ErrMissPK } + var setNames []string + // if specify cols length is zero, then commit all columns. if len(cols) == 0 { cols = mi.fields.dbcols + setNames = make([]string, 0, len(mi.fields.dbcols)-1) + } else { + setNames = make([]string, 0, len(cols)) } - setNames, setValues, err := d.collectValues(mi, ind, cols, true, false, tz) + setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz) if err != nil { return 0, err } diff --git a/orm/db_tables.go b/orm/db_tables.go index 5a78cf21..f5cacf38 100644 --- a/orm/db_tables.go +++ b/orm/db_tables.go @@ -214,8 +214,6 @@ loopFor: fi, ok = mmi.fields.GetByAny(ex) } - // fmt.Println(ex, fi.name, fiN) - _ = okN if ok { diff --git a/orm/orm.go b/orm/orm.go index 0069aa1d..9e3c3565 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -25,6 +25,7 @@ var ( ErrMultiRows = errors.New(" return multi rows") ErrNoRows = errors.New(" no row found") ErrStmtClosed = errors.New(" stmt already closed") + ErrArgs = errors.New(" args error may be empty") ErrNotImplement = errors.New("have not implement") ) @@ -39,11 +40,11 @@ type orm struct { var _ Ormer = new(orm) -func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) { +func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { val := reflect.ValueOf(md) ind = reflect.Indirect(val) typ := ind.Type() - if val.Kind() != reflect.Ptr { + if needPtr && val.Kind() != reflect.Ptr { panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) } name := getFullName(typ) @@ -62,7 +63,7 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { } func (o *orm) Read(md interface{}, cols ...string) error { - mi, ind := o.getMiInd(md) + mi, ind := o.getMiInd(md, true) err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols) if err != nil { return err @@ -71,25 +72,63 @@ func (o *orm) Read(md interface{}, cols ...string) error { } func (o *orm) Insert(md interface{}) (int64, error) { - mi, ind := o.getMiInd(md) + mi, ind := o.getMiInd(md, true) id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) if err != nil { return id, err } - if id > 0 { - if mi.fields.pk.auto { - if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { - ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id)) - } else { - ind.Field(mi.fields.pk.fieldIndex).SetInt(id) - } - } - } + + o.setPk(mi, ind, id) + return id, nil } +func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { + if mi.fields.pk.auto { + if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { + ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id)) + } else { + ind.Field(mi.fields.pk.fieldIndex).SetInt(id) + } + } +} + +func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { + var cnt int64 + + sind := reflect.Indirect(reflect.ValueOf(mds)) + + switch sind.Kind() { + case reflect.Array, reflect.Slice: + if sind.Len() == 0 { + return cnt, ErrArgs + } + default: + return cnt, ErrArgs + } + + if bulk <= 1 { + for i := 0; i < sind.Len(); i++ { + ind := sind.Index(i) + mi, _ := o.getMiInd(ind.Interface(), false) + id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) + if err != nil { + return cnt, err + } + + o.setPk(mi, ind, id) + + cnt += 1 + } + } else { + mi, _ := o.getMiInd(sind.Index(0).Interface(), false) + return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ) + } + return cnt, nil +} + func (o *orm) Update(md interface{}, cols ...string) (int64, error) { - mi, ind := o.getMiInd(md) + mi, ind := o.getMiInd(md, true) num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) if err != nil { return num, err @@ -98,25 +137,19 @@ func (o *orm) Update(md interface{}, cols ...string) (int64, error) { } func (o *orm) Delete(md interface{}) (int64, error) { - mi, ind := o.getMiInd(md) + mi, ind := o.getMiInd(md, true) num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ) if err != nil { return num, err } if num > 0 { - if mi.fields.pk.auto { - if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { - ind.Field(mi.fields.pk.fieldIndex).SetUint(0) - } else { - ind.Field(mi.fields.pk.fieldIndex).SetInt(0) - } - } + o.setPk(mi, ind, 0) } return num, nil } func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { - mi, ind := o.getMiInd(md) + mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) switch { @@ -197,7 +230,7 @@ func (o *orm) QueryRelated(md interface{}, name string) QuerySeter { } func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { - mi, ind := o.getMiInd(md) + mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) _, _, exist := getExistPk(mi, ind) diff --git a/orm/orm_querym2m.go b/orm/orm_querym2m.go index 876fc37e..6f0544d0 100644 --- a/orm/orm_querym2m.go +++ b/orm/orm_querym2m.go @@ -44,7 +44,8 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { names := []string{mfi.column, rfi.column} - var nums int64 + values := make([]interface{}, 0, len(models)*2) + for _, md := range models { ind := reflect.Indirect(reflect.ValueOf(md)) @@ -59,16 +60,11 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { } } - values := []interface{}{v1, v2} - _, err := dbase.InsertValue(orm.db, mi, names, values) - if err != nil { - return nums, err - } + values = append(values, v1, v2) - nums += 1 } - return nums, nil + return dbase.InsertValue(orm.db, mi, true, names, values) } func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { diff --git a/orm/types.go b/orm/types.go index 4749124c..a6487fc0 100644 --- a/orm/types.go +++ b/orm/types.go @@ -21,6 +21,7 @@ type Fielder interface { type Ormer interface { Read(interface{}, ...string) error Insert(interface{}) (int64, error) + InsertMulti(int, interface{}) (int64, error) Update(interface{}, ...string) (int64, error) Delete(interface{}) (int64, error) LoadRelated(interface{}, string, ...interface{}) (int64, error) @@ -109,7 +110,8 @@ type txEnder interface { type dbBaser interface { Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - InsertValue(dbQuerier, *modelInfo, []string, []interface{}) (int64, error) + InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) + InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) From aa2fef0d36151297fe0f70480353d8fd45b9700f Mon Sep 17 00:00:00 2001 From: astaxie Date: Wed, 8 Jan 2014 20:54:20 +0800 Subject: [PATCH 23/46] update sessionRelease 1. mysql fix last access time not update 2. mysql & redid Release when data is empty 3. add maxlifetime distinct Gclifetime --- README.md | 7 ------- session/sess_mysql.go | 15 ++++++++------- session/sess_redis.go | 12 +++++------- session/session.go | 7 ++++++- 4 files changed, 19 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 8d010eb9..00ad468d 100644 --- a/README.md +++ b/README.md @@ -34,10 +34,3 @@ More info [beego.me](http://beego.me) beego is licensed under the Apache Licence, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0.html). - - -## Use case - -- Displaying API documentation: [gowalker](https://github.com/Unknwon/gowalker) -- seocms: [seocms](https://github.com/chinakr/seocms) -- CMS: [toropress](https://github.com/insionng/toropress) diff --git a/session/sess_mysql.go b/session/sess_mysql.go index 1101e437..3b0c6f3f 100644 --- a/session/sess_mysql.go +++ b/session/sess_mysql.go @@ -63,13 +63,13 @@ func (st *MysqlSessionStore) SessionID() string { func (st *MysqlSessionStore) SessionRelease(w http.ResponseWriter) { defer st.c.Close() - if len(st.values) > 0 { - b, err := encodeGob(st.values) - if err != nil { - return - } - st.c.Exec("UPDATE session set `session_data`= ? where session_key=?", b, st.sid) + b, err := encodeGob(st.values) + if err != nil { + return } + st.c.Exec("UPDATE session set `session_data`=?, `session_expiry`=? where session_key=?", + b, time.Now().Unix(), st.sid) + } type MysqlProvider struct { @@ -97,7 +97,8 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) { var sessiondata []byte err := row.Scan(&sessiondata) if err == sql.ErrNoRows { - c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", sid, "", time.Now().Unix()) + c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", + sid, "", time.Now().Unix()) } var kv map[interface{}]interface{} if len(sessiondata) == 0 { diff --git a/session/sess_redis.go b/session/sess_redis.go index 0f8c0308..51685844 100644 --- a/session/sess_redis.go +++ b/session/sess_redis.go @@ -61,14 +61,12 @@ func (rs *RedisSessionStore) SessionID() string { func (rs *RedisSessionStore) SessionRelease(w http.ResponseWriter) { defer rs.c.Close() - if len(rs.values) > 0 { - b, err := encodeGob(rs.values) - if err != nil { - return - } - rs.c.Do("SET", rs.sid, string(b)) - rs.c.Do("EXPIRE", rs.sid, rs.maxlifetime) + b, err := encodeGob(rs.values) + if err != nil { + return } + rs.c.Do("SET", rs.sid, string(b)) + rs.c.Do("EXPIRE", rs.sid, rs.maxlifetime) } type RedisProvider struct { diff --git a/session/session.go b/session/session.go index df348fab..0447dac7 100644 --- a/session/session.go +++ b/session/session.go @@ -52,6 +52,7 @@ type managerConfig struct { CookieName string `json:"cookieName"` EnableSetCookie bool `json:"enableSetCookie,omitempty"` Gclifetime int64 `json:"gclifetime"` + Maxlifetime int64 `json:"maxLifetime"` Maxage int `json:"maxage"` Secure bool `json:"secure"` SessionIDHashFunc string `json:"sessionIDHashFunc"` @@ -81,7 +82,11 @@ func NewManager(provideName, config string) (*Manager, error) { if err != nil { return nil, err } - provider.SessionInit(cf.Gclifetime, cf.ProviderConfig) + if cf.Maxlifetime == 0 { + cf.Maxlifetime = cf.Gclifetime + } + + provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig) if cf.SessionIDHashFunc == "" { cf.SessionIDHashFunc = "sha1" From d06c04277f449ecd39a7083449a29fc637c079b2 Mon Sep 17 00:00:00 2001 From: astaxie Date: Wed, 8 Jan 2014 22:31:26 +0800 Subject: [PATCH 24/46] support send mail --- utils/mail.go | 299 +++++++++++++++++++++++++++++++++++++++++++++ utils/mail_test.go | 24 ++++ 2 files changed, 323 insertions(+) create mode 100644 utils/mail.go create mode 100644 utils/mail_test.go diff --git a/utils/mail.go b/utils/mail.go new file mode 100644 index 00000000..5096111b --- /dev/null +++ b/utils/mail.go @@ -0,0 +1,299 @@ +package utils + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "mime" + "mime/multipart" + "net/mail" + "net/smtp" + "net/textproto" + "os" + "path" + "path/filepath" + "strconv" + "strings" +) + +const ( + maxLineLength = 76 +) + +// Email is the type used for email messages +type Email struct { + Auth smtp.Auth + Identity string `json:"identity"` + Username string `json:"username"` + Password string `json:"password"` + Host string `json:"host"` + Port int `json:"port"` + From string `json:"from"` + To []string + Bcc []string + Cc []string + Subject string + Text string // Plaintext message (optional) + HTML string // Html message (optional) + Headers textproto.MIMEHeader + Attachments []*Attachment + ReadReceipt []string +} + +// Attachment is a struct representing an email attachment. +// Based on the mime/multipart.FileHeader struct, Attachment contains the name, MIMEHeader, and content of the attachment in question +type Attachment struct { + Filename string + Header textproto.MIMEHeader + Content []byte +} + +func NewEMail(config string) *Email { + e := new(Email) + e.Headers = textproto.MIMEHeader{} + err := json.Unmarshal([]byte(config), e) + if err != nil { + return nil + } + if e.From == "" { + e.From = e.Username + } + return e +} + +// make all send information to byte +func (e *Email) Bytes() ([]byte, error) { + buff := &bytes.Buffer{} + w := multipart.NewWriter(buff) + // Set the appropriate headers (overwriting any conflicts) + // Leave out Bcc (only included in envelope headers) + e.Headers.Set("To", strings.Join(e.To, ",")) + if e.Cc != nil { + e.Headers.Set("Cc", strings.Join(e.Cc, ",")) + } + e.Headers.Set("From", e.From) + e.Headers.Set("Subject", e.Subject) + if len(e.ReadReceipt) != 0 { + e.Headers.Set("Disposition-Notification-To", strings.Join(e.ReadReceipt, ",")) + } + e.Headers.Set("MIME-Version", "1.0") + e.Headers.Set("Content-Type", fmt.Sprintf("multipart/mixed;\r\n boundary=%s\r\n", w.Boundary())) + + // Write the envelope headers (including any custom headers) + if err := headerToBytes(buff, e.Headers); err != nil { + return nil, fmt.Errorf("Failed to render message headers: %s", err) + } + // Start the multipart/mixed part + fmt.Fprintf(buff, "--%s\r\n", w.Boundary()) + header := textproto.MIMEHeader{} + // Check to see if there is a Text or HTML field + if e.Text != "" || e.HTML != "" { + subWriter := multipart.NewWriter(buff) + // Create the multipart alternative part + header.Set("Content-Type", fmt.Sprintf("multipart/alternative;\r\n boundary=%s\r\n", subWriter.Boundary())) + // Write the header + if err := headerToBytes(buff, header); err != nil { + return nil, fmt.Errorf("Failed to render multipart message headers: %s", err) + } + // Create the body sections + if e.Text != "" { + header.Set("Content-Type", fmt.Sprintf("text/plain; charset=UTF-8")) + header.Set("Content-Transfer-Encoding", "quoted-printable") + if _, err := subWriter.CreatePart(header); err != nil { + return nil, err + } + // Write the text + if err := quotePrintEncode(buff, e.Text); err != nil { + return nil, err + } + } + if e.HTML != "" { + header.Set("Content-Type", fmt.Sprintf("text/html; charset=UTF-8")) + header.Set("Content-Transfer-Encoding", "quoted-printable") + if _, err := subWriter.CreatePart(header); err != nil { + return nil, err + } + // Write the text + if err := quotePrintEncode(buff, e.HTML); err != nil { + return nil, err + } + } + if err := subWriter.Close(); err != nil { + return nil, err + } + } + // Create attachment part, if necessary + for _, a := range e.Attachments { + ap, err := w.CreatePart(a.Header) + if err != nil { + return nil, err + } + // Write the base64Wrapped content to the part + base64Wrap(ap, a.Content) + } + if err := w.Close(); err != nil { + return nil, err + } + return buff.Bytes(), nil +} + +// Attach file to the send mail +func (e *Email) AttachFile(filename string) (a *Attachment, err error) { + f, err := os.Open(filename) + if err != nil { + return + } + ct := mime.TypeByExtension(filepath.Ext(filename)) + basename := path.Base(filename) + return e.Attach(f, basename, ct) +} + +// Attach is used to attach content from an io.Reader to the email. +// Required parameters include an io.Reader, the desired filename for the attachment, and the Content-Type +func (e *Email) Attach(r io.Reader, filename string, c string) (a *Attachment, err error) { + var buffer bytes.Buffer + if _, err = io.Copy(&buffer, r); err != nil { + return + } + at := &Attachment{ + Filename: filename, + Header: textproto.MIMEHeader{}, + Content: buffer.Bytes(), + } + // Get the Content-Type to be used in the MIMEHeader + if c != "" { + at.Header.Set("Content-Type", c) + } else { + // If the Content-Type is blank, set the Content-Type to "application/octet-stream" + at.Header.Set("Content-Type", "application/octet-stream") + } + at.Header.Set("Content-Disposition", fmt.Sprintf("attachment;\r\n filename=\"%s\"", filename)) + at.Header.Set("Content-Transfer-Encoding", "base64") + e.Attachments = append(e.Attachments, at) + return at, nil +} + +func (e *Email) Send() error { + if e.Auth == nil { + e.Auth = smtp.PlainAuth(e.Identity, e.Username, e.Password, e.Host) + } + // Merge the To, Cc, and Bcc fields + to := make([]string, 0, len(e.To)+len(e.Cc)+len(e.Bcc)) + to = append(append(append(to, e.To...), e.Cc...), e.Bcc...) + // Check to make sure there is at least one recipient and one "From" address + if e.From == "" || len(to) == 0 { + return errors.New("Must specify at least one From address and one To address") + } + from, err := mail.ParseAddress(e.From) + if err != nil { + return err + } + raw, err := e.Bytes() + if err != nil { + return err + } + return smtp.SendMail(e.Host+":"+strconv.Itoa(e.Port), e.Auth, from.Address, to, raw) +} + +// quotePrintEncode writes the quoted-printable text to the IO Writer (according to RFC 2045) +func quotePrintEncode(w io.Writer, s string) error { + var buf [3]byte + mc := 0 + for i := 0; i < len(s); i++ { + c := s[i] + // We're assuming Unix style text formats as input (LF line break), and + // quoted-printble uses CRLF line breaks. (Literal CRs will become + // "=0D", but probably shouldn't be there to begin with!) + if c == '\n' { + io.WriteString(w, "\r\n") + mc = 0 + continue + } + + var nextOut []byte + if isPrintable(c) { + nextOut = append(buf[:0], c) + } else { + nextOut = buf[:] + qpEscape(nextOut, c) + } + + // Add a soft line break if the next (encoded) byte would push this line + // to or past the limit. + if mc+len(nextOut) >= maxLineLength { + if _, err := io.WriteString(w, "=\r\n"); err != nil { + return err + } + mc = 0 + } + + if _, err := w.Write(nextOut); err != nil { + return err + } + mc += len(nextOut) + } + // No trailing end-of-line?? Soft line break, then. TODO: is this sane? + if mc > 0 { + io.WriteString(w, "=\r\n") + } + return nil +} + +// isPrintable returns true if the rune given is "printable" according to RFC 2045, false otherwise +func isPrintable(c byte) bool { + return (c >= '!' && c <= '<') || (c >= '>' && c <= '~') || (c == ' ' || c == '\n' || c == '\t') +} + +// qpEscape is a helper function for quotePrintEncode which escapes a +// non-printable byte. Expects len(dest) == 3. +func qpEscape(dest []byte, c byte) { + const nums = "0123456789ABCDEF" + dest[0] = '=' + dest[1] = nums[(c&0xf0)>>4] + dest[2] = nums[(c & 0xf)] +} + +// headerToBytes enumerates the key and values in the header, and writes the results to the IO Writer +func headerToBytes(w io.Writer, t textproto.MIMEHeader) error { + for k, v := range t { + // Write the header key + _, err := fmt.Fprintf(w, "%s:", k) + if err != nil { + return err + } + // Write each value in the header + for _, c := range v { + _, err := fmt.Fprintf(w, " %s\r\n", c) + if err != nil { + return err + } + } + } + return nil +} + +// base64Wrap encodeds the attachment content, and wraps it according to RFC 2045 standards (every 76 chars) +// The output is then written to the specified io.Writer +func base64Wrap(w io.Writer, b []byte) { + // 57 raw bytes per 76-byte base64 line. + const maxRaw = 57 + // Buffer for each line, including trailing CRLF. + var buffer [maxLineLength + len("\r\n")]byte + copy(buffer[maxLineLength:], "\r\n") + // Process raw chunks until there's no longer enough to fill a line. + for len(b) >= maxRaw { + base64.StdEncoding.Encode(buffer[:], b[:maxRaw]) + w.Write(buffer[:]) + b = b[maxRaw:] + } + // Handle the last chunk of bytes. + if len(b) > 0 { + out := buffer[:base64.StdEncoding.EncodedLen(len(b))] + base64.StdEncoding.Encode(out, b) + out = append(out, "\r\n"...) + w.Write(out) + } +} diff --git a/utils/mail_test.go b/utils/mail_test.go new file mode 100644 index 00000000..a75c8681 --- /dev/null +++ b/utils/mail_test.go @@ -0,0 +1,24 @@ +package utils + +import "testing" + +func TestMail(t *testing.T) { + config := `{"username":"astaxie@gmail.com","password":"astaxie","host":"smtp.gmail.com","port":587}` + mail := NewEMail(config) + if mail.Username != "astaxie@gmail.com" { + t.Fatal("email parse get username error") + } + if mail.Password != "astaxie" { + t.Fatal("email parse get password error") + } + if mail.Host != "smtp.gmail.com" { + t.Fatal("email parse get host error") + } + if mail.Port != 587 { + t.Fatal("email parse get port error") + } + mail.To = []string{"xiemengjun@gmail.com"} + mail.From = "astaxie@gmail.com" + mail.Subject = "hi, just from beego!" + mail.Send() +} From d7f2c738c8d46b4461053b5d632e92c0b7e81473 Mon Sep 17 00:00:00 2001 From: astaxie Date: Wed, 8 Jan 2014 22:35:42 +0800 Subject: [PATCH 25/46] add attach file --- utils/mail_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/utils/mail_test.go b/utils/mail_test.go index a75c8681..c0535ed5 100644 --- a/utils/mail_test.go +++ b/utils/mail_test.go @@ -20,5 +20,8 @@ func TestMail(t *testing.T) { mail.To = []string{"xiemengjun@gmail.com"} mail.From = "astaxie@gmail.com" mail.Subject = "hi, just from beego!" + mail.Text = "Text Body is, of course, supported!" + mail.HTML = "

            Fancy Html is supported, too!

            " + mail.AttachFile("/Users/astaxie/github/beego/beego.go") mail.Send() } From e34f8c4634cdbbb0ab47c645b7e456f3902600d9 Mon Sep 17 00:00:00 2001 From: astaxie Date: Wed, 8 Jan 2014 23:24:31 +0800 Subject: [PATCH 26/46] add cookie test --- session/sess_cookie.go | 8 +++++--- session/sess_cookie_test.go | 38 +++++++++++++++++++++++++++++++++++++ session/session.go | 7 ++++--- 3 files changed, 47 insertions(+), 6 deletions(-) create mode 100644 session/sess_cookie_test.go diff --git a/session/sess_cookie.go b/session/sess_cookie.go index deff70a0..7962be18 100644 --- a/session/sess_cookie.go +++ b/session/sess_cookie.go @@ -105,12 +105,14 @@ func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error } func (pder *CookieProvider) SessionRead(sid string) (SessionStore, error) { - kv := make(map[interface{}]interface{}) - kv, _ = decodeCookie(pder.block, + maps, _ := decodeCookie(pder.block, pder.config.SecurityKey, pder.config.SecurityName, sid, pder.maxlifetime) - rs := &CookieSessionStore{sid: sid, values: kv} + if maps == nil { + maps = make(map[interface{}]interface{}) + } + rs := &CookieSessionStore{sid: sid, values: maps} return rs, nil } diff --git a/session/sess_cookie_test.go b/session/sess_cookie_test.go new file mode 100644 index 00000000..154c15a2 --- /dev/null +++ b/session/sess_cookie_test.go @@ -0,0 +1,38 @@ +package session + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestCookie(t *testing.T) { + config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` + globalSessions, err := NewManager("cookie", config) + if err != nil { + t.Fatal("init cookie session err", err) + } + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + sess := globalSessions.SessionStart(w, r) + err = sess.Set("username", "astaxie") + if err != nil { + t.Fatal("set error,", err) + } + if username := sess.Get("username"); username != "astaxie" { + t.Fatal("get username error") + } + sess.SessionRelease(w) + if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { + t.Fatal("setcookie error") + } else { + parts := strings.Split(strings.TrimSpace(cookiestr), ";") + for k, v := range parts { + nameval := strings.Split(v, "=") + if k == 0 && nameval[0] != "gosessionid" { + t.Fatal("error") + } + } + } +} diff --git a/session/session.go b/session/session.go index 0447dac7..f41ba85b 100644 --- a/session/session.go +++ b/session/session.go @@ -85,9 +85,10 @@ func NewManager(provideName, config string) (*Manager, error) { if cf.Maxlifetime == 0 { cf.Maxlifetime = cf.Gclifetime } - - provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig) - + err = provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig) + if err != nil { + return nil, err + } if cf.SessionIDHashFunc == "" { cf.SessionIDHashFunc = "sha1" } From a369b15ef29ed5f7e2418f4d058c78f72d085b1b Mon Sep 17 00:00:00 2001 From: Pengfei Xue Date: Thu, 9 Jan 2014 18:49:18 +0800 Subject: [PATCH 27/46] reset cache connection to nil, if err isio.EOF * this will support auto-connection --- cache/redis.go | 56 ++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 50 insertions(+), 6 deletions(-) diff --git a/cache/redis.go b/cache/redis.go index b923a6df..0fac0deb 100644 --- a/cache/redis.go +++ b/cache/redis.go @@ -3,6 +3,7 @@ package cache import ( "encoding/json" "errors" + "io" "github.com/beego/redigo/redis" ) @@ -33,10 +34,18 @@ func (rc *RedisCache) Get(key string) interface{} { return nil } } + v, err := rc.c.Do("HGET", rc.key, key) + // write to closed socket, reset rc.c to nil + if err == io.EOF { + rc.c = nil + return nil + } + if err != nil { return nil } + return v } @@ -50,7 +59,14 @@ func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error { return err } } + _, err := rc.c.Do("HSET", rc.key, key, val) + // write to closed socket, reset rc.c to nil + if err == io.EOF { + rc.c = nil + return err + } + return err } @@ -63,7 +79,14 @@ func (rc *RedisCache) Delete(key string) error { return err } } + _, err := rc.c.Do("HDEL", rc.key, key) + // write to closed socket, reset rc.c to nil + if err == io.EOF { + rc.c = nil + return err + } + return err } @@ -76,10 +99,18 @@ func (rc *RedisCache) IsExist(key string) bool { return false } } + v, err := redis.Bool(rc.c.Do("HEXISTS", rc.key, key)) + // write to closed socket, reset rc.c to nil + if err == io.EOF { + rc.c = nil + return false + } + if err != nil { return false } + return v } @@ -92,11 +123,14 @@ func (rc *RedisCache) Incr(key string) error { return err } } + _, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, 1)) - if err != nil { - return err + // write to closed socket + if err == io.EOF { + rc.c = nil } - return nil + + return err } // decrease counter in redis. @@ -108,11 +142,15 @@ func (rc *RedisCache) Decr(key string) error { return err } } + _, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, -1)) - if err != nil { - return err + + // write to closed socket + if err == io.EOF { + rc.c = nil } - return nil + + return err } // clean all cache in redis. delete this redis collection. @@ -124,7 +162,13 @@ func (rc *RedisCache) ClearAll() error { return err } } + _, err := rc.c.Do("DEL", rc.key) + // write to closed socket + if err == io.EOF { + rc.c = nil + } + return err } From 0b42e5573bed39e3a5c0368632e61f054223b48a Mon Sep 17 00:00:00 2001 From: Pengfei Xue Date: Thu, 9 Jan 2014 18:50:30 +0800 Subject: [PATCH 28/46] align memcache operations with redis --- cache/memcache.go | 41 +++++++++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/cache/memcache.go b/cache/memcache.go index 15d3649c..365c5de7 100644 --- a/cache/memcache.go +++ b/cache/memcache.go @@ -21,7 +21,11 @@ func NewMemCache() *MemcacheCache { // get value from memcache. func (rc *MemcacheCache) Get(key string) interface{} { if rc.c == nil { - rc.c = rc.connectInit() + var err error + rc.c, err = rc.connectInit() + if err != nil { + return err + } } v, err := rc.c.Get(key) if err != nil { @@ -39,7 +43,11 @@ func (rc *MemcacheCache) Get(key string) interface{} { // put value to memcache. only support string. func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error { if rc.c == nil { - rc.c = rc.connectInit() + var err error + rc.c, err = rc.connectInit() + if err != nil { + return err + } } v, ok := val.(string) if !ok { @@ -55,7 +63,11 @@ func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error { // delete value in memcache. func (rc *MemcacheCache) Delete(key string) error { if rc.c == nil { - rc.c = rc.connectInit() + var err error + rc.c, err = rc.connectInit() + if err != nil { + return err + } } _, err := rc.c.Delete(key) return err @@ -76,7 +88,11 @@ func (rc *MemcacheCache) Decr(key string) error { // check value exists in memcache. func (rc *MemcacheCache) IsExist(key string) bool { if rc.c == nil { - rc.c = rc.connectInit() + var err error + rc.c, err = rc.connectInit() + if err != nil { + return false + } } v, err := rc.c.Get(key) if err != nil { @@ -93,7 +109,11 @@ func (rc *MemcacheCache) IsExist(key string) bool { // clear all cached in memcache. func (rc *MemcacheCache) ClearAll() error { if rc.c == nil { - rc.c = rc.connectInit() + var err error + rc.c, err = rc.connectInit() + if err != nil { + return err + } } err := rc.c.FlushAll() return err @@ -109,20 +129,21 @@ func (rc *MemcacheCache) StartAndGC(config string) error { return errors.New("config has no conn key") } rc.conninfo = cf["conn"] - rc.c = rc.connectInit() - if rc.c == nil { + var err error + rc.c, err = rc.connectInit() + if err != nil { return errors.New("dial tcp conn error") } return nil } // connect to memcache and keep the connection. -func (rc *MemcacheCache) connectInit() *memcache.Connection { +func (rc *MemcacheCache) connectInit() (*memcache.Connection, error) { c, err := memcache.Connect(rc.conninfo) if err != nil { - return nil + return nil, err } - return c + return c, nil } func init() { From 844412c302478acd9945f94eb00a6704ee6ff906 Mon Sep 17 00:00:00 2001 From: astaxie Date: Thu, 9 Jan 2014 21:37:50 +0800 Subject: [PATCH 29/46] fix #453 --- beego.go | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/beego.go b/beego.go index d672e834..57d4822f 100644 --- a/beego.go +++ b/beego.go @@ -194,17 +194,20 @@ func Run() { if SessionOn { sessionConfig := AppConfig.String("sessionConfig") if sessionConfig == "" { - sessionConfig = `{"cookieName":` + SessionName + `,` + + sessionConfig = `{"cookieName":"` + SessionName + `",` + `"gclifetime":` + strconv.FormatInt(SessionGCMaxLifetime, 10) + `,` + - `"providerConfig":` + SessionSavePath + `,` + + `"providerConfig":"` + SessionSavePath + `",` + `"secure":` + strconv.FormatBool(HttpTLS) + `,` + - `"sessionIDHashFunc":` + SessionHashFunc + `,` + - `"sessionIDHashKey":` + SessionHashKey + `,` + + `"sessionIDHashFunc":"` + SessionHashFunc + `",` + + `"sessionIDHashKey":"` + SessionHashKey + `",` + `"enableSetCookie":` + strconv.FormatBool(SessionAutoSetCookie) + `,` + - `"cookieLifeTime":` + strconv.Itoa(SessionCookieLifeTime) + `,}` + `"cookieLifeTime":` + strconv.Itoa(SessionCookieLifeTime) + `}` } - GlobalSessions, _ = session.NewManager(SessionProvider, + GlobalSessions, err := session.NewManager(SessionProvider, sessionConfig) + if err != nil { + panic(err) + } go GlobalSessions.GC() } From afadb3f6df6c44e4e2f0a62420af917f362eacc8 Mon Sep 17 00:00:00 2001 From: sol lu Date: Fri, 10 Jan 2014 13:31:08 +0800 Subject: [PATCH 30/46] Update beego.go --- beego.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/beego.go b/beego.go index 57d4822f..e4a96c17 100644 --- a/beego.go +++ b/beego.go @@ -192,6 +192,7 @@ func Run() { } if SessionOn { + var err error sessionConfig := AppConfig.String("sessionConfig") if sessionConfig == "" { sessionConfig = `{"cookieName":"` + SessionName + `",` + @@ -203,7 +204,7 @@ func Run() { `"enableSetCookie":` + strconv.FormatBool(SessionAutoSetCookie) + `,` + `"cookieLifeTime":` + strconv.Itoa(SessionCookieLifeTime) + `}` } - GlobalSessions, err := session.NewManager(SessionProvider, + GlobalSessions, err = session.NewManager(SessionProvider, sessionConfig) if err != nil { panic(err) From 8d79f8387bd5326948910a287d1119725e8cd126 Mon Sep 17 00:00:00 2001 From: slene Date: Fri, 10 Jan 2014 16:50:03 +0800 Subject: [PATCH 31/46] #441 fix detect timezone in mysql --- orm/db_alias.go | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/orm/db_alias.go b/orm/db_alias.go index ca90bf3a..24924312 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -123,21 +123,18 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) { switch al.Driver { case DR_MySQL: - row := al.DB.QueryRow("SELECT @@session.time_zone") + row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)") var tz string row.Scan(&tz) - if tz == "SYSTEM" { - tz = "" - row = al.DB.QueryRow("SELECT @@system_time_zone") - row.Scan(&tz) - t, err := time.Parse("MST", tz) - if err == nil { - al.TZ = t.Location() + if len(tz) >= 8 { + if tz[0] != '-' { + tz = "+" + tz } - } else { - t, err := time.Parse("-07:00", tz) + t, err := time.Parse("-07:00:00", tz) if err == nil { al.TZ = t.Location() + } else { + DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) } } @@ -163,6 +160,8 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) { loc, err := time.LoadLocation(tz) if err == nil { al.TZ = loc + } else { + DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) } } From b64e70e7dfaf1cfa2e538aeb397dd3df2d875ef9 Mon Sep 17 00:00:00 2001 From: Pengfei Xue Date: Fri, 10 Jan 2014 18:31:15 +0800 Subject: [PATCH 32/46] use connection pool for redis cache --- cache/redis.go | 160 ++++++++++++++----------------------------------- 1 file changed, 44 insertions(+), 116 deletions(-) diff --git a/cache/redis.go b/cache/redis.go index 0fac0deb..90047b6f 100644 --- a/cache/redis.go +++ b/cache/redis.go @@ -3,7 +3,8 @@ package cache import ( "encoding/json" "errors" - "io" + "sync" + "time" "github.com/beego/redigo/redis" ) @@ -15,9 +16,10 @@ var ( // Redis cache adapter. type RedisCache struct { - c redis.Conn + p *redis.Pool // redis connection pool conninfo string key string + mu sync.Mutex } // create new redis cache with default collection name. @@ -25,23 +27,17 @@ func NewRedisCache() *RedisCache { return &RedisCache{key: DefaultKey} } +// actually do the redis cmds +func (rc *RedisCache) do(commandName string, args ...interface{}) (reply interface{}, err error) { + c := rc.p.Get() + defer c.Close() + + return c.Do(commandName, args...) +} + // Get cache from redis. func (rc *RedisCache) Get(key string) interface{} { - if rc.c == nil { - var err error - rc.c, err = rc.connectInit() - if err != nil { - return nil - } - } - - v, err := rc.c.Do("HGET", rc.key, key) - // write to closed socket, reset rc.c to nil - if err == io.EOF { - rc.c = nil - return nil - } - + v, err := rc.do("HGET", rc.key, key) if err != nil { return nil } @@ -52,61 +48,19 @@ func (rc *RedisCache) Get(key string) interface{} { // put cache to redis. // timeout is ignored. func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error { - if rc.c == nil { - var err error - rc.c, err = rc.connectInit() - if err != nil { - return err - } - } - - _, err := rc.c.Do("HSET", rc.key, key, val) - // write to closed socket, reset rc.c to nil - if err == io.EOF { - rc.c = nil - return err - } - + _, err := rc.do("HSET", rc.key, key, val) return err } // delete cache in redis. func (rc *RedisCache) Delete(key string) error { - if rc.c == nil { - var err error - rc.c, err = rc.connectInit() - if err != nil { - return err - } - } - - _, err := rc.c.Do("HDEL", rc.key, key) - // write to closed socket, reset rc.c to nil - if err == io.EOF { - rc.c = nil - return err - } - + _, err := rc.do("HDEL", rc.key, key) return err } // check cache exist in redis. func (rc *RedisCache) IsExist(key string) bool { - if rc.c == nil { - var err error - rc.c, err = rc.connectInit() - if err != nil { - return false - } - } - - v, err := redis.Bool(rc.c.Do("HEXISTS", rc.key, key)) - // write to closed socket, reset rc.c to nil - if err == io.EOF { - rc.c = nil - return false - } - + v, err := redis.Bool(rc.do("HEXISTS", rc.key, key)) if err != nil { return false } @@ -116,59 +70,19 @@ func (rc *RedisCache) IsExist(key string) bool { // increase counter in redis. func (rc *RedisCache) Incr(key string) error { - if rc.c == nil { - var err error - rc.c, err = rc.connectInit() - if err != nil { - return err - } - } - - _, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, 1)) - // write to closed socket - if err == io.EOF { - rc.c = nil - } - + _, err := redis.Bool(rc.do("HINCRBY", rc.key, key, 1)) return err } // decrease counter in redis. func (rc *RedisCache) Decr(key string) error { - if rc.c == nil { - var err error - rc.c, err = rc.connectInit() - if err != nil { - return err - } - } - - _, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, -1)) - - // write to closed socket - if err == io.EOF { - rc.c = nil - } - + _, err := redis.Bool(rc.do("HINCRBY", rc.key, key, -1)) return err } // clean all cache in redis. delete this redis collection. func (rc *RedisCache) ClearAll() error { - if rc.c == nil { - var err error - rc.c, err = rc.connectInit() - if err != nil { - return err - } - } - - _, err := rc.c.Do("DEL", rc.key) - // write to closed socket - if err == io.EOF { - rc.c = nil - } - + _, err := rc.do("DEL", rc.key) return err } @@ -179,32 +93,46 @@ func (rc *RedisCache) ClearAll() error { func (rc *RedisCache) StartAndGC(config string) error { var cf map[string]string json.Unmarshal([]byte(config), &cf) + if _, ok := cf["key"]; !ok { cf["key"] = DefaultKey } + if _, ok := cf["conn"]; !ok { return errors.New("config has no conn key") } + rc.key = cf["key"] rc.conninfo = cf["conn"] - var err error - rc.c, err = rc.connectInit() - if err != nil { + rc.connectInit() + + c := rc.p.Get() + defer c.Close() + if err := c.Err(); err != nil { return err } - if rc.c == nil { - return errors.New("dial tcp conn error") - } + return nil } // connect to redis. -func (rc *RedisCache) connectInit() (redis.Conn, error) { - c, err := redis.Dial("tcp", rc.conninfo) - if err != nil { - return nil, err +func (rc *RedisCache) connectInit() { + rc.mu.Lock() + + // initialize a new pool + rc.p = &redis.Pool{ + MaxIdle: 3, + IdleTimeout: 180 * time.Second, + Dial: func() (redis.Conn, error) { + c, err := redis.Dial("tcp", rc.conninfo) + if err != nil { + return nil, err + } + return c, nil + }, } - return c, nil + + rc.mu.Unlock() } func init() { From cb55009c8be136d3aa427f53f9408b45669736d4 Mon Sep 17 00:00:00 2001 From: Pengfei Xue Date: Fri, 10 Jan 2014 20:31:43 +0800 Subject: [PATCH 33/46] remove mutex --- cache/redis.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/cache/redis.go b/cache/redis.go index 90047b6f..ba1d4d49 100644 --- a/cache/redis.go +++ b/cache/redis.go @@ -3,7 +3,6 @@ package cache import ( "encoding/json" "errors" - "sync" "time" "github.com/beego/redigo/redis" @@ -19,7 +18,6 @@ type RedisCache struct { p *redis.Pool // redis connection pool conninfo string key string - mu sync.Mutex } // create new redis cache with default collection name. @@ -117,8 +115,6 @@ func (rc *RedisCache) StartAndGC(config string) error { // connect to redis. func (rc *RedisCache) connectInit() { - rc.mu.Lock() - // initialize a new pool rc.p = &redis.Pool{ MaxIdle: 3, @@ -131,8 +127,6 @@ func (rc *RedisCache) connectInit() { return c, nil }, } - - rc.mu.Unlock() } func init() { From 3b99f37aa121fedaa87a5b19ea6a6d976ec3c75b Mon Sep 17 00:00:00 2001 From: slene Date: Sat, 11 Jan 2014 14:28:11 +0800 Subject: [PATCH 34/46] add a empty fake config Initialize AppConfig to avoid nil pointer runtime error. --- beego.go | 2 +- config.go | 1 + config/fake.go | 58 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 config/fake.go diff --git a/beego.go b/beego.go index e4a96c17..c22719c6 100644 --- a/beego.go +++ b/beego.go @@ -192,7 +192,7 @@ func Run() { } if SessionOn { - var err error + var err error sessionConfig := AppConfig.String("sessionConfig") if sessionConfig == "" { sessionConfig = `{"cookieName":"` + SessionName + `",` + diff --git a/config.go b/config.go index 39a893c0..7c21d696 100644 --- a/config.go +++ b/config.go @@ -141,6 +141,7 @@ func init() { func ParseConfig() (err error) { AppConfig, err = config.NewConfig("ini", AppConfigPath) if err != nil { + AppConfig = config.NewFakeConfig() return err } else { HttpAddr = AppConfig.String("HttpAddr") diff --git a/config/fake.go b/config/fake.go new file mode 100644 index 00000000..05279932 --- /dev/null +++ b/config/fake.go @@ -0,0 +1,58 @@ +package config + +import ( + "errors" + "strconv" + "strings" +) + +type fakeConfigContainer struct { + data map[string]string +} + +func (c *fakeConfigContainer) getData(key string) string { + key = strings.ToLower(key) + return c.data[key] +} + +func (c *fakeConfigContainer) Set(key, val string) error { + key = strings.ToLower(key) + c.data[key] = val + return nil +} + +func (c *fakeConfigContainer) String(key string) string { + return c.getData(key) +} + +func (c *fakeConfigContainer) Int(key string) (int, error) { + return strconv.Atoi(c.getData(key)) +} + +func (c *fakeConfigContainer) Int64(key string) (int64, error) { + return strconv.ParseInt(c.getData(key), 10, 64) +} + +func (c *fakeConfigContainer) Bool(key string) (bool, error) { + return strconv.ParseBool(c.getData(key)) +} + +func (c *fakeConfigContainer) Float(key string) (float64, error) { + return strconv.ParseFloat(c.getData(key), 64) +} + +func (c *fakeConfigContainer) DIY(key string) (interface{}, error) { + key = strings.ToLower(key) + if v, ok := c.data[key]; ok { + return v, nil + } + return nil, errors.New("key not find") +} + +var _ ConfigContainer = new(fakeConfigContainer) + +func NewFakeConfig() ConfigContainer { + return &fakeConfigContainer{ + data: make(map[string]string), + } +} From 6e9ba0ea7f014642f62a915d28e6868cd41f1474 Mon Sep 17 00:00:00 2001 From: slene Date: Sat, 11 Jan 2014 17:01:33 +0800 Subject: [PATCH 35/46] fix SessionRegenerateID should release old SessionStore and release new SessionStore in router.go --- controller.go | 3 ++- router.go | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/controller.go b/controller.go index 04352706..342bf06b 100644 --- a/controller.go +++ b/controller.go @@ -165,7 +165,7 @@ func (c *Controller) RenderBytes() ([]byte, error) { if c.LayoutSections != nil { for sectionName, sectionTpl := range c.LayoutSections { - if (sectionTpl == "") { + if sectionTpl == "" { c.Data[sectionName] = "" continue } @@ -391,6 +391,7 @@ func (c *Controller) DelSession(name interface{}) { // SessionRegenerateID regenerates session id for this session. // the session data have no changes. func (c *Controller) SessionRegenerateID() { + c.CruSession.SessionRelease(c.Ctx.ResponseWriter) c.CruSession = GlobalSessions.SessionRegenerateId(c.Ctx.ResponseWriter, c.Ctx.Request) c.Ctx.Input.CruSession = c.CruSession } diff --git a/router.go b/router.go index fce3885b..bb083768 100644 --- a/router.go +++ b/router.go @@ -529,7 +529,9 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) // session init if SessionOn { context.Input.CruSession = GlobalSessions.SessionStart(w, r) - defer context.Input.CruSession.SessionRelease(w) + defer func() { + context.Input.CruSession.SessionRelease(w) + }() } if !utils.InSlice(strings.ToLower(r.Method), HTTPMETHOD) { From dc767b65dfe4007bb4fa702f2c3977e6cfbe96f7 Mon Sep 17 00:00:00 2001 From: Norman <53300940@qq.com> Date: Tue, 14 Jan 2014 19:54:32 +0800 Subject: [PATCH 36/46] Update SessionExist to close the db connection close the mysql connection --- session/sess_mysql.go | 1 + 1 file changed, 1 insertion(+) diff --git a/session/sess_mysql.go b/session/sess_mysql.go index 3b0c6f3f..7bad9e4a 100644 --- a/session/sess_mysql.go +++ b/session/sess_mysql.go @@ -115,6 +115,7 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) { func (mp *MysqlProvider) SessionExist(sid string) bool { c := mp.connectInit() + defer c.Close() row := c.QueryRow("select session_data from session where session_key=?", sid) var sessiondata []byte err := row.Scan(&sessiondata) From b016102d34f94b91c4a6f27603afda1ae74cf2a7 Mon Sep 17 00:00:00 2001 From: astaxie Date: Wed, 15 Jan 2014 09:40:33 +0800 Subject: [PATCH 37/46] add coding --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 00ad468d..776e5c2d 100644 --- a/README.md +++ b/README.md @@ -34,3 +34,6 @@ More info [beego.me](http://beego.me) beego is licensed under the Apache Licence, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0.html). + +[![Clone in Koding](http://learn.koding.com/btn/clone_d.png)][koding] +[koding]: https://koding.com/Teamwork?import=https://github.com/astaxie/beego/archive/master.zip&c=git1 \ No newline at end of file From fee3c2b8f9a4ca782346044ab4331728997f5421 Mon Sep 17 00:00:00 2001 From: astaxie Date: Wed, 15 Jan 2014 17:19:03 +0800 Subject: [PATCH 38/46] add Strings interface can return []string sep by ; Example: peers = one;Two;Three --- config/config.go | 5 +++-- config/fake.go | 4 ++++ config/ini.go | 5 +++++ config/ini_test.go | 8 ++++++++ config/json.go | 5 +++++ config/xml.go | 6 ++++++ config/yaml.go | 6 ++++++ 7 files changed, 37 insertions(+), 2 deletions(-) diff --git a/config/config.go b/config/config.go index 5fb0dd81..5e4c2e9c 100644 --- a/config/config.go +++ b/config/config.go @@ -6,8 +6,9 @@ import ( // ConfigContainer defines how to get and set value from configuration raw data. type ConfigContainer interface { - Set(key, val string) error // support section::key type in given key when using ini type. - String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + Set(key, val string) error // support section::key type in given key when using ini type. + String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + Strings(key string) []string //get string slice Int(key string) (int, error) Int64(key string) (int64, error) Bool(key string) (bool, error) diff --git a/config/fake.go b/config/fake.go index 05279932..26a9f430 100644 --- a/config/fake.go +++ b/config/fake.go @@ -25,6 +25,10 @@ func (c *fakeConfigContainer) String(key string) string { return c.getData(key) } +func (c *fakeConfigContainer) Strings(key string) []string { + return strings.Split(c.getData(key), ";") +} + func (c *fakeConfigContainer) Int(key string) (int, error) { return strconv.Atoi(c.getData(key)) } diff --git a/config/ini.go b/config/ini.go index 22c23f40..75e6486c 100644 --- a/config/ini.go +++ b/config/ini.go @@ -146,6 +146,11 @@ func (c *IniConfigContainer) String(key string) string { return c.getdata(key) } +// Strings returns the []string value for a given key. +func (c *IniConfigContainer) Strings(key string) []string { + return strings.Split(c.String(key), ";") +} + // WriteValue writes a new value for key. // if write to one section, the key need be "section::key". // if the section is not existed, it panics. diff --git a/config/ini_test.go b/config/ini_test.go index cf87e77c..08a69e50 100644 --- a/config/ini_test.go +++ b/config/ini_test.go @@ -19,6 +19,7 @@ copyrequestbody = true key1="asta" key2 = "xie" CaseInsensitive = true +peers = one;two;three ` func TestIni(t *testing.T) { @@ -78,4 +79,11 @@ func TestIni(t *testing.T) { if v, err := iniconf.Bool("demo::caseinsensitive"); err != nil || v != true { t.Fatal("get demo.caseinsensitive error") } + + if data := iniconf.Strings("demo::peers"); len(data) != 3 { + t.Fatal("get strings error", data) + } else if data[0] != "one" { + t.Fatal("get first params error not equat to one") + } + } diff --git a/config/json.go b/config/json.go index 883e0674..24874e8a 100644 --- a/config/json.go +++ b/config/json.go @@ -116,6 +116,11 @@ func (c *JsonConfigContainer) String(key string) string { return "" } +// Strings returns the []string value for a given key. +func (c *JsonConfigContainer) Strings(key string) []string { + return strings.Split(c.String(key), ";") +} + // WriteValue writes a new value for key. func (c *JsonConfigContainer) Set(key, val string) error { c.Lock() diff --git a/config/xml.go b/config/xml.go index 35f19336..7943d8fe 100644 --- a/config/xml.go +++ b/config/xml.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "os" "strconv" + "strings" "sync" "github.com/beego/x2j" @@ -72,6 +73,11 @@ func (c *XMLConfigContainer) String(key string) string { return "" } +// Strings returns the []string value for a given key. +func (c *XMLConfigContainer) Strings(key string) []string { + return strings.Split(c.String(key), ";") +} + // WriteValue writes a new value for key. func (c *XMLConfigContainer) Set(key, val string) error { c.Lock() diff --git a/config/yaml.go b/config/yaml.go index 394cb3b2..bd10b846 100644 --- a/config/yaml.go +++ b/config/yaml.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "log" "os" + "strings" "sync" "github.com/beego/goyaml2" @@ -117,6 +118,11 @@ func (c *YAMLConfigContainer) String(key string) string { return "" } +// Strings returns the []string value for a given key. +func (c *YAMLConfigContainer) Strings(key string) []string { + return strings.Split(c.String(key), ";") +} + // WriteValue writes a new value for key. func (c *YAMLConfigContainer) Set(key, val string) error { c.Lock() From f419c12427da0234236b89f74ad049ff85dab6bb Mon Sep 17 00:00:00 2001 From: slene Date: Thu, 16 Jan 2014 20:53:35 +0800 Subject: [PATCH 39/46] add captcha util --- controller.go | 15 +- utils/captcha/captcha.go | 184 ++++++++++++++ utils/captcha/image.go | 484 ++++++++++++++++++++++++++++++++++++ utils/captcha/image_test.go | 38 +++ utils/rand.go | 20 ++ 5 files changed, 728 insertions(+), 13 deletions(-) create mode 100644 utils/captcha/captcha.go create mode 100644 utils/captcha/image.go create mode 100644 utils/captcha/image_test.go create mode 100644 utils/rand.go diff --git a/controller.go b/controller.go index 342bf06b..034e4cb3 100644 --- a/controller.go +++ b/controller.go @@ -3,7 +3,6 @@ package beego import ( "bytes" "crypto/hmac" - "crypto/rand" "crypto/sha1" "encoding/base64" "errors" @@ -22,6 +21,7 @@ import ( "github.com/astaxie/beego/context" "github.com/astaxie/beego/session" + "github.com/astaxie/beego/utils" ) var ( @@ -455,7 +455,7 @@ func (c *Controller) XsrfToken() string { } else { expire = int64(XSRFExpire) } - token = getRandomString(15) + token = string(utils.RandomCreateBytes(15)) c.SetSecureCookie(XSRFKEY, "_xsrf", token, expire) } c._xsrf_token = token @@ -492,14 +492,3 @@ func (c *Controller) XsrfFormHtml() string { func (c *Controller) GetControllerAndAction() (controllerName, actionName string) { return c.controllerName, c.actionName } - -// getRandomString returns random string. -func getRandomString(n int) string { - const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" - var bytes = make([]byte, n) - rand.Read(bytes) - for i, b := range bytes { - bytes[i] = alphanum[b%byte(len(alphanum))] - } - return string(bytes) -} diff --git a/utils/captcha/captcha.go b/utils/captcha/captcha.go new file mode 100644 index 00000000..174197a0 --- /dev/null +++ b/utils/captcha/captcha.go @@ -0,0 +1,184 @@ +package captcha + +// modifiy and integrated to Beego in one file from https://github.com/dchest/captcha + +import ( + "fmt" + "html/template" + "net/http" + "path" + "strings" + + "github.com/astaxie/beego" + "github.com/astaxie/beego/cache" + "github.com/astaxie/beego/context" + "github.com/astaxie/beego/utils" +) + +var ( + defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} +) + +const ( + challengeNums = 6 + expiration = 600 + fieldIdName = "captcha_id" + fieldCaptchaName = "captcha" + cachePrefix = "captcha_" + urlPrefix = "/captcha/" +) + +type Captcha struct { + store cache.Cache + urlPrefix string + FieldIdName string + FieldCaptchaName string + StdWidth int + StdHeight int + ChallengeNums int + Expiration int64 + CachePrefix string +} + +func (c *Captcha) key(id string) string { + return c.CachePrefix + id +} + +func (c *Captcha) genRandChars() []byte { + return utils.RandomCreateBytes(c.ChallengeNums, defaultChars...) +} + +func (c *Captcha) Handler(ctx *context.Context) { + var chars []byte + + id := path.Base(ctx.Request.RequestURI) + if i := strings.Index(id, "."); i != -1 { + id = id[:i] + } + + key := c.key(id) + + if v, ok := c.store.Get(key).([]byte); ok { + chars = v + } else { + ctx.Output.SetStatus(404) + ctx.WriteString("captcha not found") + return + } + + // reload captcha + if len(ctx.Input.Query("reload")) > 0 { + chars = c.genRandChars() + if err := c.store.Put(key, chars, c.Expiration); err != nil { + ctx.Output.SetStatus(500) + ctx.WriteString("captcha reload error") + beego.Error("Reload Create Captcha Error:", err) + return + } + } + + img := NewImage(chars, c.StdWidth, c.StdHeight) + if _, err := img.WriteTo(ctx.ResponseWriter); err != nil { + beego.Error("Write Captcha Image Error:", err) + } +} + +func (c *Captcha) CreateCaptchaHtml() template.HTML { + value, err := c.CreateCaptcha() + if err != nil { + beego.Error("Create Captcha Error:", err) + return "" + } + + // create html + return template.HTML(fmt.Sprintf(``+ + ``+ + ``+ + ``, c.FieldIdName, value, c.urlPrefix, value, c.urlPrefix, value)) +} + +func (c *Captcha) CreateCaptcha() (string, error) { + // generate captcha id + id := string(utils.RandomCreateBytes(15)) + + // get the captcha chars + chars := c.genRandChars() + + // save to store + if err := c.store.Put(c.key(id), chars, c.Expiration); err != nil { + return "", err + } + + return id, nil +} + +func (c *Captcha) VerifyReq(req *http.Request) bool { + req.ParseForm() + return c.Verify(req.Form.Get(c.FieldIdName), req.Form.Get(c.FieldCaptchaName)) +} + +func (c *Captcha) Verify(id string, challenge string) (success bool) { + if len(challenge) == 0 || len(id) == 0 { + return + } + + var chars []byte + + key := c.key(id) + + if v, ok := c.store.Get(key).([]byte); ok && len(v) == len(challenge) { + chars = v + } else { + return + } + + defer func() { + // finally remove it + c.store.Delete(key) + }() + + // verify challenge + for i, c := range chars { + if c != challenge[i]-48 { + return + } + } + + return true +} + +func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha { + cpt := &Captcha{} + cpt.store = store + cpt.FieldIdName = fieldIdName + cpt.FieldCaptchaName = fieldCaptchaName + cpt.ChallengeNums = challengeNums + cpt.Expiration = expiration + cpt.CachePrefix = cachePrefix + cpt.StdWidth = stdWidth + cpt.StdHeight = stdHeight + + if len(urlPrefix) == 0 { + urlPrefix = urlPrefix + } + + if urlPrefix[len(urlPrefix)-1] != '/' { + urlPrefix += "/" + } + + cpt.urlPrefix = urlPrefix + + return cpt +} + +func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha { + cpt := NewCaptcha(urlPrefix, store) + + // create filter for serve captcha image + beego.AddFilter(urlPrefix+":", "BeforeRouter", cpt.Handler) + + // add to template func map + beego.AddFuncMap("create_captcha", cpt.CreateCaptchaHtml) + + return cpt +} diff --git a/utils/captcha/image.go b/utils/captcha/image.go new file mode 100644 index 00000000..d9ae13a6 --- /dev/null +++ b/utils/captcha/image.go @@ -0,0 +1,484 @@ +package captcha + +import ( + "bytes" + "image" + "image/color" + "image/png" + "io" + "math" + "math/rand" + "time" +) + +const ( + fontWidth = 11 + fontHeight = 18 + blackChar = 1 + + // Standard width and height of a captcha image. + stdWidth = 240 + stdHeight = 80 + // Maximum absolute skew factor of a single digit. + maxSkew = 0.7 + // Number of background circles. + circleCount = 20 +) + +var font = [][]byte{ + { // 0 + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, + 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + }, + { // 1 + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + }, + { // 2 + 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + }, + { // 3 + 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, + }, + { // 4 + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, + 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, + 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, + 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + }, + { // 5 + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, + }, + { // 6 + 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, + 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, + 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, + 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, + 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + }, + { // 7 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, + }, + { // 8 + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, + 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, + 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, + 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + }, + { // 9 + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, + 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, + 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, + 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, + }, +} + +type Image struct { + *image.Paletted + numWidth int + numHeight int + dotSize int +} + +func getrand() *rand.Rand { + return rand.New(rand.NewSource(time.Now().UnixNano())) +} + +func randIntn(max int) int { + return getrand().Intn(max) +} + +func randInt(min, max int) int { + return getrand().Intn(max-min) + min +} + +func randFloat(min, max float64) float64 { + return (max-min)*getrand().Float64() + min +} + +func randomPalette() color.Palette { + p := make([]color.Color, circleCount+1) + // Transparent color. + p[0] = color.RGBA{0xFF, 0xFF, 0xFF, 0x00} + // Primary color. + prim := color.RGBA{ + uint8(randIntn(129)), + uint8(randIntn(129)), + uint8(randIntn(129)), + 0xFF, + } + p[1] = prim + // Circle colors. + for i := 2; i <= circleCount; i++ { + p[i] = randomBrightness(prim, 255) + } + return p +} + +// NewImage returns a new captcha image of the given width and height with the +// given digits, where each digit must be in range 0-9. +func NewImage(digits []byte, width, height int) *Image { + m := new(Image) + m.Paletted = image.NewPaletted(image.Rect(0, 0, width, height), randomPalette()) + m.calculateSizes(width, height, len(digits)) + // Randomly position captcha inside the image. + maxx := width - (m.numWidth+m.dotSize)*len(digits) - m.dotSize + maxy := height - m.numHeight - m.dotSize*2 + var border int + if width > height { + border = height / 5 + } else { + border = width / 5 + } + x := randInt(border, maxx-border) + y := randInt(border, maxy-border) + // Draw digits. + for _, n := range digits { + m.drawDigit(font[n], x, y) + x += m.numWidth + m.dotSize + } + // Draw strike-through line. + m.strikeThrough() + // Apply wave distortion. + m.distort(randFloat(5, 10), randFloat(100, 200)) + // Fill image with random circles. + m.fillWithCircles(circleCount, m.dotSize) + return m +} + +// encodedPNG encodes an image to PNG and returns +// the result as a byte slice. +func (m *Image) encodedPNG() []byte { + var buf bytes.Buffer + if err := png.Encode(&buf, m.Paletted); err != nil { + panic(err.Error()) + } + return buf.Bytes() +} + +// WriteTo writes captcha image in PNG format into the given writer. +func (m *Image) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(m.encodedPNG()) + return int64(n), err +} + +func (m *Image) calculateSizes(width, height, ncount int) { + // Goal: fit all digits inside the image. + var border int + if width > height { + border = height / 4 + } else { + border = width / 4 + } + // Convert everything to floats for calculations. + w := float64(width - border*2) + h := float64(height - border*2) + // fw takes into account 1-dot spacing between digits. + fw := float64(fontWidth + 1) + fh := float64(fontHeight) + nc := float64(ncount) + // Calculate the width of a single digit taking into account only the + // width of the image. + nw := w / nc + // Calculate the height of a digit from this width. + nh := nw * fh / fw + // Digit too high? + if nh > h { + // Fit digits based on height. + nh = h + nw = fw / fh * nh + } + // Calculate dot size. + m.dotSize = int(nh / fh) + // Save everything, making the actual width smaller by 1 dot to account + // for spacing between digits. + m.numWidth = int(nw) - m.dotSize + m.numHeight = int(nh) +} + +func (m *Image) drawHorizLine(fromX, toX, y int, colorIdx uint8) { + for x := fromX; x <= toX; x++ { + m.SetColorIndex(x, y, colorIdx) + } +} + +func (m *Image) drawCircle(x, y, radius int, colorIdx uint8) { + f := 1 - radius + dfx := 1 + dfy := -2 * radius + xo := 0 + yo := radius + + m.SetColorIndex(x, y+radius, colorIdx) + m.SetColorIndex(x, y-radius, colorIdx) + m.drawHorizLine(x-radius, x+radius, y, colorIdx) + + for xo < yo { + if f >= 0 { + yo-- + dfy += 2 + f += dfy + } + xo++ + dfx += 2 + f += dfx + m.drawHorizLine(x-xo, x+xo, y+yo, colorIdx) + m.drawHorizLine(x-xo, x+xo, y-yo, colorIdx) + m.drawHorizLine(x-yo, x+yo, y+xo, colorIdx) + m.drawHorizLine(x-yo, x+yo, y-xo, colorIdx) + } +} + +func (m *Image) fillWithCircles(n, maxradius int) { + maxx := m.Bounds().Max.X + maxy := m.Bounds().Max.Y + for i := 0; i < n; i++ { + colorIdx := uint8(randInt(1, circleCount-1)) + r := randInt(1, maxradius) + m.drawCircle(randInt(r, maxx-r), randInt(r, maxy-r), r, colorIdx) + } +} + +func (m *Image) strikeThrough() { + maxx := m.Bounds().Max.X + maxy := m.Bounds().Max.Y + y := randInt(maxy/3, maxy-maxy/3) + amplitude := randFloat(5, 20) + period := randFloat(80, 180) + dx := 2.0 * math.Pi / period + for x := 0; x < maxx; x++ { + xo := amplitude * math.Cos(float64(y)*dx) + yo := amplitude * math.Sin(float64(x)*dx) + for yn := 0; yn < m.dotSize; yn++ { + r := randInt(0, m.dotSize) + m.drawCircle(x+int(xo), y+int(yo)+(yn*m.dotSize), r/2, 1) + } + } +} + +func (m *Image) drawDigit(digit []byte, x, y int) { + skf := randFloat(-maxSkew, maxSkew) + xs := float64(x) + r := m.dotSize / 2 + y += randInt(-r, r) + for yo := 0; yo < fontHeight; yo++ { + for xo := 0; xo < fontWidth; xo++ { + if digit[yo*fontWidth+xo] != blackChar { + continue + } + m.drawCircle(x+xo*m.dotSize, y+yo*m.dotSize, r, 1) + } + xs += skf + x = int(xs) + } +} + +func (m *Image) distort(amplude float64, period float64) { + w := m.Bounds().Max.X + h := m.Bounds().Max.Y + + oldm := m.Paletted + newm := image.NewPaletted(image.Rect(0, 0, w, h), oldm.Palette) + + dx := 2.0 * math.Pi / period + for x := 0; x < w; x++ { + for y := 0; y < h; y++ { + xo := amplude * math.Sin(float64(y)*dx) + yo := amplude * math.Cos(float64(x)*dx) + newm.SetColorIndex(x, y, oldm.ColorIndexAt(x+int(xo), y+int(yo))) + } + } + m.Paletted = newm +} + +func randomBrightness(c color.RGBA, max uint8) color.RGBA { + minc := min3(c.R, c.G, c.B) + maxc := max3(c.R, c.G, c.B) + if maxc > max { + return c + } + n := randIntn(int(max-maxc)) - int(minc) + return color.RGBA{ + uint8(int(c.R) + n), + uint8(int(c.G) + n), + uint8(int(c.B) + n), + uint8(c.A), + } +} + +func min3(x, y, z uint8) (m uint8) { + m = x + if y < m { + m = y + } + if z < m { + m = z + } + return +} + +func max3(x, y, z uint8) (m uint8) { + m = x + if y > m { + m = y + } + if z > m { + m = z + } + return +} diff --git a/utils/captcha/image_test.go b/utils/captcha/image_test.go new file mode 100644 index 00000000..e80cc42c --- /dev/null +++ b/utils/captcha/image_test.go @@ -0,0 +1,38 @@ +package captcha + +import ( + "testing" + + "github.com/astaxie/beego/utils" +) + +type byteCounter struct { + n int64 +} + +func (bc *byteCounter) Write(b []byte) (int, error) { + bc.n += int64(len(b)) + return len(b), nil +} + +func BenchmarkNewImage(b *testing.B) { + b.StopTimer() + d := utils.RandomCreateBytes(challengeNums, defaultChars...) + b.StartTimer() + for i := 0; i < b.N; i++ { + NewImage(d, stdWidth, stdHeight) + } +} + +func BenchmarkImageWriteTo(b *testing.B) { + b.StopTimer() + d := utils.RandomCreateBytes(challengeNums, defaultChars...) + b.StartTimer() + counter := &byteCounter{} + for i := 0; i < b.N; i++ { + img := NewImage(d, stdWidth, stdHeight) + img.WriteTo(counter) + b.SetBytes(counter.n) + counter.n = 0 + } +} diff --git a/utils/rand.go b/utils/rand.go new file mode 100644 index 00000000..482c7059 --- /dev/null +++ b/utils/rand.go @@ -0,0 +1,20 @@ +package utils + +import ( + "crypto/rand" +) + +// RandomCreateBytes generate random []byte by specify chars. +func RandomCreateBytes(n int, alphabets ...byte) []byte { + const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + var bytes = make([]byte, n) + rand.Read(bytes) + for i, b := range bytes { + if len(alphabets) == 0 { + bytes[i] = alphanum[b%byte(len(alphanum))] + } else { + bytes[i] = alphabets[b%byte(len(alphabets))] + } + } + return bytes +} From 91cbe1f29b4c6f30b915075c3bade60ea2b32ad1 Mon Sep 17 00:00:00 2001 From: slene Date: Thu, 16 Jan 2014 21:34:59 +0800 Subject: [PATCH 40/46] add some comments for captcha --- utils/captcha/captcha.go | 83 +++++++++++++++++++++++++++++++++++----- utils/captcha/image.go | 7 ++++ 2 files changed, 80 insertions(+), 10 deletions(-) diff --git a/utils/captcha/captcha.go b/utils/captcha/captcha.go index 174197a0..b4056f7a 100644 --- a/utils/captcha/captcha.go +++ b/utils/captcha/captcha.go @@ -1,7 +1,47 @@ +// an example for use captcha +// +// ``` +// package controllers +// +// import ( +// "github.com/astaxie/beego" +// "github.com/astaxie/beego/cache" +// "github.com/astaxie/beego/utils/captcha" +// ) +// +// var cpt *captcha.Captcha +// +// func init() { +// store := cache.NewMemoryCache() +// cpt = captcha.NewWithFilter("/captcha/", store) +// } +// +// type MainController struct { +// beego.Controller +// } +// +// func (this *MainController) Get() { +// this.TplNames = "index.tpl" +// } +// +// func (this *MainController) Post() { +// this.TplNames = "index.tpl" +// +// this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) +// } +// ``` +// +// template usage +// +// ``` +// {{.Success}} +//
            +// {{create_captcha}} +// +//
            +// ``` package captcha -// modifiy and integrated to Beego in one file from https://github.com/dchest/captcha - import ( "fmt" "html/template" @@ -20,6 +60,7 @@ var ( ) const ( + // default captcha attributes challengeNums = 6 expiration = 600 fieldIdName = "captcha_id" @@ -29,15 +70,29 @@ const ( ) type Captcha struct { - store cache.Cache - urlPrefix string - FieldIdName string + // beego cache store + store cache.Cache + + // url prefix for captcha image + urlPrefix string + + // specify captcha id input field name + FieldIdName string + // specify captcha result input field name FieldCaptchaName string - StdWidth int - StdHeight int - ChallengeNums int - Expiration int64 - CachePrefix string + + // captcha image width and height + StdWidth int + StdHeight int + + // captcha chars nums + ChallengeNums int + + // captcha expiration seconds + Expiration int64 + + // cache key prefix + CachePrefix string } func (c *Captcha) key(id string) string { @@ -48,6 +103,7 @@ func (c *Captcha) genRandChars() []byte { return utils.RandomCreateBytes(c.ChallengeNums, defaultChars...) } +// beego filter handler for serve captcha image func (c *Captcha) Handler(ctx *context.Context) { var chars []byte @@ -83,6 +139,7 @@ func (c *Captcha) Handler(ctx *context.Context) { } } +// tempalte func for output html func (c *Captcha) CreateCaptchaHtml() template.HTML { value, err := c.CreateCaptcha() if err != nil { @@ -97,6 +154,7 @@ func (c *Captcha) CreateCaptchaHtml() template.HTML { ``, c.FieldIdName, value, c.urlPrefix, value, c.urlPrefix, value)) } +// create a new captcha id func (c *Captcha) CreateCaptcha() (string, error) { // generate captcha id id := string(utils.RandomCreateBytes(15)) @@ -112,11 +170,13 @@ func (c *Captcha) CreateCaptcha() (string, error) { return id, nil } +// verify from a request func (c *Captcha) VerifyReq(req *http.Request) bool { req.ParseForm() return c.Verify(req.Form.Get(c.FieldIdName), req.Form.Get(c.FieldCaptchaName)) } +// direct verify id and challenge string func (c *Captcha) Verify(id string, challenge string) (success bool) { if len(challenge) == 0 || len(id) == 0 { return @@ -147,6 +207,7 @@ func (c *Captcha) Verify(id string, challenge string) (success bool) { return true } +// create a new captcha.Captcha func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha { cpt := &Captcha{} cpt.store = store @@ -171,6 +232,8 @@ func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha { return cpt } +// create a new captcha.Captcha and auto AddFilter for serve captacha image +// and add a tempalte func for output html func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha { cpt := NewCaptcha(urlPrefix, store) diff --git a/utils/captcha/image.go b/utils/captcha/image.go index d9ae13a6..af1ed167 100644 --- a/utils/captcha/image.go +++ b/utils/captcha/image.go @@ -1,3 +1,4 @@ +// modifiy and integrated to Beego from https://github.com/dchest/captcha package captcha import ( @@ -240,10 +241,16 @@ func getrand() *rand.Rand { } func randIntn(max int) int { + if max <= 0 { + return 0 + } return getrand().Intn(max) } func randInt(min, max int) int { + if max-min <= 0 { + return 0 + } return getrand().Intn(max-min) + min } From 7d5ee0d6923d300f38519205f1fe07d73eb99321 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A1=B9=E8=B6=85?= Date: Fri, 17 Jan 2014 00:17:43 +0800 Subject: [PATCH 41/46] Update README.md --- cache/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cache/README.md b/cache/README.md index d4c80fad..c0bfcc59 100644 --- a/cache/README.md +++ b/cache/README.md @@ -43,7 +43,7 @@ interval means the gc time. The cache will check at each time interval, whether ## Memcache adapter -memory adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client. +Memcache adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client. Configure like this: From 91d75e892586b10900dd70141186a5539ffdb776 Mon Sep 17 00:00:00 2001 From: slene Date: Fri, 17 Jan 2014 12:07:30 +0800 Subject: [PATCH 42/46] add readme for captcha, and enhanced performance --- utils/captcha/README.md | 45 ++++++ utils/captcha/captcha.go | 1 + utils/captcha/image.go | 29 ++-- utils/captcha/siprng.go | 264 +++++++++++++++++++++++++++++++++++ utils/captcha/siprng_test.go | 19 +++ 5 files changed, 340 insertions(+), 18 deletions(-) create mode 100644 utils/captcha/README.md create mode 100644 utils/captcha/siprng.go create mode 100644 utils/captcha/siprng_test.go diff --git a/utils/captcha/README.md b/utils/captcha/README.md new file mode 100644 index 00000000..cbdd5bde --- /dev/null +++ b/utils/captcha/README.md @@ -0,0 +1,45 @@ +# Captcha + +an example for use captcha + +``` +package controllers + +import ( + "github.com/astaxie/beego" + "github.com/astaxie/beego/cache" + "github.com/astaxie/beego/utils/captcha" +) + +var cpt *captcha.Captcha + +func init() { + // use beego cache system store the captcha data + store := cache.NewMemoryCache() + cpt = captcha.NewWithFilter("/captcha/", store) +} + +type MainController struct { + beego.Controller +} + +func (this *MainController) Get() { + this.TplNames = "index.tpl" +} + +func (this *MainController) Post() { + this.TplNames = "index.tpl" + + this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) +} +``` + +template usage + +``` +{{.Success}} +
            + {{create_captcha}} + +
            +``` diff --git a/utils/captcha/captcha.go b/utils/captcha/captcha.go index b4056f7a..beaff19c 100644 --- a/utils/captcha/captcha.go +++ b/utils/captcha/captcha.go @@ -12,6 +12,7 @@ // var cpt *captcha.Captcha // // func init() { +// // use beego cache system store the captcha data // store := cache.NewMemoryCache() // cpt = captcha.NewWithFilter("/captcha/", store) // } diff --git a/utils/captcha/image.go b/utils/captcha/image.go index af1ed167..31c66a36 100644 --- a/utils/captcha/image.go +++ b/utils/captcha/image.go @@ -8,8 +8,6 @@ import ( "image/png" "io" "math" - "math/rand" - "time" ) const ( @@ -236,26 +234,21 @@ type Image struct { dotSize int } -func getrand() *rand.Rand { - return rand.New(rand.NewSource(time.Now().UnixNano())) +var prng = &siprng{} + +// randIntn returns a pseudorandom non-negative int in range [0, n). +func randIntn(n int) int { + return prng.Intn(n) } -func randIntn(max int) int { - if max <= 0 { - return 0 - } - return getrand().Intn(max) +// randInt returns a pseudorandom int in range [from, to]. +func randInt(from, to int) int { + return prng.Intn(to+1-from) + from } -func randInt(min, max int) int { - if max-min <= 0 { - return 0 - } - return getrand().Intn(max-min) + min -} - -func randFloat(min, max float64) float64 { - return (max-min)*getrand().Float64() + min +// randFloat returns a pseudorandom float64 in range [from, to]. +func randFloat(from, to float64) float64 { + return (to-from)*prng.Float64() + from } func randomPalette() color.Palette { diff --git a/utils/captcha/siprng.go b/utils/captcha/siprng.go new file mode 100644 index 00000000..6f9274d8 --- /dev/null +++ b/utils/captcha/siprng.go @@ -0,0 +1,264 @@ +// modifiy and integrated to Beego from https://github.com/dchest/captcha +package captcha + +import ( + "crypto/rand" + "encoding/binary" + "io" + "sync" +) + +// siprng is PRNG based on SipHash-2-4. +type siprng struct { + mu sync.Mutex + k0, k1, ctr uint64 +} + +// siphash implements SipHash-2-4, accepting a uint64 as a message. +func siphash(k0, k1, m uint64) uint64 { + // Initialization. + v0 := k0 ^ 0x736f6d6570736575 + v1 := k1 ^ 0x646f72616e646f6d + v2 := k0 ^ 0x6c7967656e657261 + v3 := k1 ^ 0x7465646279746573 + t := uint64(8) << 56 + + // Compression. + v3 ^= m + + // Round 1. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + // Round 2. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + v0 ^= m + + // Compress last block. + v3 ^= t + + // Round 1. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + // Round 2. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + v0 ^= t + + // Finalization. + v2 ^= 0xff + + // Round 1. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + // Round 2. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + // Round 3. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + // Round 4. + v0 += v1 + v1 = v1<<13 | v1>>(64-13) + v1 ^= v0 + v0 = v0<<32 | v0>>(64-32) + + v2 += v3 + v3 = v3<<16 | v3>>(64-16) + v3 ^= v2 + + v0 += v3 + v3 = v3<<21 | v3>>(64-21) + v3 ^= v0 + + v2 += v1 + v1 = v1<<17 | v1>>(64-17) + v1 ^= v2 + v2 = v2<<32 | v2>>(64-32) + + return v0 ^ v1 ^ v2 ^ v3 +} + +// rekey sets a new PRNG key, which is read from crypto/rand. +func (p *siprng) rekey() { + var k [16]byte + if _, err := io.ReadFull(rand.Reader, k[:]); err != nil { + panic(err.Error()) + } + p.k0 = binary.LittleEndian.Uint64(k[0:8]) + p.k1 = binary.LittleEndian.Uint64(k[8:16]) + p.ctr = 1 +} + +// Uint64 returns a new pseudorandom uint64. +// It rekeys PRNG on the first call and every 64 MB of generated data. +func (p *siprng) Uint64() uint64 { + p.mu.Lock() + if p.ctr == 0 || p.ctr > 8*1024*1024 { + p.rekey() + } + v := siphash(p.k0, p.k1, p.ctr) + p.ctr++ + p.mu.Unlock() + return v +} + +func (p *siprng) Int63() int64 { + return int64(p.Uint64() & 0x7fffffffffffffff) +} + +func (p *siprng) Uint32() uint32 { + return uint32(p.Uint64()) +} + +func (p *siprng) Int31() int32 { + return int32(p.Uint32() & 0x7fffffff) +} + +func (p *siprng) Intn(n int) int { + if n <= 0 { + panic("invalid argument to Intn") + } + if n <= 1<<31-1 { + return int(p.Int31n(int32(n))) + } + return int(p.Int63n(int64(n))) +} + +func (p *siprng) Int63n(n int64) int64 { + if n <= 0 { + panic("invalid argument to Int63n") + } + max := int64((1 << 63) - 1 - (1<<63)%uint64(n)) + v := p.Int63() + for v > max { + v = p.Int63() + } + return v % n +} + +func (p *siprng) Int31n(n int32) int32 { + if n <= 0 { + panic("invalid argument to Int31n") + } + max := int32((1 << 31) - 1 - (1<<31)%uint32(n)) + v := p.Int31() + for v > max { + v = p.Int31() + } + return v % n +} + +func (p *siprng) Float64() float64 { return float64(p.Int63()) / (1 << 63) } diff --git a/utils/captcha/siprng_test.go b/utils/captcha/siprng_test.go new file mode 100644 index 00000000..2fb81f7e --- /dev/null +++ b/utils/captcha/siprng_test.go @@ -0,0 +1,19 @@ +package captcha + +import "testing" + +func TestSiphash(t *testing.T) { + good := uint64(0xe849e8bb6ffe2567) + cur := siphash(0, 0, 0) + if cur != good { + t.Fatalf("siphash: expected %x, got %x", good, cur) + } +} + +func BenchmarkSiprng(b *testing.B) { + b.SetBytes(8) + p := &siprng{} + for i := 0; i < b.N; i++ { + p.Uint64() + } +} From 32799bc259973944b7dfe38dfd3f432bde6c7e08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=B0=8F=E9=BB=91?= Date: Fri, 17 Jan 2014 16:03:01 +0800 Subject: [PATCH 43/46] add comments for middleware packages, fix typo error --- beego.go | 2 +- middleware/error.go | 19 +++++++++++++------ middleware/exceptions.go | 5 ++++- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/beego.go b/beego.go index c22719c6..640fb5ab 100644 --- a/beego.go +++ b/beego.go @@ -221,7 +221,7 @@ func Run() { middleware.VERSION = VERSION middleware.AppName = AppName - middleware.RegisterErrorHander() + middleware.RegisterErrorHandler() if EnableAdmin { go BeeAdminApp.Run() diff --git a/middleware/error.go b/middleware/error.go index 35d9eb59..5c12b533 100644 --- a/middleware/error.go +++ b/middleware/error.go @@ -61,6 +61,7 @@ var tpl = ` ` +// render default application error page with error and stack string. func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack string) { t, _ := template.New("beegoerrortemp").Parse(tpl) data := make(map[string]string) @@ -175,13 +176,14 @@ var errtpl = ` ` +// map of http handlers for each error string. var ErrorMaps map[string]http.HandlerFunc func init() { ErrorMaps = make(map[string]http.HandlerFunc) } -//404 +// show 404 notfound error. func NotFound(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) data := make(map[string]interface{}) @@ -199,7 +201,7 @@ func NotFound(rw http.ResponseWriter, r *http.Request) { t.Execute(rw, data) } -//401 +// show 401 unauthorized error. func Unauthorized(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) data := make(map[string]interface{}) @@ -215,7 +217,7 @@ func Unauthorized(rw http.ResponseWriter, r *http.Request) { t.Execute(rw, data) } -//403 +// show 403 forbidden error. func Forbidden(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) data := make(map[string]interface{}) @@ -232,7 +234,7 @@ func Forbidden(rw http.ResponseWriter, r *http.Request) { t.Execute(rw, data) } -//503 +// show 503 service unavailable error. func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) data := make(map[string]interface{}) @@ -248,7 +250,7 @@ func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) { t.Execute(rw, data) } -//500 +// show 500 internal server error. func InternalServerError(rw http.ResponseWriter, r *http.Request) { t, _ := template.New("beegoerrortemp").Parse(errtpl) data := make(map[string]interface{}) @@ -262,15 +264,18 @@ func InternalServerError(rw http.ResponseWriter, r *http.Request) { t.Execute(rw, data) } +// show 500 internal error with simple text string. func SimpleServerError(rw http.ResponseWriter, r *http.Request) { http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } +// add http handler for given error string. func Errorhandler(err string, h http.HandlerFunc) { ErrorMaps[err] = h } -func RegisterErrorHander() { +// register default error http handlers, 404,401,403,500 and 503. +func RegisterErrorHandler() { if _, ok := ErrorMaps["404"]; !ok { ErrorMaps["404"] = NotFound } @@ -292,6 +297,8 @@ func RegisterErrorHander() { } } +// show error string as simple text message. +// if error string is empty, show 500 error as default. func Exception(errcode string, w http.ResponseWriter, r *http.Request, msg string) { if h, ok := ErrorMaps[errcode]; ok { isint, err := strconv.Atoi(errcode) diff --git a/middleware/exceptions.go b/middleware/exceptions.go index 5bf85956..b221dfcb 100644 --- a/middleware/exceptions.go +++ b/middleware/exceptions.go @@ -2,16 +2,19 @@ package middleware import "fmt" +// http exceptions type HTTPException struct { StatusCode int // http status code 4xx, 5xx Description string } +// return http exception error string, e.g. "400 Bad Request". func (e *HTTPException) Error() string { - // return `status description`, e.g. `400 Bad Request` return fmt.Sprintf("%d %s", e.StatusCode, e.Description) } +// map of http exceptions for each http status code int. +// defined 400,401,403,404,405,500,502,503 and 504 default. var HTTPExceptionMaps map[int]HTTPException func init() { From f5a5ebe16ba8688de57b81b3779742c342a2de18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=B0=8F=E9=BB=91?= Date: Fri, 17 Jan 2014 17:04:15 +0800 Subject: [PATCH 44/46] add comments for orm packages, part 1 --- orm/cmd.go | 12 ++++++++++++ orm/cmd_utils.go | 4 ++++ orm/db.go | 41 ++++++++++++++++++++++++++++++++++++++++- orm/db_alias.go | 19 ++++++++++++++----- 4 files changed, 70 insertions(+), 6 deletions(-) diff --git a/orm/cmd.go b/orm/cmd.go index 97545da4..95be7f4a 100644 --- a/orm/cmd.go +++ b/orm/cmd.go @@ -16,6 +16,7 @@ var ( commands = make(map[string]commander) ) +// print help. func printHelp(errs ...string) { content := `orm command usage: @@ -31,6 +32,7 @@ func printHelp(errs ...string) { os.Exit(2) } +// listen for orm command and then run it if command arguments passed. func RunCommand() { if len(os.Args) < 2 || os.Args[1] != "orm" { return @@ -58,6 +60,7 @@ func RunCommand() { } } +// sync database struct command interface. type commandSyncDb struct { al *alias force bool @@ -66,6 +69,7 @@ type commandSyncDb struct { rtOnError bool } +// parse orm command line arguments. func (d *commandSyncDb) Parse(args []string) { var name string @@ -78,6 +82,7 @@ func (d *commandSyncDb) Parse(args []string) { d.al = getDbAlias(name) } +// run orm line command. func (d *commandSyncDb) Run() error { var drops []string if d.force { @@ -208,10 +213,12 @@ func (d *commandSyncDb) Run() error { return nil } +// database creation commander interface implement. type commandSqlAll struct { al *alias } +// parse orm command line arguments. func (d *commandSqlAll) Parse(args []string) { var name string @@ -222,6 +229,7 @@ func (d *commandSqlAll) Parse(args []string) { d.al = getDbAlias(name) } +// run orm line command. func (d *commandSqlAll) Run() error { sqls, indexes := getDbCreateSql(d.al) var all []string @@ -243,6 +251,10 @@ func init() { commands["sqlall"] = new(commandSqlAll) } +// run syncdb command line. +// name means table's alias name. default is "default". +// force means run next sql if the current is error. +// verbose means show all info when running command or not. func RunSyncdb(name string, force bool, verbose bool) error { BootStrap() diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go index 6fcb4b01..8f6d94db 100644 --- a/orm/cmd_utils.go +++ b/orm/cmd_utils.go @@ -12,6 +12,7 @@ type dbIndex struct { Sql string } +// create database drop sql. func getDbDropSql(al *alias) (sqls []string) { if len(modelCache.cache) == 0 { fmt.Println("no Model found, need register your model") @@ -26,6 +27,7 @@ func getDbDropSql(al *alias) (sqls []string) { return sqls } +// get database column type string. func getColumnTyp(al *alias, fi *fieldInfo) (col string) { T := al.DbBaser.DbTypes() fieldType := fi.fieldType @@ -79,6 +81,7 @@ checkColumn: return } +// create alter sql string. func getColumnAddQuery(al *alias, fi *fieldInfo) string { Q := al.DbBaser.TableQuote() typ := getColumnTyp(al, fi) @@ -90,6 +93,7 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string { return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ) } +// create database creation string. func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { if len(modelCache.cache) == 0 { fmt.Println("no Model found, need register your model") diff --git a/orm/db.go b/orm/db.go index c6e92ec9..10967fc5 100644 --- a/orm/db.go +++ b/orm/db.go @@ -15,7 +15,7 @@ const ( ) var ( - ErrMissPK = errors.New("missed pk value") + ErrMissPK = errors.New("missed pk value") // missing pk error ) var ( @@ -45,12 +45,15 @@ var ( } ) +// an instance of dbBaser interface/ type dbBase struct { ins dbBaser } +// check dbBase implements dbBaser interface. var _ dbBaser = new(dbBase) +// get struct columns values as interface slice. func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) { var columns []string @@ -87,6 +90,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, return } +// get one field value in struct column as interface. func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) { var value interface{} if fi.pk { @@ -155,6 +159,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val return value, nil } +// create insert sql preparation statement object. func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { Q := d.ins.TableQuote() @@ -180,6 +185,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, return stmt, query, err } +// insert struct with prepared statement and given struct reflect value. func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) if err != nil { @@ -200,6 +206,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, } } +// query sql ,read records and persist in dbBaser. func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error { var whereCols []string var args []interface{} @@ -259,6 +266,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo return nil } +// execute insert sql dbQuerier with given struct reflect.Value. func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { names := make([]string, 0, len(mi.fields.dbcols)-1) values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz) @@ -269,6 +277,7 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. return d.InsertValue(q, mi, false, names, values) } +// multi-insert sql with given slice struct reflect.Value. func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { var ( cnt int64 @@ -325,6 +334,8 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul return cnt, nil } +// execute insert sql with given struct and given values. +// insert the given values, not the field values in struct. func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { Q := d.ins.TableQuote() @@ -364,6 +375,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s } } +// execute update sql dbQuerier with given struct reflect.Value. func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { pkName, pkValue, ok := getExistPk(mi, ind) if ok == false { @@ -404,6 +416,8 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. return 0, nil } +// execute delete sql dbQuerier with given struct reflect.Value. +// delete index is pk. func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { pkName, pkValue, ok := getExistPk(mi, ind) if ok == false { @@ -445,6 +459,8 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. return 0, nil } +// update table-related record by querySet. +// need querySet not struct reflect.Value to update related records. func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { columns := make([]string, 0, len(params)) values := make([]interface{}, 0, len(params)) @@ -520,6 +536,8 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con return 0, nil } +// delete related records. +// do UpdateBanch or DeleteBanch by condition of tables' relationship. func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { for _, fi := range mi.fields.fieldsReverse { fi = fi.reverseFieldInfo @@ -546,6 +564,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz * return nil } +// delete table-related records. func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { tables := newDbTables(mi, d.ins) tables.skipEnd = true @@ -623,6 +642,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con return 0, nil } +// read related records. func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { val := reflect.ValueOf(container) @@ -832,6 +852,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi return cnt, nil } +// excute count sql and return count result int64. func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { tables := newDbTables(mi, d.ins) tables.parseRelated(qs.related, qs.relDepth) @@ -852,6 +873,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition return } +// generate sql with replacing operator string placeholders and replaced values. func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { sql := "" params := getFlatParams(fi, args, tz) @@ -909,6 +931,7 @@ func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) { // default not use } +// set values to struct column. func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) { for i, column := range cols { val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() @@ -930,6 +953,7 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, } } +// convert value from database result to value following in field type. func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) { if val == nil { return nil, nil @@ -1082,6 +1106,7 @@ end: } +// set one value to struct column field. func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) { fieldType := fi.fieldType @@ -1156,6 +1181,7 @@ setValue: return value, nil } +// query sql, read values , save to *[]ParamList. func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { var ( @@ -1323,6 +1349,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond return cnt, nil } +// flag of update joined record. func (d *dbBase) SupportUpdateJoin() bool { return true } @@ -1331,30 +1358,37 @@ func (d *dbBase) MaxLimit() uint64 { return 18446744073709551615 } +// return quote. func (d *dbBase) TableQuote() string { return "`" } +// replace value placeholer in parametered sql string. func (d *dbBase) ReplaceMarks(query *string) { // default use `?` as mark, do nothing } +// flag of RETURNING sql. func (d *dbBase) HasReturningID(*modelInfo, *string) bool { return false } +// convert time from db. func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { *t = t.In(tz) } +// convert time to db. func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) { *t = t.In(tz) } +// get database types. func (d *dbBase) DbTypes() map[string]string { return nil } +// gt all tables. func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { tables := make(map[string]bool) query := d.ins.ShowTablesQuery() @@ -1379,6 +1413,7 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { return tables, nil } +// get all cloumns in table. func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { columns := make(map[string][3]string) query := d.ins.ShowColumnsQuery(table) @@ -1405,18 +1440,22 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e return columns, nil } +// not implement. func (d *dbBase) OperatorSql(operator string) string { panic(ErrNotImplement) } +// not implement. func (d *dbBase) ShowTablesQuery() string { panic(ErrNotImplement) } +// not implement. func (d *dbBase) ShowColumnsQuery(table string) string { panic(ErrNotImplement) } +// not implement. func (d *dbBase) IndexExists(dbQuerier, string, string) bool { panic(ErrNotImplement) } diff --git a/orm/db_alias.go b/orm/db_alias.go index 24924312..d50b6ebd 100644 --- a/orm/db_alias.go +++ b/orm/db_alias.go @@ -9,27 +9,32 @@ import ( "time" ) +// database driver constant int. type DriverType int const ( - _ DriverType = iota - DR_MySQL - DR_Sqlite - DR_Oracle - DR_Postgres + _ DriverType = iota // int enum type + DR_MySQL // mysql + DR_Sqlite // sqlite + DR_Oracle // oracle + DR_Postgres // pgsql ) +// database driver string. type driver string +// get type constant int of current driver.. func (d driver) Type() DriverType { a, _ := dataBaseCache.get(string(d)) return a.Driver } +// get name of current driver func (d driver) Name() string { return string(d) } +// check driver iis implemented Driver interface or not. var _ Driver = new(driver) var ( @@ -47,11 +52,13 @@ var ( } ) +// database alias cacher. type _dbCache struct { mux sync.RWMutex cache map[string]*alias } +// add database alias with original name. func (ac *_dbCache) add(name string, al *alias) (added bool) { ac.mux.Lock() defer ac.mux.Unlock() @@ -62,6 +69,7 @@ func (ac *_dbCache) add(name string, al *alias) (added bool) { return } +// get database alias if cached. func (ac *_dbCache) get(name string) (al *alias, ok bool) { ac.mux.RLock() defer ac.mux.RUnlock() @@ -69,6 +77,7 @@ func (ac *_dbCache) get(name string) (al *alias, ok bool) { return } +// get default alias. func (ac *_dbCache) getDefault() (al *alias) { al, _ = ac.get("default") return From 4c527dde65b78d61a0fc9b190dabe9272a309fb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=B0=8F=E9=BB=91?= Date: Fri, 17 Jan 2014 17:25:17 +0800 Subject: [PATCH 45/46] add comments for orm packages, part 2 --- orm/db.go | 1 + orm/db_mysql.go | 9 +++++++++ orm/db_oracle.go | 2 ++ orm/db_postgres.go | 15 +++++++++++++++ orm/db_sqlite.go | 14 ++++++++++++++ orm/db_tables.go | 15 +++++++++++++++ orm/db_utils.go | 3 +++ 7 files changed, 59 insertions(+) diff --git a/orm/db.go b/orm/db.go index 10967fc5..f12e76fb 100644 --- a/orm/db.go +++ b/orm/db.go @@ -927,6 +927,7 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator stri return sql, params } +// gernerate sql string with inner function, such as UPPER(text). func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) { // default not use } diff --git a/orm/db_mysql.go b/orm/db_mysql.go index da123079..566f2992 100644 --- a/orm/db_mysql.go +++ b/orm/db_mysql.go @@ -4,6 +4,7 @@ import ( "fmt" ) +// mysql operators. var mysqlOperators = map[string]string{ "exact": "= ?", "iexact": "LIKE ?", @@ -21,6 +22,7 @@ var mysqlOperators = map[string]string{ "iendswith": "LIKE ?", } +// mysql column field types. var mysqlTypes = map[string]string{ "auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY", @@ -41,29 +43,35 @@ var mysqlTypes = map[string]string{ "float64-decimal": "numeric(%d, %d)", } +// mysql dbBaser implementation. type dbBaseMysql struct { dbBase } var _ dbBaser = new(dbBaseMysql) +// get mysql operator. func (d *dbBaseMysql) OperatorSql(operator string) string { return mysqlOperators[operator] } +// get mysql table field types. func (d *dbBaseMysql) DbTypes() map[string]string { return mysqlTypes } +// show table sql for mysql. func (d *dbBaseMysql) ShowTablesQuery() string { return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" } +// show columns sql of table for mysql. func (d *dbBaseMysql) ShowColumnsQuery(table string) string { return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ "WHERE table_schema = DATABASE() AND table_name = '%s'", table) } +// execute sql to check index exist. func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool { row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) @@ -72,6 +80,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool return cnt > 0 } +// create new mysql dbBaser. func newdbBaseMysql() dbBaser { b := new(dbBaseMysql) b.ins = b diff --git a/orm/db_oracle.go b/orm/db_oracle.go index ca1715ef..8e374122 100644 --- a/orm/db_oracle.go +++ b/orm/db_oracle.go @@ -1,11 +1,13 @@ package orm +// oracle dbBaser type dbBaseOracle struct { dbBase } var _ dbBaser = new(dbBaseOracle) +// create oracle dbBaser. func newdbBaseOracle() dbBaser { b := new(dbBaseOracle) b.ins = b diff --git a/orm/db_postgres.go b/orm/db_postgres.go index 4058fc10..d26511c0 100644 --- a/orm/db_postgres.go +++ b/orm/db_postgres.go @@ -5,6 +5,7 @@ import ( "strconv" ) +// postgresql operators. var postgresOperators = map[string]string{ "exact": "= ?", "iexact": "= UPPER(?)", @@ -20,6 +21,7 @@ var postgresOperators = map[string]string{ "iendswith": "LIKE UPPER(?)", } +// postgresql column field types. var postgresTypes = map[string]string{ "auto": "serial NOT NULL PRIMARY KEY", "pk": "NOT NULL PRIMARY KEY", @@ -40,16 +42,19 @@ var postgresTypes = map[string]string{ "float64-decimal": "numeric(%d, %d)", } +// postgresql dbBaser. type dbBasePostgres struct { dbBase } var _ dbBaser = new(dbBasePostgres) +// get postgresql operator. func (d *dbBasePostgres) OperatorSql(operator string) string { return postgresOperators[operator] } +// generate functioned sql string, such as contains(text). func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { switch operator { case "contains", "startswith", "endswith": @@ -59,6 +64,7 @@ func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, } } +// postgresql unsupports updating joined record. func (d *dbBasePostgres) SupportUpdateJoin() bool { return false } @@ -67,10 +73,13 @@ func (d *dbBasePostgres) MaxLimit() uint64 { return 0 } +// postgresql quote is ". func (d *dbBasePostgres) TableQuote() string { return `"` } +// postgresql value placeholder is $n. +// replace default ? to $n. func (d *dbBasePostgres) ReplaceMarks(query *string) { q := *query num := 0 @@ -97,6 +106,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) { *query = string(data) } +// make returning sql support for postgresql. func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) { if mi.fields.pk.auto { if query != nil { @@ -107,18 +117,22 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) return } +// show table sql for postgresql. func (d *dbBasePostgres) ShowTablesQuery() string { return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')" } +// show table columns sql for postgresql. func (d *dbBasePostgres) ShowColumnsQuery(table string) string { return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table) } +// get column types of postgresql. func (d *dbBasePostgres) DbTypes() map[string]string { return postgresTypes } +// check index exist in postgresql. func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool { query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name) row := db.QueryRow(query) @@ -127,6 +141,7 @@ func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bo return cnt > 0 } +// create new postgresql dbBaser. func newdbBasePostgres() dbBaser { b := new(dbBasePostgres) b.ins = b diff --git a/orm/db_sqlite.go b/orm/db_sqlite.go index 7711ded0..81692e2c 100644 --- a/orm/db_sqlite.go +++ b/orm/db_sqlite.go @@ -5,6 +5,7 @@ import ( "fmt" ) +// sqlite operators. var sqliteOperators = map[string]string{ "exact": "= ?", "iexact": "LIKE ? ESCAPE '\\'", @@ -20,6 +21,7 @@ var sqliteOperators = map[string]string{ "iendswith": "LIKE ? ESCAPE '\\'", } +// sqlite column types. var sqliteTypes = map[string]string{ "auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT", "pk": "NOT NULL PRIMARY KEY", @@ -40,38 +42,47 @@ var sqliteTypes = map[string]string{ "float64-decimal": "decimal", } +// sqlite dbBaser. type dbBaseSqlite struct { dbBase } var _ dbBaser = new(dbBaseSqlite) +// get sqlite operator. func (d *dbBaseSqlite) OperatorSql(operator string) string { return sqliteOperators[operator] } +// generate functioned sql for sqlite. +// only support DATE(text). func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { if fi.fieldType == TypeDateField { *leftCol = fmt.Sprintf("DATE(%s)", *leftCol) } } +// unable updating joined record in sqlite. func (d *dbBaseSqlite) SupportUpdateJoin() bool { return false } +// max int in sqlite. func (d *dbBaseSqlite) MaxLimit() uint64 { return 9223372036854775807 } +// get column types in sqlite. func (d *dbBaseSqlite) DbTypes() map[string]string { return sqliteTypes } +// get show tables sql in sqlite. func (d *dbBaseSqlite) ShowTablesQuery() string { return "SELECT name FROM sqlite_master WHERE type = 'table'" } +// get columns in sqlite. func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { query := d.ins.ShowColumnsQuery(table) rows, err := db.Query(query) @@ -92,10 +103,12 @@ func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]str return columns, nil } +// get show columns sql in sqlite. func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { return fmt.Sprintf("pragma table_info('%s')", table) } +// check index exist in sqlite. func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool { query := fmt.Sprintf("PRAGMA index_list('%s')", table) rows, err := db.Query(query) @@ -113,6 +126,7 @@ func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool return false } +// create new sqlite dbBaser. func newdbBaseSqlite() dbBaser { b := new(dbBaseSqlite) b.ins = b diff --git a/orm/db_tables.go b/orm/db_tables.go index f5cacf38..854c4214 100644 --- a/orm/db_tables.go +++ b/orm/db_tables.go @@ -6,6 +6,7 @@ import ( "time" ) +// table info struct. type dbTable struct { id int index string @@ -18,6 +19,7 @@ type dbTable struct { jtl *dbTable } +// tables collection struct, contains some tables. type dbTables struct { tablesM map[string]*dbTable tables []*dbTable @@ -26,6 +28,8 @@ type dbTables struct { skipEnd bool } +// set table info to collection. +// if not exist, create new. func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { name := strings.Join(names, ExprSep) if j, ok := t.tablesM[name]; ok { @@ -42,6 +46,7 @@ func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) return t.tablesM[name] } +// add table info to collection. func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { name := strings.Join(names, ExprSep) if _, ok := t.tablesM[name]; ok == false { @@ -54,11 +59,14 @@ func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) return t.tablesM[name], false } +// get table info in collection. func (t *dbTables) get(name string) (*dbTable, bool) { j, ok := t.tablesM[name] return j, ok } +// get related fields info in recursive depth loop. +// loop once, depth decreases one. func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string { if depth < 0 || fi.fieldType == RelManyToMany { return related @@ -79,6 +87,7 @@ func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related [] return related } +// parse related fields. func (t *dbTables) parseRelated(rels []string, depth int) { relsNum := len(rels) @@ -140,6 +149,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) { } } +// generate join string. func (t *dbTables) getJoinSql() (join string) { Q := t.base.TableQuote() @@ -186,6 +196,7 @@ func (t *dbTables) getJoinSql() (join string) { return } +// parse orm model struct field tag expression. func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { var ( jtl *dbTable @@ -300,6 +311,7 @@ loopFor: return } +// generate condition sql. func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) { if cond == nil || cond.IsEmpty() { return @@ -364,6 +376,7 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe return } +// generate order sql. func (t *dbTables) getOrderSql(orders []string) (orderSql string) { if len(orders) == 0 { return @@ -392,6 +405,7 @@ func (t *dbTables) getOrderSql(orders []string) (orderSql string) { return } +// generate limit sql. func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) { if limit == 0 { limit = int64(DefaultRowsLimit) @@ -414,6 +428,7 @@ func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits return } +// crete new tables collection. func newDbTables(mi *modelInfo, base dbBaser) *dbTables { tables := &dbTables{} tables.tablesM = make(map[string]*dbTable) diff --git a/orm/db_utils.go b/orm/db_utils.go index e2178294..34de8186 100644 --- a/orm/db_utils.go +++ b/orm/db_utils.go @@ -6,6 +6,7 @@ import ( "time" ) +// get table alias. func getDbAlias(name string) *alias { if al, ok := dataBaseCache.get(name); ok { return al @@ -15,6 +16,7 @@ func getDbAlias(name string) *alias { return nil } +// get pk column info. func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { fi := mi.fields.pk @@ -37,6 +39,7 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac return } +// get fields description as flatted string. func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) { outFor: From 828a30606999e2e5fa2cac11327bb11378f50190 Mon Sep 17 00:00:00 2001 From: FuXiaoHei Date: Fri, 17 Jan 2014 23:28:54 +0800 Subject: [PATCH 46/46] add comments for orm package, done --- orm/models.go | 7 +++++++ orm/models_boot.go | 6 ++++++ orm/models_info_f.go | 8 ++++++++ orm/models_info_m.go | 4 ++++ orm/models_utils.go | 8 ++++++++ orm/orm.go | 37 +++++++++++++++++++++++++++++++++++++ orm/orm_conds.go | 11 +++++++++++ orm/orm_log.go | 5 +++++ orm/orm_object.go | 4 ++++ orm/orm_querym2m.go | 13 +++++++++++++ orm/orm_queryset.go | 37 +++++++++++++++++++++++++++++++++++++ orm/orm_raw.go | 12 ++++++++++++ orm/types.go | 13 +++++++++++++ orm/utils.go | 27 +++++++++++++++++++++++++++ 14 files changed, 192 insertions(+) diff --git a/orm/models.go b/orm/models.go index 1cb25c4c..5744d865 100644 --- a/orm/models.go +++ b/orm/models.go @@ -41,6 +41,7 @@ var ( } ) +// model info collection type _modelCache struct { sync.RWMutex orders []string @@ -49,6 +50,7 @@ type _modelCache struct { done bool } +// get all model info func (mc *_modelCache) all() map[string]*modelInfo { m := make(map[string]*modelInfo, len(mc.cache)) for k, v := range mc.cache { @@ -57,6 +59,7 @@ func (mc *_modelCache) all() map[string]*modelInfo { return m } +// get orderd model info func (mc *_modelCache) allOrdered() []*modelInfo { m := make([]*modelInfo, 0, len(mc.orders)) for _, table := range mc.orders { @@ -65,16 +68,19 @@ func (mc *_modelCache) allOrdered() []*modelInfo { return m } +// get model info by table name func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) { mi, ok = mc.cache[table] return } +// get model info by field name func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) { mi, ok = mc.cacheByFN[name] return } +// set model info to collection func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { mii := mc.cache[table] mc.cache[table] = mi @@ -85,6 +91,7 @@ func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { return mii } +// clean all model info. func (mc *_modelCache) clean() { mc.orders = make([]string, 0) mc.cache = make(map[string]*modelInfo) diff --git a/orm/models_boot.go b/orm/models_boot.go index 3274b187..03caeb62 100644 --- a/orm/models_boot.go +++ b/orm/models_boot.go @@ -8,6 +8,8 @@ import ( "strings" ) +// register models. +// prefix means table name prefix. func registerModel(model interface{}, prefix string) { val := reflect.ValueOf(model) ind := reflect.Indirect(val) @@ -67,6 +69,7 @@ func registerModel(model interface{}, prefix string) { modelCache.set(table, info) } +// boostrap models func bootStrap() { if modelCache.done { return @@ -281,6 +284,7 @@ end: } } +// register models func RegisterModel(models ...interface{}) { if modelCache.done { panic(fmt.Errorf("RegisterModel must be run before BootStrap")) @@ -302,6 +306,8 @@ func RegisterModelWithPrefix(prefix string, models ...interface{}) { } } +// bootrap models. +// make all model parsed and can not add more models func BootStrap() { if modelCache.done { return diff --git a/orm/models_info_f.go b/orm/models_info_f.go index 03736091..fadbb335 100644 --- a/orm/models_info_f.go +++ b/orm/models_info_f.go @@ -9,6 +9,7 @@ import ( var errSkipField = errors.New("skip field") +// field info collection type fields struct { pk *fieldInfo columns map[string]*fieldInfo @@ -23,6 +24,7 @@ type fields struct { dbcols []string } +// add field info func (f *fields) Add(fi *fieldInfo) (added bool) { if f.fields[fi.name] == nil && f.columns[fi.column] == nil { f.columns[fi.column] = fi @@ -49,14 +51,17 @@ func (f *fields) Add(fi *fieldInfo) (added bool) { return true } +// get field info by name func (f *fields) GetByName(name string) *fieldInfo { return f.fields[name] } +// get field info by column name func (f *fields) GetByColumn(column string) *fieldInfo { return f.columns[column] } +// get field info by string, name is prior func (f *fields) GetByAny(name string) (*fieldInfo, bool) { if fi, ok := f.fields[name]; ok { return fi, ok @@ -70,6 +75,7 @@ func (f *fields) GetByAny(name string) (*fieldInfo, bool) { return nil, false } +// create new field info collection func newFields() *fields { f := new(fields) f.fields = make(map[string]*fieldInfo) @@ -79,6 +85,7 @@ func newFields() *fields { return f } +// single field info type fieldInfo struct { mi *modelInfo fieldIndex int @@ -115,6 +122,7 @@ type fieldInfo struct { onDelete string } +// new field info func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (fi *fieldInfo, err error) { var ( tag string diff --git a/orm/models_info_m.go b/orm/models_info_m.go index 7a173781..b596fc6a 100644 --- a/orm/models_info_m.go +++ b/orm/models_info_m.go @@ -7,6 +7,7 @@ import ( "reflect" ) +// single model info type modelInfo struct { pkg string name string @@ -20,6 +21,7 @@ type modelInfo struct { isThrough bool } +// new model info func newModelInfo(val reflect.Value) (info *modelInfo) { var ( err error @@ -79,6 +81,8 @@ func newModelInfo(val reflect.Value) (info *modelInfo) { return } +// combine related model info to new model info. +// prepare for relation models query. func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) { info = new(modelInfo) info.fields = newFields() diff --git a/orm/models_utils.go b/orm/models_utils.go index 38095b7e..1466a724 100644 --- a/orm/models_utils.go +++ b/orm/models_utils.go @@ -7,10 +7,12 @@ import ( "time" ) +// get reflect.Type name with package path. func getFullName(typ reflect.Type) string { return typ.PkgPath() + "." + typ.Name() } +// get table name. method, or field name. auto snaked. func getTableName(val reflect.Value) string { ind := reflect.Indirect(val) fun := val.MethodByName("TableName") @@ -26,6 +28,7 @@ func getTableName(val reflect.Value) string { return snakeString(ind.Type().Name()) } +// get table engine, mysiam or innodb. func getTableEngine(val reflect.Value) string { fun := val.MethodByName("TableEngine") if fun.IsValid() { @@ -40,6 +43,7 @@ func getTableEngine(val reflect.Value) string { return "" } +// get table index from method. func getTableIndex(val reflect.Value) [][]string { fun := val.MethodByName("TableIndex") if fun.IsValid() { @@ -56,6 +60,7 @@ func getTableIndex(val reflect.Value) [][]string { return nil } +// get table unique from method func getTableUnique(val reflect.Value) [][]string { fun := val.MethodByName("TableUnique") if fun.IsValid() { @@ -72,6 +77,7 @@ func getTableUnique(val reflect.Value) [][]string { return nil } +// get snaked column name func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { col = strings.ToLower(col) column := col @@ -89,6 +95,7 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col return column } +// return field type as type constant from reflect.Value func getFieldType(val reflect.Value) (ft int, err error) { elm := reflect.Indirect(val) switch elm.Kind() { @@ -128,6 +135,7 @@ func getFieldType(val reflect.Value) (ft int, err error) { return } +// parse struct tag string func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) { attr := make(map[string]bool) tag := make(map[string]string) diff --git a/orm/orm.go b/orm/orm.go index 9e3c3565..71b4daa4 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -40,6 +40,7 @@ type orm struct { var _ Ormer = new(orm) +// get model info and model reflect value func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { val := reflect.ValueOf(md) ind = reflect.Indirect(val) @@ -54,6 +55,7 @@ func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect panic(fmt.Errorf(" table: `%s` not found, maybe not RegisterModel", name)) } +// get field info from model info by given field name func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { fi, ok := mi.fields.GetByAny(name) if !ok { @@ -62,6 +64,7 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { return fi } +// read data to model func (o *orm) Read(md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols) @@ -71,6 +74,7 @@ func (o *orm) Read(md interface{}, cols ...string) error { return nil } +// insert model data to database func (o *orm) Insert(md interface{}) (int64, error) { mi, ind := o.getMiInd(md, true) id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) @@ -83,6 +87,7 @@ func (o *orm) Insert(md interface{}) (int64, error) { return id, nil } +// set auto pk field func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { if mi.fields.pk.auto { if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { @@ -93,6 +98,7 @@ func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { } } +// insert some models to database func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { var cnt int64 @@ -127,6 +133,8 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { return cnt, nil } +// update model to database. +// cols set the columns those want to update. func (o *orm) Update(md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) @@ -136,6 +144,7 @@ func (o *orm) Update(md interface{}, cols ...string) (int64, error) { return num, nil } +// delete model in database func (o *orm) Delete(md interface{}) (int64, error) { mi, ind := o.getMiInd(md, true) num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ) @@ -148,6 +157,7 @@ func (o *orm) Delete(md interface{}) (int64, error) { return num, nil } +// create a models to models queryer func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) @@ -162,6 +172,14 @@ func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { return newQueryM2M(md, o, mi, fi, ind) } +// load related models to md model. +// args are limit, offset int and order string. +// +// example: +// orm.LoadRelated(post,"Tags") +// for _,tag := range post.Tags{...} +// +// make sure the relation is defined in model struct tags. func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { _, fi, ind, qseter := o.queryRelated(md, name) @@ -223,12 +241,19 @@ func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int return nums, err } +// return a QuerySeter for related models to md model. +// it can do all, update, delete in QuerySeter. +// example: +// qs := orm.QueryRelated(post,"Tag") +// qs.All(&[]*Tag{}) +// func (o *orm) QueryRelated(md interface{}, name string) QuerySeter { // is this api needed ? _, _, _, qs := o.queryRelated(md, name) return qs } +// get QuerySeter for related models to md model func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) @@ -260,6 +285,7 @@ func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, return mi, fi, ind, qs } +// get reverse relation QuerySeter func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { switch fi.fieldType { case RelReverseOne, RelReverseMany: @@ -280,6 +306,7 @@ func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS return q } +// get relation QuerySeter func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { switch fi.fieldType { case RelOneToOne, RelForeignKey, RelManyToMany: @@ -299,6 +326,9 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { return q } +// return a QuerySeter for table operations. +// table name can be string or struct. +// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { name := "" if table, ok := ptrStructOrTableName.(string); ok { @@ -318,6 +348,7 @@ func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { return } +// switch to another registered database driver by given name. func (o *orm) Using(name string) error { if o.isTx { panic(fmt.Errorf(" transaction has been start, cannot change db")) @@ -335,6 +366,7 @@ func (o *orm) Using(name string) error { return nil } +// begin transaction func (o *orm) Begin() error { if o.isTx { return ErrTxHasBegan @@ -353,6 +385,7 @@ func (o *orm) Begin() error { return nil } +// commit transaction func (o *orm) Commit() error { if o.isTx == false { return ErrTxDone @@ -367,6 +400,7 @@ func (o *orm) Commit() error { return err } +// rollback transaction func (o *orm) Rollback() error { if o.isTx == false { return ErrTxDone @@ -381,14 +415,17 @@ func (o *orm) Rollback() error { return err } +// return a raw query seter for raw sql string. func (o *orm) Raw(query string, args ...interface{}) RawSeter { return newRawSet(o, query, args) } +// return current using database Driver func (o *orm) Driver() Driver { return driver(o.alias.Name) } +// create new orm func NewOrm() Ormer { BootStrap() // execute only once diff --git a/orm/orm_conds.go b/orm/orm_conds.go index 91d69986..5b1151e2 100644 --- a/orm/orm_conds.go +++ b/orm/orm_conds.go @@ -18,15 +18,19 @@ type condValue struct { isCond bool } +// condition struct. +// work for WHERE conditions. type Condition struct { params []condValue } +// return new condition struct func NewCondition() *Condition { c := &Condition{} return c } +// add expression to condition func (c Condition) And(expr string, args ...interface{}) *Condition { if expr == "" || len(args) == 0 { panic(fmt.Errorf(" args cannot empty")) @@ -35,6 +39,7 @@ func (c Condition) And(expr string, args ...interface{}) *Condition { return &c } +// add NOT expression to condition func (c Condition) AndNot(expr string, args ...interface{}) *Condition { if expr == "" || len(args) == 0 { panic(fmt.Errorf(" args cannot empty")) @@ -43,6 +48,7 @@ func (c Condition) AndNot(expr string, args ...interface{}) *Condition { return &c } +// combine a condition to current condition func (c *Condition) AndCond(cond *Condition) *Condition { c = c.clone() if c == cond { @@ -54,6 +60,7 @@ func (c *Condition) AndCond(cond *Condition) *Condition { return c } +// add OR expression to condition func (c Condition) Or(expr string, args ...interface{}) *Condition { if expr == "" || len(args) == 0 { panic(fmt.Errorf(" args cannot empty")) @@ -62,6 +69,7 @@ func (c Condition) Or(expr string, args ...interface{}) *Condition { return &c } +// add OR NOT expression to condition func (c Condition) OrNot(expr string, args ...interface{}) *Condition { if expr == "" || len(args) == 0 { panic(fmt.Errorf(" args cannot empty")) @@ -70,6 +78,7 @@ func (c Condition) OrNot(expr string, args ...interface{}) *Condition { return &c } +// combine a OR condition to current condition func (c *Condition) OrCond(cond *Condition) *Condition { c = c.clone() if c == cond { @@ -81,10 +90,12 @@ func (c *Condition) OrCond(cond *Condition) *Condition { return c } +// check the condition arguments are empty or not. func (c *Condition) IsEmpty() bool { return len(c.params) == 0 } +// clone a condition func (c Condition) clone() *Condition { return &c } diff --git a/orm/orm_log.go b/orm/orm_log.go index 0bb5d6f9..e6df797a 100644 --- a/orm/orm_log.go +++ b/orm/orm_log.go @@ -13,6 +13,7 @@ type Log struct { *log.Logger } +// set io.Writer to create a Logger. func NewLog(out io.Writer) *Log { d := new(Log) d.Logger = log.New(out, "[ORM]", 1e9) @@ -40,6 +41,8 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error DebugLog.Println(con) } +// statement query logger struct. +// if dev mode, use stmtQueryLog, or use stmtQuerier. type stmtQueryLog struct { alias *alias query string @@ -84,6 +87,8 @@ func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier { return d } +// database query logger struct. +// if dev mode, use dbQueryLog, or use dbQuerier. type dbQueryLog struct { alias *alias db dbQuerier diff --git a/orm/orm_object.go b/orm/orm_object.go index 3c6d1f0e..fa644349 100644 --- a/orm/orm_object.go +++ b/orm/orm_object.go @@ -5,6 +5,7 @@ import ( "reflect" ) +// an insert queryer struct type insertSet struct { mi *modelInfo orm *orm @@ -14,6 +15,7 @@ type insertSet struct { var _ Inserter = new(insertSet) +// insert model ignore it's registered or not. func (o *insertSet) Insert(md interface{}) (int64, error) { if o.closed { return 0, ErrStmtClosed @@ -44,6 +46,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) { return id, nil } +// close insert queryer statement func (o *insertSet) Close() error { if o.closed { return ErrStmtClosed @@ -52,6 +55,7 @@ func (o *insertSet) Close() error { return o.stmt.Close() } +// create new insert queryer. func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) { bi := new(insertSet) bi.orm = orm diff --git a/orm/orm_querym2m.go b/orm/orm_querym2m.go index 6f0544d0..f0bc94b7 100644 --- a/orm/orm_querym2m.go +++ b/orm/orm_querym2m.go @@ -4,6 +4,7 @@ import ( "reflect" ) +// model to model struct type queryM2M struct { md interface{} mi *modelInfo @@ -12,6 +13,13 @@ type queryM2M struct { ind reflect.Value } +// add models to origin models when creating queryM2M. +// example: +// m2m := orm.QueryM2M(post,"Tag") +// m2m.Add(&Tag1{},&Tag2{}) +// for _,tag := range post.Tags{} +// +// make sure the relation is defined in post model struct tag. func (o *queryM2M) Add(mds ...interface{}) (int64, error) { fi := o.fi mi := fi.relThroughModelInfo @@ -67,6 +75,7 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { return dbase.InsertValue(orm.db, mi, true, names, values) } +// remove models following the origin model relationship func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { fi := o.fi qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) @@ -78,17 +87,20 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { return nums, nil } +// check model is existed in relationship of origin model func (o *queryM2M) Exist(md interface{}) bool { fi := o.fi return o.qs.Filter(fi.reverseFieldInfo.name, o.md). Filter(fi.reverseFieldInfoTwo.name, md).Exist() } +// clean all models in related of origin model func (o *queryM2M) Clear() (int64, error) { fi := o.fi return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete() } +// count all related models of origin model func (o *queryM2M) Count() (int64, error) { fi := o.fi return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count() @@ -96,6 +108,7 @@ func (o *queryM2M) Count() (int64, error) { var _ QueryM2Mer = new(queryM2M) +// create new M2M queryer. func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { qm2m := new(queryM2M) qm2m.md = md diff --git a/orm/orm_queryset.go b/orm/orm_queryset.go index b25d0542..ad8a9374 100644 --- a/orm/orm_queryset.go +++ b/orm/orm_queryset.go @@ -18,6 +18,10 @@ const ( Col_Except ) +// ColValue do the field raw changes. e.g Nums = Nums + 10. usage: +// Params{ +// "Nums": ColValue(Col_Add, 10), +// } func ColValue(opt operator, value interface{}) interface{} { switch opt { case Col_Add, Col_Minus, Col_Multiply, Col_Except: @@ -34,6 +38,7 @@ func ColValue(opt operator, value interface{}) interface{} { return val } +// real query struct type querySet struct { mi *modelInfo cond *Condition @@ -47,6 +52,7 @@ type querySet struct { var _ QuerySeter = new(querySet) +// add condition expression to QuerySeter. func (o querySet) Filter(expr string, args ...interface{}) QuerySeter { if o.cond == nil { o.cond = NewCondition() @@ -55,6 +61,7 @@ func (o querySet) Filter(expr string, args ...interface{}) QuerySeter { return &o } +// add NOT condition to querySeter. func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter { if o.cond == nil { o.cond = NewCondition() @@ -63,10 +70,13 @@ func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter { return &o } +// set offset number func (o *querySet) setOffset(num interface{}) { o.offset = ToInt64(num) } +// add LIMIT value. +// args[0] means offset, e.g. LIMIT num,offset. func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter { o.limit = ToInt64(limit) if len(args) > 0 { @@ -75,16 +85,21 @@ func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter { return &o } +// add OFFSET value func (o querySet) Offset(offset interface{}) QuerySeter { o.setOffset(offset) return &o } +// add ORDER expression. +// "column" means ASC, "-column" means DESC. func (o querySet) OrderBy(exprs ...string) QuerySeter { o.orders = exprs return &o } +// set relation model to query together. +// it will query relation models and assign to parent model. func (o querySet) RelatedSel(params ...interface{}) QuerySeter { var related []string if len(params) == 0 { @@ -105,36 +120,50 @@ func (o querySet) RelatedSel(params ...interface{}) QuerySeter { return &o } +// set condition to QuerySeter. func (o querySet) SetCond(cond *Condition) QuerySeter { o.cond = cond return &o } +// return QuerySeter execution result number func (o *querySet) Count() (int64, error) { return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) } +// check result empty or not after QuerySeter executed func (o *querySet) Exist() bool { cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) return cnt > 0 } +// execute update with parameters func (o *querySet) Update(values Params) (int64, error) { return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) } +// execute delete func (o *querySet) Delete() (int64, error) { return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) } +// return a insert queryer. +// it can be used in times. +// example: +// i,err := sq.PrepareInsert() +// i.Add(&user1{},&user2{}) func (o *querySet) PrepareInsert() (Inserter, error) { return newInsertSet(o.orm, o.mi) } +// query all data and map to containers. +// cols means the columns when querying. func (o *querySet) All(container interface{}, cols ...string) (int64, error) { return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) } +// query one row data and map to containers. +// cols means the columns when querying. func (o *querySet) One(container interface{}, cols ...string) error { num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) if err != nil { @@ -149,18 +178,26 @@ func (o *querySet) One(container interface{}, cols ...string) error { return nil } +// query all data and map to []map[string]interface. +// expres means condition expression. +// it converts data to []map[column]value. func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) } +// query all data and map to [][]interface +// it converts data to [][column_index]value func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) } +// query all data and map to []interface. +// it's designed for one row record set, auto change to []value, not [][column]value. func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) } +// create new QuerySeter. func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { o := new(querySet) o.mi = mi diff --git a/orm/orm_raw.go b/orm/orm_raw.go index a713dbac..3f5fb162 100644 --- a/orm/orm_raw.go +++ b/orm/orm_raw.go @@ -7,6 +7,7 @@ import ( "time" ) +// raw sql string prepared statement type rawPrepare struct { rs *rawSet stmt stmtQuerier @@ -44,6 +45,7 @@ func newRawPreparer(rs *rawSet) (RawPreparer, error) { return o, nil } +// raw query seter type rawSet struct { query string args []interface{} @@ -52,11 +54,13 @@ type rawSet struct { var _ RawSeter = new(rawSet) +// set args for every query func (o rawSet) SetArgs(args ...interface{}) RawSeter { o.args = args return &o } +// execute raw sql and return sql.Result func (o *rawSet) Exec() (sql.Result, error) { query := o.query o.orm.alias.DbBaser.ReplaceMarks(&query) @@ -65,6 +69,7 @@ func (o *rawSet) Exec() (sql.Result, error) { return o.orm.db.Exec(query, args...) } +// set field value to row container func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { switch ind.Kind() { case reflect.Bool: @@ -163,6 +168,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { } } +// set field value in loop for slice container func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) { nInds := *nIndsPtr @@ -233,6 +239,7 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr } } +// query data and map to container func (o *rawSet) QueryRow(containers ...interface{}) error { refs := make([]interface{}, 0, len(containers)) sInds := make([]reflect.Value, 0) @@ -362,6 +369,7 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { return nil } +// query data rows and map to container func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { refs := make([]interface{}, 0, len(containers)) sInds := make([]reflect.Value, 0) @@ -615,18 +623,22 @@ func (o *rawSet) readValues(container interface{}) (int64, error) { return cnt, nil } +// query data to []map[string]interface func (o *rawSet) Values(container *[]Params) (int64, error) { return o.readValues(container) } +// query data to [][]interface func (o *rawSet) ValuesList(container *[]ParamsList) (int64, error) { return o.readValues(container) } +// query data to []interface func (o *rawSet) ValuesFlat(container *ParamsList) (int64, error) { return o.readValues(container) } +// return prepared raw statement for used in times. func (o *rawSet) Prepare() (RawPreparer, error) { return newRawPreparer(o) } diff --git a/orm/types.go b/orm/types.go index a6487fc0..6f13ed67 100644 --- a/orm/types.go +++ b/orm/types.go @@ -6,11 +6,13 @@ import ( "time" ) +// database driver type Driver interface { Name() string Type() DriverType } +// field info type Fielder interface { String() string FieldType() int @@ -18,6 +20,7 @@ type Fielder interface { RawValue() interface{} } +// orm struct type Ormer interface { Read(interface{}, ...string) error Insert(interface{}) (int64, error) @@ -35,11 +38,13 @@ type Ormer interface { Driver() Driver } +// insert prepared statement type Inserter interface { Insert(interface{}) (int64, error) Close() error } +// query seter type QuerySeter interface { Filter(string, ...interface{}) QuerySeter Exclude(string, ...interface{}) QuerySeter @@ -60,6 +65,7 @@ type QuerySeter interface { ValuesFlat(*ParamsList, string) (int64, error) } +// model to model query struct type QueryM2Mer interface { Add(...interface{}) (int64, error) Remove(...interface{}) (int64, error) @@ -68,11 +74,13 @@ type QueryM2Mer interface { Count() (int64, error) } +// raw query statement type RawPreparer interface { Exec(...interface{}) (sql.Result, error) Close() error } +// raw query seter type RawSeter interface { Exec() (sql.Result, error) QueryRow(...interface{}) error @@ -84,6 +92,7 @@ type RawSeter interface { Prepare() (RawPreparer, error) } +// statement querier type stmtQuerier interface { Close() error Exec(args ...interface{}) (sql.Result, error) @@ -91,6 +100,7 @@ type stmtQuerier interface { QueryRow(args ...interface{}) *sql.Row } +// db querier type dbQuerier interface { Prepare(query string) (*sql.Stmt, error) Exec(query string, args ...interface{}) (sql.Result, error) @@ -98,15 +108,18 @@ type dbQuerier interface { QueryRow(query string, args ...interface{}) *sql.Row } +// transaction beginner type txer interface { Begin() (*sql.Tx, error) } +// transaction ending type txEnder interface { Commit() error Rollback() error } +// base database struct type dbBaser interface { Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) diff --git a/orm/utils.go b/orm/utils.go index 237b3edf..2e347278 100644 --- a/orm/utils.go +++ b/orm/utils.go @@ -10,6 +10,7 @@ import ( type StrTo string +// set string func (f *StrTo) Set(v string) { if v != "" { *f = StrTo(v) @@ -18,77 +19,93 @@ func (f *StrTo) Set(v string) { } } +// clean string func (f *StrTo) Clear() { *f = StrTo(0x1E) } +// check string exist func (f StrTo) Exist() bool { return string(f) != string(0x1E) } +// string to bool func (f StrTo) Bool() (bool, error) { return strconv.ParseBool(f.String()) } +// string to float32 func (f StrTo) Float32() (float32, error) { v, err := strconv.ParseFloat(f.String(), 32) return float32(v), err } +// string to float64 func (f StrTo) Float64() (float64, error) { return strconv.ParseFloat(f.String(), 64) } +// string to int func (f StrTo) Int() (int, error) { v, err := strconv.ParseInt(f.String(), 10, 32) return int(v), err } +// string to int8 func (f StrTo) Int8() (int8, error) { v, err := strconv.ParseInt(f.String(), 10, 8) return int8(v), err } +// string to int16 func (f StrTo) Int16() (int16, error) { v, err := strconv.ParseInt(f.String(), 10, 16) return int16(v), err } +// string to int32 func (f StrTo) Int32() (int32, error) { v, err := strconv.ParseInt(f.String(), 10, 32) return int32(v), err } +// string to int64 func (f StrTo) Int64() (int64, error) { v, err := strconv.ParseInt(f.String(), 10, 64) return int64(v), err } +// string to uint func (f StrTo) Uint() (uint, error) { v, err := strconv.ParseUint(f.String(), 10, 32) return uint(v), err } +// string to uint8 func (f StrTo) Uint8() (uint8, error) { v, err := strconv.ParseUint(f.String(), 10, 8) return uint8(v), err } +// string to uint16 func (f StrTo) Uint16() (uint16, error) { v, err := strconv.ParseUint(f.String(), 10, 16) return uint16(v), err } +// string to uint31 func (f StrTo) Uint32() (uint32, error) { v, err := strconv.ParseUint(f.String(), 10, 32) return uint32(v), err } +// string to uint64 func (f StrTo) Uint64() (uint64, error) { v, err := strconv.ParseUint(f.String(), 10, 64) return uint64(v), err } +// string to string func (f StrTo) String() string { if f.Exist() { return string(f) @@ -96,6 +113,7 @@ func (f StrTo) String() string { return "" } +// interface to string func ToStr(value interface{}, args ...int) (s string) { switch v := value.(type) { case bool: @@ -134,6 +152,7 @@ func ToStr(value interface{}, args ...int) (s string) { return s } +// interface to int64 func ToInt64(value interface{}) (d int64) { val := reflect.ValueOf(value) switch value.(type) { @@ -147,6 +166,7 @@ func ToInt64(value interface{}) (d int64) { return } +// snake string, XxYy to xx_yy func snakeString(s string) string { data := make([]byte, 0, len(s)*2) j := false @@ -164,6 +184,7 @@ func snakeString(s string) string { return strings.ToLower(string(data[:len(data)])) } +// camel string, xx_yy to XxYy func camelString(s string) string { data := make([]byte, 0, len(s)) j := false @@ -190,6 +211,7 @@ func camelString(s string) string { type argString []string +// get string by index from string slice func (a argString) Get(i int, args ...string) (r string) { if i >= 0 && i < len(a) { r = a[i] @@ -201,6 +223,7 @@ func (a argString) Get(i int, args ...string) (r string) { type argInt []int +// get int by index from int slice func (a argInt) Get(i int, args ...int) (r int) { if i >= 0 && i < len(a) { r = a[i] @@ -213,6 +236,7 @@ func (a argInt) Get(i int, args ...int) (r int) { type argAny []interface{} +// get interface by index from interface slice func (a argAny) Get(i int, args ...interface{}) (r interface{}) { if i >= 0 && i < len(a) { r = a[i] @@ -223,15 +247,18 @@ func (a argAny) Get(i int, args ...interface{}) (r interface{}) { return } +// parse time to string with location func timeParse(dateString, format string) (time.Time, error) { tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) return tp, err } +// format time string func timeFormat(t time.Time, format string) string { return t.Format(format) } +// get pointer indirect type func indirectType(v reflect.Type) reflect.Type { switch v.Kind() { case reflect.Ptr: