diff --git a/adminui.go b/adminui.go
index 7bb32b34..cdcdef33 100644
--- a/adminui.go
+++ b/adminui.go
@@ -78,13 +78,14 @@ var qpsTpl = `{{define "content"}}
{{range $i, $elem := .Content.Data}}
- {{range $elem}}
-
- {{.}}
- |
- {{end}}
+ {{index $elem 0}} |
+ {{index $elem 1}} |
+ {{index $elem 2}} |
+ {{index $elem 4}} |
+ {{index $elem 6}} |
+ {{index $elem 8}} |
+ {{index $elem 10}} |
-
{{end}}
diff --git a/beego.go b/beego.go
index 32b64f75..10304994 100644
--- a/beego.go
+++ b/beego.go
@@ -23,7 +23,7 @@ import (
const (
// VERSION represent beego web framework version.
- VERSION = "1.6.1"
+ VERSION = "1.7.0"
// DEV is for develop
DEV = "dev"
diff --git a/cache/memcache/memcache.go b/cache/memcache/memcache.go
index 3f0fe411..972361f7 100644
--- a/cache/memcache/memcache.go
+++ b/cache/memcache/memcache.go
@@ -33,12 +33,10 @@ import (
"encoding/json"
"errors"
"strings"
-
- "github.com/bradfitz/gomemcache/memcache"
-
"time"
"github.com/astaxie/beego/cache"
+ "github.com/bradfitz/gomemcache/memcache"
)
// Cache Memcache adapter.
@@ -60,7 +58,7 @@ func (rc *Cache) Get(key string) interface{} {
}
}
if item, err := rc.conn.Get(key); err == nil {
- return string(item.Value)
+ return item.Value
}
return nil
}
@@ -80,7 +78,7 @@ func (rc *Cache) GetMulti(keys []string) []interface{} {
mv, err := rc.conn.GetMulti(keys)
if err == nil {
for _, v := range mv {
- rv = append(rv, string(v.Value))
+ rv = append(rv, v.Value)
}
return rv
}
@@ -90,18 +88,21 @@ func (rc *Cache) GetMulti(keys []string) []interface{} {
return rv
}
-// Put put value to memcache. only support string.
+// Put put value to memcache.
func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error {
if rc.conn == nil {
if err := rc.connectInit(); err != nil {
return err
}
}
- v, ok := val.(string)
- if !ok {
- return errors.New("val must string")
+ item := memcache.Item{Key: key, Expiration: int32(timeout / time.Second)}
+ if v, ok := val.([]byte); ok {
+ item.Value = v
+ } else if str, ok := val.(string); ok {
+ item.Value = []byte(str)
+ } else {
+ return errors.New("val only support string and []byte")
}
- item := memcache.Item{Key: key, Value: []byte(v), Expiration: int32(timeout / time.Second)}
return rc.conn.Set(&item)
}
diff --git a/cache/memcache/memcache_test.go b/cache/memcache/memcache_test.go
index 0c8c57f2..d9129b69 100644
--- a/cache/memcache/memcache_test.go
+++ b/cache/memcache/memcache_test.go
@@ -46,7 +46,7 @@ func TestMemcacheCache(t *testing.T) {
t.Error("set Error", err)
}
- if v, err := strconv.Atoi(bm.Get("astaxie").(string)); err != nil || v != 1 {
+ if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 {
t.Error("get err")
}
@@ -54,7 +54,7 @@ func TestMemcacheCache(t *testing.T) {
t.Error("Incr Error", err)
}
- if v, err := strconv.Atoi(bm.Get("astaxie").(string)); err != nil || v != 2 {
+ if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 2 {
t.Error("get err")
}
@@ -62,7 +62,7 @@ func TestMemcacheCache(t *testing.T) {
t.Error("Decr Error", err)
}
- if v, err := strconv.Atoi(bm.Get("astaxie").(string)); err != nil || v != 1 {
+ if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 {
t.Error("get err")
}
bm.Delete("astaxie")
@@ -78,7 +78,7 @@ func TestMemcacheCache(t *testing.T) {
t.Error("check err")
}
- if v := bm.Get("astaxie").(string); v != "author" {
+ if v := bm.Get("astaxie").([]byte); string(v) != "author" {
t.Error("get err")
}
@@ -94,10 +94,10 @@ func TestMemcacheCache(t *testing.T) {
if len(vv) != 2 {
t.Error("GetMulti ERROR")
}
- if vv[0].(string) != "author" && vv[0].(string) != "author1" {
+ if string(vv[0].([]byte)) != "author" && string(vv[0].([]byte)) != "author1" {
t.Error("GetMulti ERROR")
}
- if vv[1].(string) != "author1" && vv[1].(string) != "author" {
+ if string(vv[1].([]byte)) != "author1" && string(vv[1].([]byte)) != "author" {
t.Error("GetMulti ERROR")
}
diff --git a/config.go b/config.go
index ead538b0..a4f40611 100644
--- a/config.go
+++ b/config.go
@@ -19,9 +19,11 @@ import (
"os"
"path/filepath"
"reflect"
+ "runtime"
"strings"
"github.com/astaxie/beego/config"
+ "github.com/astaxie/beego/context"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/session"
"github.com/astaxie/beego/utils"
@@ -34,6 +36,7 @@ type Config struct {
RouterCaseSensitive bool
ServerName string
RecoverPanic bool
+ RecoverFunc func(*context.Context)
CopyRequestBody bool
EnableGzip bool
MaxMemory int64
@@ -142,6 +145,37 @@ func init() {
}
}
+func recoverPanic(ctx *context.Context) {
+ if err := recover(); err != nil {
+ if err == ErrAbort {
+ return
+ }
+ if !BConfig.RecoverPanic {
+ panic(err)
+ }
+ if BConfig.EnableErrorsShow {
+ if _, ok := ErrorMaps[fmt.Sprint(err)]; ok {
+ exception(fmt.Sprint(err), ctx)
+ return
+ }
+ }
+ var stack string
+ logs.Critical("the request url is ", ctx.Input.URL())
+ logs.Critical("Handler crashed with error", err)
+ for i := 1; ; i++ {
+ _, file, line, ok := runtime.Caller(i)
+ if !ok {
+ break
+ }
+ logs.Critical(fmt.Sprintf("%s:%d", file, line))
+ stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line))
+ }
+ if BConfig.RunMode == DEV {
+ showErr(err, ctx, stack)
+ }
+ }
+}
+
func newBConfig() *Config {
return &Config{
AppName: "beego",
@@ -149,6 +183,7 @@ func newBConfig() *Config {
RouterCaseSensitive: true,
ServerName: "beegoServer:" + VERSION,
RecoverPanic: true,
+ RecoverFunc: recoverPanic,
CopyRequestBody: false,
EnableGzip: false,
MaxMemory: 1 << 26, //64MB
@@ -233,9 +268,9 @@ func assignConfig(ac config.Configer) error {
sds := strings.Fields(sd)
for _, v := range sds {
if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 {
- BConfig.WebConfig.StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[1]
+ BConfig.WebConfig.StaticDir["/"+strings.Trim(url2fsmap[0], "/")] = url2fsmap[1]
} else {
- BConfig.WebConfig.StaticDir["/"+strings.TrimRight(url2fsmap[0], "/")] = url2fsmap[0]
+ BConfig.WebConfig.StaticDir["/"+strings.Trim(url2fsmap[0], "/")] = url2fsmap[0]
}
}
}
diff --git a/config/yaml/yaml.go b/config/yaml/yaml.go
index 64e25cb3..e3260215 100644
--- a/config/yaml/yaml.go
+++ b/config/yaml/yaml.go
@@ -281,7 +281,7 @@ func (c *ConfigContainer) DIY(key string) (v interface{}, err error) {
func (c *ConfigContainer) getData(key string) (interface{}, error) {
if len(key) == 0 {
- return nil, errors.New("key is emtpy")
+ return nil, errors.New("key is empty")
}
if v, ok := c.data[key]; ok {
diff --git a/context/acceptencoder.go b/context/acceptencoder.go
index cb735445..350b560d 100644
--- a/context/acceptencoder.go
+++ b/context/acceptencoder.go
@@ -209,9 +209,13 @@ func parseEncoding(r *http.Request) string {
continue
}
vs := strings.Split(v, ";")
+ var cf acceptEncoder
+ var ok bool
+ if cf, ok = encoderMap[vs[0]]; !ok {
+ continue
+ }
if len(vs) == 1 {
- lastQ = q{vs[0], 1}
- break
+ return cf.name
}
if len(vs) == 2 {
f, _ := strconv.ParseFloat(strings.Replace(vs[1], "q=", "", -1), 64)
@@ -219,12 +223,9 @@ func parseEncoding(r *http.Request) string {
continue
}
if f > lastQ.value {
- lastQ = q{vs[0], f}
+ lastQ = q{cf.name, f}
}
}
}
- if cf, ok := encoderMap[lastQ.name]; ok {
- return cf.name
- }
- return ""
+ return lastQ.name
}
diff --git a/context/acceptencoder_test.go b/context/acceptencoder_test.go
index 3afff679..e3d61e27 100644
--- a/context/acceptencoder_test.go
+++ b/context/acceptencoder_test.go
@@ -41,4 +41,19 @@ func Test_ExtractEncoding(t *testing.T) {
if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"*"}}}) != "gzip" {
t.Fail()
}
+ if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"x,gzip,deflate"}}}) != "gzip" {
+ t.Fail()
+ }
+ if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip,x,deflate"}}}) != "gzip" {
+ t.Fail()
+ }
+ if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0.5,x,deflate"}}}) != "deflate" {
+ t.Fail()
+ }
+ if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"x"}}}) != "" {
+ t.Fail()
+ }
+ if parseEncoding(&http.Request{Header: map[string][]string{"Accept-Encoding": {"gzip;q=0.5,x;q=0.8"}}}) != "gzip" {
+ t.Fail()
+ }
}
diff --git a/context/input.go b/context/input.go
index edfdf530..1e6eaf71 100644
--- a/context/input.go
+++ b/context/input.go
@@ -40,12 +40,14 @@ var (
// BeegoInput operates the http request header, data, cookie and body.
// it also contains router params and current session.
type BeegoInput struct {
- Context *Context
- CruSession session.Store
- pnames []string
- pvalues []string
- data map[interface{}]interface{} // store some values in this context when calling context in filter or controller.
- RequestBody []byte
+ Context *Context
+ CruSession session.Store
+ pnames []string
+ pvalues []string
+ data map[interface{}]interface{} // store some values in this context when calling context in filter or controller.
+ RequestBody []byte
+ RunMethod string
+ RunController reflect.Type
}
// NewInput return BeegoInput generated by Context.
@@ -89,6 +91,9 @@ func (input *BeegoInput) Site() string {
// Scheme returns request scheme as "http" or "https".
func (input *BeegoInput) Scheme() string {
+ if scheme := input.Header("X-Forwarded-Proto"); scheme != "" {
+ return scheme
+ }
if input.Context.Request.URL.Scheme != "" {
return input.Context.Request.URL.Scheme
}
@@ -298,6 +303,14 @@ func (input *BeegoInput) SetParam(key, val string) {
input.pnames = append(input.pnames, key)
}
+// ResetParams clears any of the input's Params
+// This function is used to clear parameters so they may be reset between filter
+// passes.
+func (input *BeegoInput) ResetParams() {
+ input.pnames = input.pnames[:0]
+ input.pvalues = input.pvalues[:0]
+}
+
// Query returns input data item string by a given string.
func (input *BeegoInput) Query(key string) string {
if val := input.Param(key); val != "" {
@@ -326,13 +339,16 @@ func (input *BeegoInput) Cookie(key string) string {
}
// Session returns current session item value by a given key.
-// if non-existed, return empty string.
+// if non-existed, return nil.
func (input *BeegoInput) Session(key interface{}) interface{} {
return input.CruSession.Get(key)
}
// CopyBody returns the raw request body data as bytes.
func (input *BeegoInput) CopyBody(MaxMemory int64) []byte {
+ if input.Context.Request.Body == nil {
+ return []byte{}
+ }
safe := &io.LimitedReader{R: input.Context.Request.Body, N: MaxMemory}
requestbody, _ := ioutil.ReadAll(safe)
input.Context.Request.Body.Close()
@@ -576,12 +592,15 @@ func (input *BeegoInput) bindStruct(params *url.Values, key string, typ reflect.
result := reflect.New(typ).Elem()
fieldValues := make(map[string]reflect.Value)
for reqKey, val := range *params {
- if !strings.HasPrefix(reqKey, key+".") {
+ var fieldName string
+ if strings.HasPrefix(reqKey, key+".") {
+ fieldName = reqKey[len(key)+1:]
+ } else if strings.HasPrefix(reqKey, key+"[") && reqKey[len(reqKey)-1] == ']' {
+ fieldName = reqKey[len(key)+1 : len(reqKey)-1]
+ } else {
continue
}
- fieldName := reqKey[len(key)+1:]
-
if _, ok := fieldValues[fieldName]; !ok {
// Time to bind this field. Get it and make sure we can set it.
fieldValue := result.FieldByName(fieldName)
diff --git a/context/input_test.go b/context/input_test.go
index 8887aec4..e64addba 100644
--- a/context/input_test.go
+++ b/context/input_test.go
@@ -75,6 +75,24 @@ func TestParse(t *testing.T) {
fmt.Println(user)
}
+func TestParse2(t *testing.T) {
+ r, _ := http.NewRequest("GET", "/?user[0][Username]=Raph&user[1].Username=Leo&user[0].Password=123456&user[1][Password]=654321", nil)
+ beegoInput := NewInput()
+ beegoInput.Context = NewContext()
+ beegoInput.Context.Reset(httptest.NewRecorder(), r)
+ beegoInput.ParseFormOrMulitForm(1 << 20)
+ type User struct {
+ Username string
+ Password string
+ }
+ var users []User
+ err := beegoInput.Bind(&users, "user")
+ fmt.Println(users)
+ if err != nil || users[0].Username != "Raph" || users[0].Password != "123456" || users[1].Username != "Leo" || users[1].Password != "654321" {
+ t.Fatal("users info wrong")
+ }
+}
+
func TestSubDomain(t *testing.T) {
r, _ := http.NewRequest("GET", "http://www.example.com/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil)
beegoInput := NewInput()
diff --git a/context/output.go b/context/output.go
index e1ad23e0..c09b9d19 100644
--- a/context/output.go
+++ b/context/output.go
@@ -146,18 +146,12 @@ func (output *BeegoOutput) Cookie(name string, value string, others ...interface
}
// default false. for session cookie default true
- httponly := false
if len(others) > 4 {
if v, ok := others[4].(bool); ok && v {
- // HttpOnly = true
- httponly = true
+ fmt.Fprintf(&b, "; HttpOnly")
}
}
- if httponly {
- fmt.Fprintf(&b, "; HttpOnly")
- }
-
output.Context.ResponseWriter.Header().Add("Set-Cookie", b.String())
}
@@ -212,7 +206,8 @@ func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error {
if callback == "" {
return errors.New(`"callback" parameter required`)
}
- callbackContent := bytes.NewBufferString(" " + template.JSEscapeString(callback))
+ callback = template.JSEscapeString(callback)
+ callbackContent := bytes.NewBufferString(" if(window." + callback + ")" + callback)
callbackContent.WriteString("(")
callbackContent.Write(content)
callbackContent.WriteString(");\r\n")
diff --git a/controller.go b/controller.go
index 3a9d1618..c7eb118d 100644
--- a/controller.go
+++ b/controller.go
@@ -71,6 +71,7 @@ type Controller struct {
TplName string
Layout string
LayoutSections map[string]string // the key is the section name and the value is the template name
+ TplPrefix string
TplExt string
EnableRender bool
@@ -227,6 +228,9 @@ func (c *Controller) renderTemplate() (bytes.Buffer, error) {
if c.TplName == "" {
c.TplName = strings.ToLower(c.controllerName) + "/" + strings.ToLower(c.actionName) + "." + c.TplExt
}
+ if c.TplPrefix != "" {
+ c.TplName = c.TplPrefix + c.TplName
+ }
if BConfig.RunMode == DEV {
buildFiles := []string{c.TplName}
if c.Layout != "" {
diff --git a/error.go b/error.go
index ce25d281..ab626247 100644
--- a/error.go
+++ b/error.go
@@ -93,7 +93,11 @@ func showErr(err interface{}, ctx *context.Context, stack string) {
"BeegoVersion": VERSION,
"GoVersion": runtime.Version(),
}
- ctx.ResponseWriter.WriteHeader(500)
+ if ctx.Output.Status != 0 {
+ ctx.ResponseWriter.WriteHeader(ctx.Output.Status)
+ } else {
+ ctx.ResponseWriter.WriteHeader(500)
+ }
t.Execute(ctx.ResponseWriter, data)
}
diff --git a/filter.go b/filter.go
index 863223f7..9cc6e913 100644
--- a/filter.go
+++ b/filter.go
@@ -27,6 +27,7 @@ type FilterRouter struct {
tree *Tree
pattern string
returnOnOutput bool
+ resetParams bool
}
// ValidRouter checks if the current request is matched by this filter.
diff --git a/hooks.go b/hooks.go
index 3dca1b8d..0c7d05fe 100644
--- a/hooks.go
+++ b/hooks.go
@@ -45,26 +45,24 @@ func registerSession() error {
if BConfig.WebConfig.Session.SessionOn {
var err error
sessionConfig := AppConfig.String("sessionConfig")
+ conf := new(session.ManagerConfig)
if sessionConfig == "" {
- conf := map[string]interface{}{
- "cookieName": BConfig.WebConfig.Session.SessionName,
- "gclifetime": BConfig.WebConfig.Session.SessionGCMaxLifetime,
- "providerConfig": filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig),
- "secure": BConfig.Listen.EnableHTTPS,
- "enableSetCookie": BConfig.WebConfig.Session.SessionAutoSetCookie,
- "domain": BConfig.WebConfig.Session.SessionDomain,
- "cookieLifeTime": BConfig.WebConfig.Session.SessionCookieLifeTime,
- "enableSidInHttpHeader": BConfig.WebConfig.Session.EnableSidInHttpHeader,
- "sessionNameInHttpHeader": BConfig.WebConfig.Session.SessionNameInHttpHeader,
- "enableSidInUrlQuery": BConfig.WebConfig.Session.EnableSidInUrlQuery,
- }
- confBytes, err := json.Marshal(conf)
- if err != nil {
+ conf.CookieName = BConfig.WebConfig.Session.SessionName
+ conf.EnableSetCookie = BConfig.WebConfig.Session.SessionAutoSetCookie
+ conf.Gclifetime = BConfig.WebConfig.Session.SessionGCMaxLifetime
+ conf.Secure = BConfig.Listen.EnableHTTPS
+ conf.CookieLifeTime = BConfig.WebConfig.Session.SessionCookieLifeTime
+ conf.ProviderConfig = filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig)
+ conf.Domain = BConfig.WebConfig.Session.SessionDomain
+ conf.EnableSidInHttpHeader = BConfig.WebConfig.Session.EnableSidInHttpHeader
+ conf.SessionNameInHttpHeader = BConfig.WebConfig.Session.SessionNameInHttpHeader
+ conf.EnableSidInUrlQuery = BConfig.WebConfig.Session.EnableSidInUrlQuery
+ } else {
+ if err = json.Unmarshal([]byte(sessionConfig), conf); err != nil {
return err
}
- sessionConfig = string(confBytes)
}
- if GlobalSessions, err = session.NewManager(BConfig.WebConfig.Session.SessionProvider, sessionConfig); err != nil {
+ if GlobalSessions, err = session.NewManager(BConfig.WebConfig.Session.SessionProvider, conf); err != nil {
return err
}
go GlobalSessions.GC()
diff --git a/httplib/httplib.go b/httplib/httplib.go
index 76984122..7e6f2700 100644
--- a/httplib/httplib.go
+++ b/httplib/httplib.go
@@ -409,9 +409,10 @@ func (b *BeegoHTTPRequest) DoRequest() (*http.Response, error) {
if trans == nil {
// create default transport
trans = &http.Transport{
- TLSClientConfig: b.setting.TLSClientConfig,
- Proxy: b.setting.Proxy,
- Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout),
+ TLSClientConfig: b.setting.TLSClientConfig,
+ Proxy: b.setting.Proxy,
+ Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout),
+ MaxIdleConnsPerHost: -1,
}
} else {
// if b.transport is *http.Transport then set the settings.
diff --git a/logs/color.go b/logs/color.go
new file mode 100644
index 00000000..41d23638
--- /dev/null
+++ b/logs/color.go
@@ -0,0 +1,28 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !windows
+
+package logs
+
+import "io"
+
+type ansiColorWriter struct {
+ w io.Writer
+ mode outputMode
+}
+
+func (cw *ansiColorWriter) Write(p []byte) (int, error) {
+ return cw.w.Write(p)
+}
diff --git a/logs/color_windows.go b/logs/color_windows.go
new file mode 100644
index 00000000..deee4c87
--- /dev/null
+++ b/logs/color_windows.go
@@ -0,0 +1,428 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build windows
+
+package logs
+
+import (
+ "bytes"
+ "io"
+ "strings"
+ "syscall"
+ "unsafe"
+)
+
+type (
+ csiState int
+ parseResult int
+)
+
+const (
+ outsideCsiCode csiState = iota
+ firstCsiCode
+ secondCsiCode
+)
+
+const (
+ noConsole parseResult = iota
+ changedColor
+ unknown
+)
+
+type ansiColorWriter struct {
+ w io.Writer
+ mode outputMode
+ state csiState
+ paramStartBuf bytes.Buffer
+ paramBuf bytes.Buffer
+}
+
+const (
+ firstCsiChar byte = '\x1b'
+ secondeCsiChar byte = '['
+ separatorChar byte = ';'
+ sgrCode byte = 'm'
+)
+
+const (
+ foregroundBlue = uint16(0x0001)
+ foregroundGreen = uint16(0x0002)
+ foregroundRed = uint16(0x0004)
+ foregroundIntensity = uint16(0x0008)
+ backgroundBlue = uint16(0x0010)
+ backgroundGreen = uint16(0x0020)
+ backgroundRed = uint16(0x0040)
+ backgroundIntensity = uint16(0x0080)
+ underscore = uint16(0x8000)
+
+ foregroundMask = foregroundBlue | foregroundGreen | foregroundRed | foregroundIntensity
+ backgroundMask = backgroundBlue | backgroundGreen | backgroundRed | backgroundIntensity
+)
+
+const (
+ ansiReset = "0"
+ ansiIntensityOn = "1"
+ ansiIntensityOff = "21"
+ ansiUnderlineOn = "4"
+ ansiUnderlineOff = "24"
+ ansiBlinkOn = "5"
+ ansiBlinkOff = "25"
+
+ ansiForegroundBlack = "30"
+ ansiForegroundRed = "31"
+ ansiForegroundGreen = "32"
+ ansiForegroundYellow = "33"
+ ansiForegroundBlue = "34"
+ ansiForegroundMagenta = "35"
+ ansiForegroundCyan = "36"
+ ansiForegroundWhite = "37"
+ ansiForegroundDefault = "39"
+
+ ansiBackgroundBlack = "40"
+ ansiBackgroundRed = "41"
+ ansiBackgroundGreen = "42"
+ ansiBackgroundYellow = "43"
+ ansiBackgroundBlue = "44"
+ ansiBackgroundMagenta = "45"
+ ansiBackgroundCyan = "46"
+ ansiBackgroundWhite = "47"
+ ansiBackgroundDefault = "49"
+
+ ansiLightForegroundGray = "90"
+ ansiLightForegroundRed = "91"
+ ansiLightForegroundGreen = "92"
+ ansiLightForegroundYellow = "93"
+ ansiLightForegroundBlue = "94"
+ ansiLightForegroundMagenta = "95"
+ ansiLightForegroundCyan = "96"
+ ansiLightForegroundWhite = "97"
+
+ ansiLightBackgroundGray = "100"
+ ansiLightBackgroundRed = "101"
+ ansiLightBackgroundGreen = "102"
+ ansiLightBackgroundYellow = "103"
+ ansiLightBackgroundBlue = "104"
+ ansiLightBackgroundMagenta = "105"
+ ansiLightBackgroundCyan = "106"
+ ansiLightBackgroundWhite = "107"
+)
+
+type drawType int
+
+const (
+ foreground drawType = iota
+ background
+)
+
+type winColor struct {
+ code uint16
+ drawType drawType
+}
+
+var colorMap = map[string]winColor{
+ ansiForegroundBlack: {0, foreground},
+ ansiForegroundRed: {foregroundRed, foreground},
+ ansiForegroundGreen: {foregroundGreen, foreground},
+ ansiForegroundYellow: {foregroundRed | foregroundGreen, foreground},
+ ansiForegroundBlue: {foregroundBlue, foreground},
+ ansiForegroundMagenta: {foregroundRed | foregroundBlue, foreground},
+ ansiForegroundCyan: {foregroundGreen | foregroundBlue, foreground},
+ ansiForegroundWhite: {foregroundRed | foregroundGreen | foregroundBlue, foreground},
+ ansiForegroundDefault: {foregroundRed | foregroundGreen | foregroundBlue, foreground},
+
+ ansiBackgroundBlack: {0, background},
+ ansiBackgroundRed: {backgroundRed, background},
+ ansiBackgroundGreen: {backgroundGreen, background},
+ ansiBackgroundYellow: {backgroundRed | backgroundGreen, background},
+ ansiBackgroundBlue: {backgroundBlue, background},
+ ansiBackgroundMagenta: {backgroundRed | backgroundBlue, background},
+ ansiBackgroundCyan: {backgroundGreen | backgroundBlue, background},
+ ansiBackgroundWhite: {backgroundRed | backgroundGreen | backgroundBlue, background},
+ ansiBackgroundDefault: {0, background},
+
+ ansiLightForegroundGray: {foregroundIntensity, foreground},
+ ansiLightForegroundRed: {foregroundIntensity | foregroundRed, foreground},
+ ansiLightForegroundGreen: {foregroundIntensity | foregroundGreen, foreground},
+ ansiLightForegroundYellow: {foregroundIntensity | foregroundRed | foregroundGreen, foreground},
+ ansiLightForegroundBlue: {foregroundIntensity | foregroundBlue, foreground},
+ ansiLightForegroundMagenta: {foregroundIntensity | foregroundRed | foregroundBlue, foreground},
+ ansiLightForegroundCyan: {foregroundIntensity | foregroundGreen | foregroundBlue, foreground},
+ ansiLightForegroundWhite: {foregroundIntensity | foregroundRed | foregroundGreen | foregroundBlue, foreground},
+
+ ansiLightBackgroundGray: {backgroundIntensity, background},
+ ansiLightBackgroundRed: {backgroundIntensity | backgroundRed, background},
+ ansiLightBackgroundGreen: {backgroundIntensity | backgroundGreen, background},
+ ansiLightBackgroundYellow: {backgroundIntensity | backgroundRed | backgroundGreen, background},
+ ansiLightBackgroundBlue: {backgroundIntensity | backgroundBlue, background},
+ ansiLightBackgroundMagenta: {backgroundIntensity | backgroundRed | backgroundBlue, background},
+ ansiLightBackgroundCyan: {backgroundIntensity | backgroundGreen | backgroundBlue, background},
+ ansiLightBackgroundWhite: {backgroundIntensity | backgroundRed | backgroundGreen | backgroundBlue, background},
+}
+
+var (
+ kernel32 = syscall.NewLazyDLL("kernel32.dll")
+ procSetConsoleTextAttribute = kernel32.NewProc("SetConsoleTextAttribute")
+ procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
+ defaultAttr *textAttributes
+)
+
+func init() {
+ screenInfo := getConsoleScreenBufferInfo(uintptr(syscall.Stdout))
+ if screenInfo != nil {
+ colorMap[ansiForegroundDefault] = winColor{
+ screenInfo.WAttributes & (foregroundRed | foregroundGreen | foregroundBlue),
+ foreground,
+ }
+ colorMap[ansiBackgroundDefault] = winColor{
+ screenInfo.WAttributes & (backgroundRed | backgroundGreen | backgroundBlue),
+ background,
+ }
+ defaultAttr = convertTextAttr(screenInfo.WAttributes)
+ }
+}
+
+type coord struct {
+ X, Y int16
+}
+
+type smallRect struct {
+ Left, Top, Right, Bottom int16
+}
+
+type consoleScreenBufferInfo struct {
+ DwSize coord
+ DwCursorPosition coord
+ WAttributes uint16
+ SrWindow smallRect
+ DwMaximumWindowSize coord
+}
+
+func getConsoleScreenBufferInfo(hConsoleOutput uintptr) *consoleScreenBufferInfo {
+ var csbi consoleScreenBufferInfo
+ ret, _, _ := procGetConsoleScreenBufferInfo.Call(
+ hConsoleOutput,
+ uintptr(unsafe.Pointer(&csbi)))
+ if ret == 0 {
+ return nil
+ }
+ return &csbi
+}
+
+func setConsoleTextAttribute(hConsoleOutput uintptr, wAttributes uint16) bool {
+ ret, _, _ := procSetConsoleTextAttribute.Call(
+ hConsoleOutput,
+ uintptr(wAttributes))
+ return ret != 0
+}
+
+type textAttributes struct {
+ foregroundColor uint16
+ backgroundColor uint16
+ foregroundIntensity uint16
+ backgroundIntensity uint16
+ underscore uint16
+ otherAttributes uint16
+}
+
+func convertTextAttr(winAttr uint16) *textAttributes {
+ fgColor := winAttr & (foregroundRed | foregroundGreen | foregroundBlue)
+ bgColor := winAttr & (backgroundRed | backgroundGreen | backgroundBlue)
+ fgIntensity := winAttr & foregroundIntensity
+ bgIntensity := winAttr & backgroundIntensity
+ underline := winAttr & underscore
+ otherAttributes := winAttr &^ (foregroundMask | backgroundMask | underscore)
+ return &textAttributes{fgColor, bgColor, fgIntensity, bgIntensity, underline, otherAttributes}
+}
+
+func convertWinAttr(textAttr *textAttributes) uint16 {
+ var winAttr uint16
+ winAttr |= textAttr.foregroundColor
+ winAttr |= textAttr.backgroundColor
+ winAttr |= textAttr.foregroundIntensity
+ winAttr |= textAttr.backgroundIntensity
+ winAttr |= textAttr.underscore
+ winAttr |= textAttr.otherAttributes
+ return winAttr
+}
+
+func changeColor(param []byte) parseResult {
+ screenInfo := getConsoleScreenBufferInfo(uintptr(syscall.Stdout))
+ if screenInfo == nil {
+ return noConsole
+ }
+
+ winAttr := convertTextAttr(screenInfo.WAttributes)
+ strParam := string(param)
+ if len(strParam) <= 0 {
+ strParam = "0"
+ }
+ csiParam := strings.Split(strParam, string(separatorChar))
+ for _, p := range csiParam {
+ c, ok := colorMap[p]
+ switch {
+ case !ok:
+ switch p {
+ case ansiReset:
+ winAttr.foregroundColor = defaultAttr.foregroundColor
+ winAttr.backgroundColor = defaultAttr.backgroundColor
+ winAttr.foregroundIntensity = defaultAttr.foregroundIntensity
+ winAttr.backgroundIntensity = defaultAttr.backgroundIntensity
+ winAttr.underscore = 0
+ winAttr.otherAttributes = 0
+ case ansiIntensityOn:
+ winAttr.foregroundIntensity = foregroundIntensity
+ case ansiIntensityOff:
+ winAttr.foregroundIntensity = 0
+ case ansiUnderlineOn:
+ winAttr.underscore = underscore
+ case ansiUnderlineOff:
+ winAttr.underscore = 0
+ case ansiBlinkOn:
+ winAttr.backgroundIntensity = backgroundIntensity
+ case ansiBlinkOff:
+ winAttr.backgroundIntensity = 0
+ default:
+ // unknown code
+ }
+ case c.drawType == foreground:
+ winAttr.foregroundColor = c.code
+ case c.drawType == background:
+ winAttr.backgroundColor = c.code
+ }
+ }
+ winTextAttribute := convertWinAttr(winAttr)
+ setConsoleTextAttribute(uintptr(syscall.Stdout), winTextAttribute)
+
+ return changedColor
+}
+
+func parseEscapeSequence(command byte, param []byte) parseResult {
+ if defaultAttr == nil {
+ return noConsole
+ }
+
+ switch command {
+ case sgrCode:
+ return changeColor(param)
+ default:
+ return unknown
+ }
+}
+
+func (cw *ansiColorWriter) flushBuffer() (int, error) {
+ return cw.flushTo(cw.w)
+}
+
+func (cw *ansiColorWriter) resetBuffer() (int, error) {
+ return cw.flushTo(nil)
+}
+
+func (cw *ansiColorWriter) flushTo(w io.Writer) (int, error) {
+ var n1, n2 int
+ var err error
+
+ startBytes := cw.paramStartBuf.Bytes()
+ cw.paramStartBuf.Reset()
+ if w != nil {
+ n1, err = cw.w.Write(startBytes)
+ if err != nil {
+ return n1, err
+ }
+ } else {
+ n1 = len(startBytes)
+ }
+ paramBytes := cw.paramBuf.Bytes()
+ cw.paramBuf.Reset()
+ if w != nil {
+ n2, err = cw.w.Write(paramBytes)
+ if err != nil {
+ return n1 + n2, err
+ }
+ } else {
+ n2 = len(paramBytes)
+ }
+ return n1 + n2, nil
+}
+
+func isParameterChar(b byte) bool {
+ return ('0' <= b && b <= '9') || b == separatorChar
+}
+
+func (cw *ansiColorWriter) Write(p []byte) (int, error) {
+ r, nw, first, last := 0, 0, 0, 0
+ if cw.mode != DiscardNonColorEscSeq {
+ cw.state = outsideCsiCode
+ cw.resetBuffer()
+ }
+
+ var err error
+ for i, ch := range p {
+ switch cw.state {
+ case outsideCsiCode:
+ if ch == firstCsiChar {
+ cw.paramStartBuf.WriteByte(ch)
+ cw.state = firstCsiCode
+ }
+ case firstCsiCode:
+ switch ch {
+ case firstCsiChar:
+ cw.paramStartBuf.WriteByte(ch)
+ break
+ case secondeCsiChar:
+ cw.paramStartBuf.WriteByte(ch)
+ cw.state = secondCsiCode
+ last = i - 1
+ default:
+ cw.resetBuffer()
+ cw.state = outsideCsiCode
+ }
+ case secondCsiCode:
+ if isParameterChar(ch) {
+ cw.paramBuf.WriteByte(ch)
+ } else {
+ nw, err = cw.w.Write(p[first:last])
+ r += nw
+ if err != nil {
+ return r, err
+ }
+ first = i + 1
+ result := parseEscapeSequence(ch, cw.paramBuf.Bytes())
+ if result == noConsole || (cw.mode == OutputNonColorEscSeq && result == unknown) {
+ cw.paramBuf.WriteByte(ch)
+ nw, err := cw.flushBuffer()
+ if err != nil {
+ return r, err
+ }
+ r += nw
+ } else {
+ n, _ := cw.resetBuffer()
+ // Add one more to the size of the buffer for the last ch
+ r += n + 1
+ }
+
+ cw.state = outsideCsiCode
+ }
+ default:
+ cw.state = outsideCsiCode
+ }
+ }
+
+ if cw.mode != DiscardNonColorEscSeq || cw.state == outsideCsiCode {
+ nw, err = cw.w.Write(p[first:])
+ r += nw
+ }
+
+ return r, err
+}
diff --git a/logs/color_windows_test.go b/logs/color_windows_test.go
new file mode 100644
index 00000000..5074841a
--- /dev/null
+++ b/logs/color_windows_test.go
@@ -0,0 +1,294 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build windows
+
+package logs
+
+import (
+ "bytes"
+ "fmt"
+ "syscall"
+ "testing"
+)
+
+var GetConsoleScreenBufferInfo = getConsoleScreenBufferInfo
+
+func ChangeColor(color uint16) {
+ setConsoleTextAttribute(uintptr(syscall.Stdout), color)
+}
+
+func ResetColor() {
+ ChangeColor(uint16(0x0007))
+}
+
+func TestWritePlanText(t *testing.T) {
+ inner := bytes.NewBufferString("")
+ w := NewAnsiColorWriter(inner)
+ expected := "plain text"
+ fmt.Fprintf(w, expected)
+ actual := inner.String()
+ if actual != expected {
+ t.Errorf("Get %q, want %q", actual, expected)
+ }
+}
+
+func TestWriteParseText(t *testing.T) {
+ inner := bytes.NewBufferString("")
+ w := NewAnsiColorWriter(inner)
+
+ inputTail := "\x1b[0mtail text"
+ expectedTail := "tail text"
+ fmt.Fprintf(w, inputTail)
+ actualTail := inner.String()
+ inner.Reset()
+ if actualTail != expectedTail {
+ t.Errorf("Get %q, want %q", actualTail, expectedTail)
+ }
+
+ inputHead := "head text\x1b[0m"
+ expectedHead := "head text"
+ fmt.Fprintf(w, inputHead)
+ actualHead := inner.String()
+ inner.Reset()
+ if actualHead != expectedHead {
+ t.Errorf("Get %q, want %q", actualHead, expectedHead)
+ }
+
+ inputBothEnds := "both ends \x1b[0m text"
+ expectedBothEnds := "both ends text"
+ fmt.Fprintf(w, inputBothEnds)
+ actualBothEnds := inner.String()
+ inner.Reset()
+ if actualBothEnds != expectedBothEnds {
+ t.Errorf("Get %q, want %q", actualBothEnds, expectedBothEnds)
+ }
+
+ inputManyEsc := "\x1b\x1b\x1b\x1b[0m many esc"
+ expectedManyEsc := "\x1b\x1b\x1b many esc"
+ fmt.Fprintf(w, inputManyEsc)
+ actualManyEsc := inner.String()
+ inner.Reset()
+ if actualManyEsc != expectedManyEsc {
+ t.Errorf("Get %q, want %q", actualManyEsc, expectedManyEsc)
+ }
+
+ expectedSplit := "split text"
+ for _, ch := range "split \x1b[0m text" {
+ fmt.Fprintf(w, string(ch))
+ }
+ actualSplit := inner.String()
+ inner.Reset()
+ if actualSplit != expectedSplit {
+ t.Errorf("Get %q, want %q", actualSplit, expectedSplit)
+ }
+}
+
+type screenNotFoundError struct {
+ error
+}
+
+func writeAnsiColor(expectedText, colorCode string) (actualText string, actualAttributes uint16, err error) {
+ inner := bytes.NewBufferString("")
+ w := NewAnsiColorWriter(inner)
+ fmt.Fprintf(w, "\x1b[%sm%s", colorCode, expectedText)
+
+ actualText = inner.String()
+ screenInfo := GetConsoleScreenBufferInfo(uintptr(syscall.Stdout))
+ if screenInfo != nil {
+ actualAttributes = screenInfo.WAttributes
+ } else {
+ err = &screenNotFoundError{}
+ }
+ return
+}
+
+type testParam struct {
+ text string
+ attributes uint16
+ ansiColor string
+}
+
+func TestWriteAnsiColorText(t *testing.T) {
+ screenInfo := GetConsoleScreenBufferInfo(uintptr(syscall.Stdout))
+ if screenInfo == nil {
+ t.Fatal("Could not get ConsoleScreenBufferInfo")
+ }
+ defer ChangeColor(screenInfo.WAttributes)
+ defaultFgColor := screenInfo.WAttributes & uint16(0x0007)
+ defaultBgColor := screenInfo.WAttributes & uint16(0x0070)
+ defaultFgIntensity := screenInfo.WAttributes & uint16(0x0008)
+ defaultBgIntensity := screenInfo.WAttributes & uint16(0x0080)
+
+ fgParam := []testParam{
+ {"foreground black ", uint16(0x0000 | 0x0000), "30"},
+ {"foreground red ", uint16(0x0004 | 0x0000), "31"},
+ {"foreground green ", uint16(0x0002 | 0x0000), "32"},
+ {"foreground yellow ", uint16(0x0006 | 0x0000), "33"},
+ {"foreground blue ", uint16(0x0001 | 0x0000), "34"},
+ {"foreground magenta", uint16(0x0005 | 0x0000), "35"},
+ {"foreground cyan ", uint16(0x0003 | 0x0000), "36"},
+ {"foreground white ", uint16(0x0007 | 0x0000), "37"},
+ {"foreground default", defaultFgColor | 0x0000, "39"},
+ {"foreground light gray ", uint16(0x0000 | 0x0008 | 0x0000), "90"},
+ {"foreground light red ", uint16(0x0004 | 0x0008 | 0x0000), "91"},
+ {"foreground light green ", uint16(0x0002 | 0x0008 | 0x0000), "92"},
+ {"foreground light yellow ", uint16(0x0006 | 0x0008 | 0x0000), "93"},
+ {"foreground light blue ", uint16(0x0001 | 0x0008 | 0x0000), "94"},
+ {"foreground light magenta", uint16(0x0005 | 0x0008 | 0x0000), "95"},
+ {"foreground light cyan ", uint16(0x0003 | 0x0008 | 0x0000), "96"},
+ {"foreground light white ", uint16(0x0007 | 0x0008 | 0x0000), "97"},
+ }
+
+ bgParam := []testParam{
+ {"background black ", uint16(0x0007 | 0x0000), "40"},
+ {"background red ", uint16(0x0007 | 0x0040), "41"},
+ {"background green ", uint16(0x0007 | 0x0020), "42"},
+ {"background yellow ", uint16(0x0007 | 0x0060), "43"},
+ {"background blue ", uint16(0x0007 | 0x0010), "44"},
+ {"background magenta", uint16(0x0007 | 0x0050), "45"},
+ {"background cyan ", uint16(0x0007 | 0x0030), "46"},
+ {"background white ", uint16(0x0007 | 0x0070), "47"},
+ {"background default", uint16(0x0007) | defaultBgColor, "49"},
+ {"background light gray ", uint16(0x0007 | 0x0000 | 0x0080), "100"},
+ {"background light red ", uint16(0x0007 | 0x0040 | 0x0080), "101"},
+ {"background light green ", uint16(0x0007 | 0x0020 | 0x0080), "102"},
+ {"background light yellow ", uint16(0x0007 | 0x0060 | 0x0080), "103"},
+ {"background light blue ", uint16(0x0007 | 0x0010 | 0x0080), "104"},
+ {"background light magenta", uint16(0x0007 | 0x0050 | 0x0080), "105"},
+ {"background light cyan ", uint16(0x0007 | 0x0030 | 0x0080), "106"},
+ {"background light white ", uint16(0x0007 | 0x0070 | 0x0080), "107"},
+ }
+
+ resetParam := []testParam{
+ {"all reset", defaultFgColor | defaultBgColor | defaultFgIntensity | defaultBgIntensity, "0"},
+ {"all reset", defaultFgColor | defaultBgColor | defaultFgIntensity | defaultBgIntensity, ""},
+ }
+
+ boldParam := []testParam{
+ {"bold on", uint16(0x0007 | 0x0008), "1"},
+ {"bold off", uint16(0x0007), "21"},
+ }
+
+ underscoreParam := []testParam{
+ {"underscore on", uint16(0x0007 | 0x8000), "4"},
+ {"underscore off", uint16(0x0007), "24"},
+ }
+
+ blinkParam := []testParam{
+ {"blink on", uint16(0x0007 | 0x0080), "5"},
+ {"blink off", uint16(0x0007), "25"},
+ }
+
+ mixedParam := []testParam{
+ {"both black, bold, underline, blink", uint16(0x0000 | 0x0000 | 0x0008 | 0x8000 | 0x0080), "30;40;1;4;5"},
+ {"both red, bold, underline, blink", uint16(0x0004 | 0x0040 | 0x0008 | 0x8000 | 0x0080), "31;41;1;4;5"},
+ {"both green, bold, underline, blink", uint16(0x0002 | 0x0020 | 0x0008 | 0x8000 | 0x0080), "32;42;1;4;5"},
+ {"both yellow, bold, underline, blink", uint16(0x0006 | 0x0060 | 0x0008 | 0x8000 | 0x0080), "33;43;1;4;5"},
+ {"both blue, bold, underline, blink", uint16(0x0001 | 0x0010 | 0x0008 | 0x8000 | 0x0080), "34;44;1;4;5"},
+ {"both magenta, bold, underline, blink", uint16(0x0005 | 0x0050 | 0x0008 | 0x8000 | 0x0080), "35;45;1;4;5"},
+ {"both cyan, bold, underline, blink", uint16(0x0003 | 0x0030 | 0x0008 | 0x8000 | 0x0080), "36;46;1;4;5"},
+ {"both white, bold, underline, blink", uint16(0x0007 | 0x0070 | 0x0008 | 0x8000 | 0x0080), "37;47;1;4;5"},
+ {"both default, bold, underline, blink", uint16(defaultFgColor | defaultBgColor | 0x0008 | 0x8000 | 0x0080), "39;49;1;4;5"},
+ }
+
+ assertTextAttribute := func(expectedText string, expectedAttributes uint16, ansiColor string) {
+ actualText, actualAttributes, err := writeAnsiColor(expectedText, ansiColor)
+ if actualText != expectedText {
+ t.Errorf("Get %q, want %q", actualText, expectedText)
+ }
+ if err != nil {
+ t.Fatal("Could not get ConsoleScreenBufferInfo")
+ }
+ if actualAttributes != expectedAttributes {
+ t.Errorf("Text: %q, Get 0x%04x, want 0x%04x", expectedText, actualAttributes, expectedAttributes)
+ }
+ }
+
+ for _, v := range fgParam {
+ ResetColor()
+ assertTextAttribute(v.text, v.attributes, v.ansiColor)
+ }
+
+ for _, v := range bgParam {
+ ChangeColor(uint16(0x0070 | 0x0007))
+ assertTextAttribute(v.text, v.attributes, v.ansiColor)
+ }
+
+ for _, v := range resetParam {
+ ChangeColor(uint16(0x0000 | 0x0070 | 0x0008))
+ assertTextAttribute(v.text, v.attributes, v.ansiColor)
+ }
+
+ ResetColor()
+ for _, v := range boldParam {
+ assertTextAttribute(v.text, v.attributes, v.ansiColor)
+ }
+
+ ResetColor()
+ for _, v := range underscoreParam {
+ assertTextAttribute(v.text, v.attributes, v.ansiColor)
+ }
+
+ ResetColor()
+ for _, v := range blinkParam {
+ assertTextAttribute(v.text, v.attributes, v.ansiColor)
+ }
+
+ for _, v := range mixedParam {
+ ResetColor()
+ assertTextAttribute(v.text, v.attributes, v.ansiColor)
+ }
+}
+
+func TestIgnoreUnknownSequences(t *testing.T) {
+ inner := bytes.NewBufferString("")
+ w := NewModeAnsiColorWriter(inner, OutputNonColorEscSeq)
+
+ inputText := "\x1b[=decpath mode"
+ expectedTail := inputText
+ fmt.Fprintf(w, inputText)
+ actualTail := inner.String()
+ inner.Reset()
+ if actualTail != expectedTail {
+ t.Errorf("Get %q, want %q", actualTail, expectedTail)
+ }
+
+ inputText = "\x1b[=tailing esc and bracket\x1b["
+ expectedTail = inputText
+ fmt.Fprintf(w, inputText)
+ actualTail = inner.String()
+ inner.Reset()
+ if actualTail != expectedTail {
+ t.Errorf("Get %q, want %q", actualTail, expectedTail)
+ }
+
+ inputText = "\x1b[?tailing esc\x1b"
+ expectedTail = inputText
+ fmt.Fprintf(w, inputText)
+ actualTail = inner.String()
+ inner.Reset()
+ if actualTail != expectedTail {
+ t.Errorf("Get %q, want %q", actualTail, expectedTail)
+ }
+
+ inputText = "\x1b[1h;3punended color code invalid\x1b3"
+ expectedTail = inputText
+ fmt.Fprintf(w, inputText)
+ actualTail = inner.String()
+ inner.Reset()
+ if actualTail != expectedTail {
+ t.Errorf("Get %q, want %q", actualTail, expectedTail)
+ }
+}
diff --git a/logs/file.go b/logs/file.go
index 7798a221..42146dae 100644
--- a/logs/file.go
+++ b/logs/file.go
@@ -159,6 +159,10 @@ func (w *fileLogWriter) createLogFile() (*os.File, error) {
return nil, err
}
fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, os.FileMode(perm))
+ if err == nil {
+ // Make sure file perm is user set perm cause of `os.OpenFile` will obey umask
+ os.Chmod(w.Filename, os.FileMode(perm))
+ }
return fd, err
}
diff --git a/logs/file_test.go b/logs/file_test.go
index 23370947..69a66d84 100644
--- a/logs/file_test.go
+++ b/logs/file_test.go
@@ -26,7 +26,8 @@ import (
func TestFilePerm(t *testing.T) {
log := NewLogger(10000)
- log.SetLogger("file", `{"filename":"test.log", "perm": "0600"}`)
+ // use 0666 as test perm cause the default umask is 022
+ log.SetLogger("file", `{"filename":"test.log", "perm": "0666"}`)
log.Debug("debug")
log.Informational("info")
log.Notice("notice")
@@ -39,7 +40,7 @@ func TestFilePerm(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- if file.Mode() != 0600 {
+ if file.Mode() != 0666 {
t.Fatal("unexpected log file permission")
}
os.Remove("test.log")
diff --git a/logs/jianliao.go b/logs/jianliao.go
new file mode 100644
index 00000000..3755118d
--- /dev/null
+++ b/logs/jianliao.go
@@ -0,0 +1,78 @@
+package logs
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/url"
+ "time"
+)
+
+// JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook
+type JLWriter struct {
+ AuthorName string `json:"authorname"`
+ Title string `json:"title"`
+ WebhookURL string `json:"webhookurl"`
+ RedirectURL string `json:"redirecturl,omitempty"`
+ ImageURL string `json:"imageurl,omitempty"`
+ Level int `json:"level"`
+}
+
+// newJLWriter create jiaoliao writer.
+func newJLWriter() Logger {
+ return &JLWriter{Level: LevelTrace}
+}
+
+// Init JLWriter with json config string
+func (s *JLWriter) Init(jsonconfig string) error {
+ err := json.Unmarshal([]byte(jsonconfig), s)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+// WriteMsg write message in smtp writer.
+// it will send an email with subject and only this message.
+func (s *JLWriter) WriteMsg(when time.Time, msg string, level int) error {
+ if level > s.Level {
+ return nil
+ }
+
+ text := fmt.Sprintf("%s %s", when.Format("2006-01-02 15:04:05"), msg)
+
+ form := url.Values{}
+ form.Add("authorName", s.AuthorName)
+ form.Add("title", s.Title)
+ form.Add("text", text)
+ if s.RedirectURL != "" {
+ form.Add("redirectUrl", s.RedirectURL)
+ }
+ if s.ImageURL != "" {
+ form.Add("imageUrl", s.ImageURL)
+ }
+
+ resp, err := http.PostForm(s.WebhookURL, form)
+ if err != nil {
+ return err
+ }
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode)
+ }
+ resp.Body.Close()
+ return nil
+}
+
+// Flush implementing method. empty.
+func (s *JLWriter) Flush() {
+ return
+}
+
+// Destroy implementing method. empty.
+func (s *JLWriter) Destroy() {
+ return
+}
+
+func init() {
+ Register(AdapterJianLiao, newJLWriter)
+}
diff --git a/logs/log.go b/logs/log.go
index c43782f3..3d512d2e 100644
--- a/logs/log.go
+++ b/logs/log.go
@@ -66,9 +66,11 @@ const (
AdapterConsole = "console"
AdapterFile = "file"
AdapterMultiFile = "multifile"
- AdapterMail = "stmp"
+ AdapterMail = "smtp"
AdapterConn = "conn"
AdapterEs = "es"
+ AdapterJianLiao = "jianliao"
+ AdapterSlack = "slack"
)
// Legacy log level constants to ensure backwards compatibility.
@@ -260,12 +262,7 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error
bl.setLogger(AdapterConsole)
bl.lock.Unlock()
}
- if logLevel == levelLoggerImpl {
- // set to emergency to ensure all log will be print out correctly
- logLevel = LevelEmergency
- } else {
- msg = levelPrefix[logLevel] + msg
- }
+
if len(v) > 0 {
msg = fmt.Sprintf(msg, v...)
}
@@ -279,6 +276,15 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error
_, filename := path.Split(file)
msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "] " + msg
}
+
+ //set level info in front of filename info
+ if logLevel == levelLoggerImpl {
+ // set to emergency to ensure all log will be print out correctly
+ logLevel = LevelEmergency
+ } else {
+ msg = levelPrefix[logLevel] + msg
+ }
+
if bl.asynchronous {
lm := logMsgPool.Get().(*logMsg)
lm.level = logLevel
@@ -532,10 +538,10 @@ func EnableFuncCallDepth(b bool) {
beeLogger.enableFuncCallDepth = b
}
-// SetLogFuncCall set the CallDepth, default is 3
+// SetLogFuncCall set the CallDepth, default is 4
func SetLogFuncCall(b bool) {
beeLogger.EnableFuncCallDepth(b)
- beeLogger.SetLogFuncCallDepth(3)
+ beeLogger.SetLogFuncCallDepth(4)
}
// SetLogFuncCallDepth set log funcCallDepth
diff --git a/logs/logger.go b/logs/logger.go
index 2f47e569..e0abfdc4 100644
--- a/logs/logger.go
+++ b/logs/logger.go
@@ -15,7 +15,9 @@
package logs
import (
+ "fmt"
"io"
+ "os"
"sync"
"time"
)
@@ -36,18 +38,56 @@ func (lg *logWriter) println(when time.Time, msg string) {
lg.Unlock()
}
-const y1 = `0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999`
-const y2 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789`
-const mo1 = `000000000111`
-const mo2 = `123456789012`
-const d1 = `0000000001111111111222222222233`
-const d2 = `1234567890123456789012345678901`
-const h1 = `000000000011111111112222`
-const h2 = `012345678901234567890123`
-const mi1 = `000000000011111111112222222222333333333344444444445555555555`
-const mi2 = `012345678901234567890123456789012345678901234567890123456789`
-const s1 = `000000000011111111112222222222333333333344444444445555555555`
-const s2 = `012345678901234567890123456789012345678901234567890123456789`
+type outputMode int
+
+// DiscardNonColorEscSeq supports the divided color escape sequence.
+// But non-color escape sequence is not output.
+// Please use the OutputNonColorEscSeq If you want to output a non-color
+// escape sequences such as ncurses. However, it does not support the divided
+// color escape sequence.
+const (
+ _ outputMode = iota
+ DiscardNonColorEscSeq
+ OutputNonColorEscSeq
+)
+
+// NewAnsiColorWriter creates and initializes a new ansiColorWriter
+// using io.Writer w as its initial contents.
+// In the console of Windows, which change the foreground and background
+// colors of the text by the escape sequence.
+// In the console of other systems, which writes to w all text.
+func NewAnsiColorWriter(w io.Writer) io.Writer {
+ return NewModeAnsiColorWriter(w, DiscardNonColorEscSeq)
+}
+
+// NewModeAnsiColorWriter create and initializes a new ansiColorWriter
+// by specifying the outputMode.
+func NewModeAnsiColorWriter(w io.Writer, mode outputMode) io.Writer {
+ if _, ok := w.(*ansiColorWriter); !ok {
+ return &ansiColorWriter{
+ w: w,
+ mode: mode,
+ }
+ }
+ return w
+}
+
+const (
+ y1 = `0123456789`
+ y2 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789`
+ y3 = `0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999`
+ y4 = `0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789`
+ mo1 = `000000000111`
+ mo2 = `123456789012`
+ d1 = `0000000001111111111222222222233`
+ d2 = `1234567890123456789012345678901`
+ h1 = `000000000011111111112222`
+ h2 = `012345678901234567890123`
+ mi1 = `000000000011111111112222222222333333333344444444445555555555`
+ mi2 = `012345678901234567890123456789012345678901234567890123456789`
+ s1 = `000000000011111111112222222222333333333344444444445555555555`
+ s2 = `012345678901234567890123456789012345678901234567890123456789`
+)
func formatTimeHeader(when time.Time) ([]byte, int) {
y, mo, d := when.Date()
@@ -55,12 +95,10 @@ func formatTimeHeader(when time.Time) ([]byte, int) {
//len("2006/01/02 15:04:05 ")==20
var buf [20]byte
- //change to '3' after 984 years, LOL
- buf[0] = '2'
- //change to '1' after 84 years, LOL
- buf[1] = '0'
- buf[2] = y1[y-2000]
- buf[3] = y2[y-2000]
+ buf[0] = y1[y/1000%10]
+ buf[1] = y2[y/100]
+ buf[2] = y3[y-y/100*100]
+ buf[3] = y4[y-y/100*100]
buf[4] = '/'
buf[5] = mo1[mo-1]
buf[6] = mo2[mo-1]
@@ -80,3 +118,71 @@ func formatTimeHeader(when time.Time) ([]byte, int) {
return buf[0:], d
}
+
+var (
+ green = string([]byte{27, 91, 57, 55, 59, 52, 50, 109})
+ white = string([]byte{27, 91, 57, 48, 59, 52, 55, 109})
+ yellow = string([]byte{27, 91, 57, 55, 59, 52, 51, 109})
+ red = string([]byte{27, 91, 57, 55, 59, 52, 49, 109})
+ blue = string([]byte{27, 91, 57, 55, 59, 52, 52, 109})
+ magenta = string([]byte{27, 91, 57, 55, 59, 52, 53, 109})
+ cyan = string([]byte{27, 91, 57, 55, 59, 52, 54, 109})
+
+ w32Green = string([]byte{27, 91, 52, 50, 109})
+ w32White = string([]byte{27, 91, 52, 55, 109})
+ w32Yellow = string([]byte{27, 91, 52, 51, 109})
+ w32Red = string([]byte{27, 91, 52, 49, 109})
+ w32Blue = string([]byte{27, 91, 52, 52, 109})
+ w32Magenta = string([]byte{27, 91, 52, 53, 109})
+ w32Cyan = string([]byte{27, 91, 52, 54, 109})
+
+ reset = string([]byte{27, 91, 48, 109})
+)
+
+func ColorByStatus(cond bool, code int) string {
+ switch {
+ case code >= 200 && code < 300:
+ return map[bool]string{true: green, false: w32Green}[cond]
+ case code >= 300 && code < 400:
+ return map[bool]string{true: white, false: w32White}[cond]
+ case code >= 400 && code < 500:
+ return map[bool]string{true: yellow, false: w32Yellow}[cond]
+ default:
+ return map[bool]string{true: red, false: w32Red}[cond]
+ }
+}
+
+func ColorByMethod(cond bool, method string) string {
+ switch method {
+ case "GET":
+ return map[bool]string{true: blue, false: w32Blue}[cond]
+ case "POST":
+ return map[bool]string{true: cyan, false: w32Cyan}[cond]
+ case "PUT":
+ return map[bool]string{true: yellow, false: w32Yellow}[cond]
+ case "DELETE":
+ return map[bool]string{true: red, false: w32Red}[cond]
+ case "PATCH":
+ return map[bool]string{true: green, false: w32Green}[cond]
+ case "HEAD":
+ return map[bool]string{true: magenta, false: w32Magenta}[cond]
+ case "OPTIONS":
+ return map[bool]string{true: white, false: w32White}[cond]
+ default:
+ return reset
+ }
+}
+
+// Guard Mutex to guarantee atomicity of W32Debug(string) function
+var mu sync.Mutex
+
+// Helper method to output colored logs in Windows terminals
+func W32Debug(msg string) {
+ mu.Lock()
+ defer mu.Unlock()
+
+ current := time.Now()
+ w := NewAnsiColorWriter(os.Stdout)
+
+ fmt.Fprintf(w, "[beego] %v %s\n", current.Format("2006/01/02 - 15:04:05"), msg)
+}
diff --git a/logs/logger_test.go b/logs/logger_test.go
index 4627853a..119b7bd3 100644
--- a/logs/logger_test.go
+++ b/logs/logger_test.go
@@ -15,6 +15,7 @@
package logs
import (
+ "bytes"
"testing"
"time"
)
@@ -55,3 +56,20 @@ func TestFormatHeader_1(t *testing.T) {
tm = tm.Add(dur)
}
}
+
+func TestNewAnsiColor1(t *testing.T) {
+ inner := bytes.NewBufferString("")
+ w := NewAnsiColorWriter(inner)
+ if w == inner {
+ t.Errorf("Get %#v, want %#v", w, inner)
+ }
+}
+
+func TestNewAnsiColor2(t *testing.T) {
+ inner := bytes.NewBufferString("")
+ w1 := NewAnsiColorWriter(inner)
+ w2 := NewAnsiColorWriter(w1)
+ if w1 != w2 {
+ t.Errorf("Get %#v, want %#v", w1, w2)
+ }
+}
diff --git a/logs/slack.go b/logs/slack.go
new file mode 100644
index 00000000..eddedd5d
--- /dev/null
+++ b/logs/slack.go
@@ -0,0 +1,66 @@
+package logs
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/url"
+ "time"
+)
+
+// SLACKWriter implements beego LoggerInterface and is used to send jiaoliao webhook
+type SLACKWriter struct {
+ WebhookURL string `json:"webhookurl"`
+ Level int `json:"level"`
+}
+
+// newSLACKWriter create jiaoliao writer.
+func newSLACKWriter() Logger {
+ return &SLACKWriter{Level: LevelTrace}
+}
+
+// Init SLACKWriter with json config string
+func (s *SLACKWriter) Init(jsonconfig string) error {
+ err := json.Unmarshal([]byte(jsonconfig), s)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+// WriteMsg write message in smtp writer.
+// it will send an email with subject and only this message.
+func (s *SLACKWriter) WriteMsg(when time.Time, msg string, level int) error {
+ if level > s.Level {
+ return nil
+ }
+
+ text := fmt.Sprintf("{\"text\": \"%s %s\"}", when.Format("2006-01-02 15:04:05"), msg)
+
+ form := url.Values{}
+ form.Add("payload", text)
+
+ resp, err := http.PostForm(s.WebhookURL, form)
+ if err != nil {
+ return err
+ }
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode)
+ }
+ resp.Body.Close()
+ return nil
+}
+
+// Flush implementing method. empty.
+func (s *SLACKWriter) Flush() {
+ return
+}
+
+// Destroy implementing method. empty.
+func (s *SLACKWriter) Destroy() {
+ return
+}
+
+func init() {
+ Register(AdapterSlack, newSLACKWriter)
+}
diff --git a/mime.go b/mime.go
index e85fcb2a..ca2878ab 100644
--- a/mime.go
+++ b/mime.go
@@ -339,7 +339,7 @@ var mimemaps = map[string]string{
".pvu": "paleovu/x-pv",
".pwz": "application/vndms-powerpoint",
".py": "text/x-scriptphyton",
- ".pyc": "applicaiton/x-bytecodepython",
+ ".pyc": "application/x-bytecodepython",
".qcp": "audio/vndqcelp",
".qd3": "x-world/x-3dmf",
".qd3d": "x-world/x-3dmf",
diff --git a/orm/db.go b/orm/db.go
index 78c72e87..30d8ae4e 100644
--- a/orm/db.go
+++ b/orm/db.go
@@ -243,6 +243,9 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
if fi.isFielder {
f := field.Addr().Interface().(Fielder)
f.SetRaw(tnow.In(DefaultTimeLoc))
+ } else if field.Kind() == reflect.Ptr {
+ v := tnow.In(DefaultTimeLoc)
+ field.Set(reflect.ValueOf(&v))
} else {
field.Set(reflect.ValueOf(tnow.In(DefaultTimeLoc)))
}
@@ -307,7 +310,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
}
// query sql ,read records and persist in dbBaser.
-func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error {
+func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
var whereCols []string
var args []interface{}
@@ -338,7 +341,12 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
sep = fmt.Sprintf("%s = ? AND %s", Q, Q)
wheres := strings.Join(whereCols, sep)
- query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q)
+ forUpdate := ""
+ if isForUpdate {
+ forUpdate = "FOR UPDATE"
+ }
+
+ query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ? %s", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q, forUpdate)
refs := make([]interface{}, colsNum)
for i := range refs {
@@ -485,6 +493,110 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s
return id, err
}
+// InsertOrUpdate a row
+// If your primary key or unique column conflict will update
+// If no will insert
+func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
+ args0 := ""
+ iouStr := ""
+ argsMap := map[string]string{}
+ switch a.Driver {
+ case DRMySQL:
+ iouStr = "ON DUPLICATE KEY UPDATE"
+ case DRPostgres:
+ if len(args) == 0 {
+ return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName)
+ } else {
+ args0 = strings.ToLower(args[0])
+ iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0)
+ }
+ default:
+ return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName)
+ }
+
+ //Get on the key-value pairs
+ for _, v := range args {
+ kv := strings.Split(v, "=")
+ if len(kv) == 2 {
+ argsMap[strings.ToLower(kv[0])] = kv[1]
+ }
+ }
+
+ isMulti := false
+ names := make([]string, 0, len(mi.fields.dbcols)-1)
+ Q := d.ins.TableQuote()
+ values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ)
+
+ if err != nil {
+ return 0, err
+ }
+
+ marks := make([]string, len(names))
+ updateValues := make([]interface{}, 0)
+ updates := make([]string, len(names))
+ var conflitValue interface{}
+ for i, v := range names {
+ marks[i] = "?"
+ valueStr := argsMap[strings.ToLower(v)]
+ if v == args0 {
+ conflitValue = values[i]
+ }
+ if valueStr != "" {
+ switch a.Driver {
+ case DRMySQL:
+ updates[i] = v + "=" + valueStr
+ case DRPostgres:
+ if conflitValue != nil {
+ //postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values
+ updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.table, args0)
+ updateValues = append(updateValues, conflitValue)
+ } else {
+ return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v)
+ }
+ }
+ } else {
+ updates[i] = v + "=?"
+ updateValues = append(updateValues, values[i])
+ }
+ }
+
+ values = append(values, updateValues...)
+
+ sep := fmt.Sprintf("%s, %s", Q, Q)
+ qmarks := strings.Join(marks, ", ")
+ qupdates := strings.Join(updates, ", ")
+ columns := strings.Join(names, sep)
+
+ multi := len(values) / len(names)
+
+ if isMulti {
+ qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
+ }
+ //conflitValue maybe is a int,can`t use fmt.Sprintf
+ query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr)
+
+ d.ins.ReplaceMarks(&query)
+
+ if isMulti || !d.ins.HasReturningID(mi, &query) {
+ res, err := q.Exec(query, values...)
+ if err == nil {
+ if isMulti {
+ return res.RowsAffected()
+ }
+ return res.LastInsertId()
+ }
+ return 0, err
+ }
+
+ row := q.QueryRow(query, values...)
+ var id int64
+ err = row.Scan(&id)
+ if err.Error() == `pq: syntax error at or near "ON"` {
+ err = fmt.Errorf("postgres version must 9.5 or higher")
+ }
+ return id, err
+}
+
// execute update sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind)
@@ -527,18 +639,36 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
// execute delete sql dbQuerier with given struct reflect.Value.
// delete index is pk.
-func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
- pkName, pkValue, ok := getExistPk(mi, ind)
- if ok == false {
- return 0, ErrMissPK
+func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
+ var whereCols []string
+ var args []interface{}
+ // if specify cols length > 0, then use it for where condition.
+ if len(cols) > 0 {
+ var err error
+ whereCols = make([]string, 0, len(cols))
+ args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz)
+ if err != nil {
+ return 0, err
+ }
+ } else {
+ // default use pk value as where condtion.
+ pkColumn, pkValue, ok := getExistPk(mi, ind)
+ if ok == false {
+ return 0, ErrMissPK
+ }
+ whereCols = []string{pkColumn}
+ args = append(args, pkValue)
}
Q := d.ins.TableQuote()
- query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, pkName, Q)
+ sep := fmt.Sprintf("%s = ? AND %s", Q, Q)
+ wheres := strings.Join(whereCols, sep)
+
+ query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q)
d.ins.ReplaceMarks(&query)
- res, err := q.Exec(query, pkValue)
+ res, err := q.Exec(query, args...)
if err == nil {
num, err := res.RowsAffected()
if err != nil {
@@ -552,7 +682,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0)
}
}
- err := d.deleteRels(q, mi, []interface{}{pkValue}, tz)
+ err := d.deleteRels(q, mi, args, tz)
if err != nil {
return num, err
}
@@ -957,12 +1087,17 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
tables.parseRelated(qs.related, qs.relDepth)
where, args := tables.getCondSQL(cond, false, tz)
+ groupBy := tables.getGroupSQL(qs.groups)
tables.getOrderSQL(qs.orders)
join := tables.getJoinSQL()
Q := d.ins.TableQuote()
- query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s", Q, mi.table, Q, join, where)
+ query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s", Q, mi.table, Q, join, where, groupBy)
+
+ if groupBy != "" {
+ query = fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS T", query)
+ }
d.ins.ReplaceMarks(&query)
@@ -1273,8 +1408,14 @@ setValue:
if isNative {
if value == nil {
value = time.Time{}
+ } else if field.Kind() == reflect.Ptr {
+ if value != nil {
+ v := value.(time.Time)
+ field.Set(reflect.ValueOf(&v))
+ }
+ } else {
+ field.Set(reflect.ValueOf(value))
}
- field.Set(reflect.ValueOf(value))
}
case fieldType == TypePositiveBitField && field.Kind() == reflect.Ptr:
if value != nil {
diff --git a/orm/db_alias.go b/orm/db_alias.go
index b6c833a7..c95d49c9 100644
--- a/orm/db_alias.go
+++ b/orm/db_alias.go
@@ -80,7 +80,7 @@ type _dbCache struct {
func (ac *_dbCache) add(name string, al *alias) (added bool) {
ac.mux.Lock()
defer ac.mux.Unlock()
- if _, ok := ac.cache[name]; ok == false {
+ if _, ok := ac.cache[name]; !ok {
ac.cache[name] = al
added = true
}
diff --git a/orm/db_utils.go b/orm/db_utils.go
index cf465d02..0279a14a 100644
--- a/orm/db_utils.go
+++ b/orm/db_utils.go
@@ -145,7 +145,7 @@ outFor:
if v, ok := arg.(time.Time); ok {
if fi != nil && fi.fieldType == TypeDateField {
arg = v.In(tz).Format(formatDate)
- } else if fi.fieldType == TypeDateTimeField {
+ } else if fi != nil && fi.fieldType == TypeDateTimeField {
arg = v.In(tz).Format(formatDateTime)
} else {
arg = v.In(tz).Format(formatTime)
@@ -154,7 +154,7 @@ outFor:
typ := val.Type()
name := getFullName(typ)
var value interface{}
- if mmi, ok := modelCache.getByFN(name); ok {
+ if mmi, ok := modelCache.getByFullName(name); ok {
if _, vu, exist := getExistPk(mmi, val); exist {
value = vu
}
diff --git a/orm/models.go b/orm/models.go
index faf551be..1d5a4dc2 100644
--- a/orm/models.go
+++ b/orm/models.go
@@ -29,39 +29,18 @@ const (
var (
modelCache = &_modelCache{
- cache: make(map[string]*modelInfo),
- cacheByFN: make(map[string]*modelInfo),
- }
- supportTag = map[string]int{
- "-": 1,
- "null": 1,
- "index": 1,
- "unique": 1,
- "pk": 1,
- "auto": 1,
- "auto_now": 1,
- "auto_now_add": 1,
- "size": 2,
- "column": 2,
- "default": 2,
- "rel": 2,
- "reverse": 2,
- "rel_table": 2,
- "rel_through": 2,
- "digits": 2,
- "decimals": 2,
- "on_delete": 2,
- "type": 2,
+ cache: make(map[string]*modelInfo),
+ cacheByFullName: make(map[string]*modelInfo),
}
)
// model info collection
type _modelCache struct {
- sync.RWMutex
- orders []string
- cache map[string]*modelInfo
- cacheByFN map[string]*modelInfo
- done bool
+ sync.RWMutex // only used outsite for bootStrap
+ orders []string
+ cache map[string]*modelInfo
+ cacheByFullName map[string]*modelInfo
+ done bool
}
// get all model info
@@ -88,9 +67,9 @@ func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) {
return
}
-// get model info by field name
-func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) {
- mi, ok = mc.cacheByFN[name]
+// get model info by full name
+func (mc *_modelCache) getByFullName(name string) (mi *modelInfo, ok bool) {
+ mi, ok = mc.cacheByFullName[name]
return
}
@@ -98,7 +77,7 @@ func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) {
func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
mii := mc.cache[table]
mc.cache[table] = mi
- mc.cacheByFN[mi.fullName] = mi
+ mc.cacheByFullName[mi.fullName] = mi
if mii == nil {
mc.orders = append(mc.orders, table)
}
@@ -109,7 +88,7 @@ func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
func (mc *_modelCache) clean() {
mc.orders = make([]string, 0)
mc.cache = make(map[string]*modelInfo)
- mc.cacheByFN = make(map[string]*modelInfo)
+ mc.cacheByFullName = make(map[string]*modelInfo)
mc.done = false
}
diff --git a/orm/models_boot.go b/orm/models_boot.go
index c9905330..4dbb54a9 100644
--- a/orm/models_boot.go
+++ b/orm/models_boot.go
@@ -15,7 +15,6 @@
package orm
import (
- "errors"
"fmt"
"os"
"reflect"
@@ -23,24 +22,34 @@ import (
)
// register models.
-// prefix means table name prefix.
-func registerModel(prefix string, model interface{}) {
+// PrefixOrSuffix means table name prefix or suffix.
+// isPrefix whether the prefix is prefix or suffix
+func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) {
val := reflect.ValueOf(model)
- ind := reflect.Indirect(val)
- typ := ind.Type()
+ typ := reflect.Indirect(val).Type()
if val.Kind() != reflect.Ptr {
panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ)))
}
+ // For this case:
+ // u := &User{}
+ // registerModel(&u)
+ if typ.Kind() == reflect.Ptr {
+ panic(fmt.Errorf(" only allow ptr model struct, it looks you use two reference to the struct `%s`", typ))
+ }
table := getTableName(val)
- if prefix != "" {
- table = prefix + table
+ if PrefixOrSuffix != "" {
+ if isPrefix {
+ table = PrefixOrSuffix + table
+ } else {
+ table = table + PrefixOrSuffix
+ }
}
-
+ // models's fullname is pkgpath + struct name
name := getFullName(typ)
- if _, ok := modelCache.getByFN(name); ok {
+ if _, ok := modelCache.getByFullName(name); ok {
fmt.Printf(" model `%s` repeat register, must be unique\n", name)
os.Exit(2)
}
@@ -50,34 +59,34 @@ func registerModel(prefix string, model interface{}) {
os.Exit(2)
}
- info := newModelInfo(val)
- if info.fields.pk == nil {
+ mi := newModelInfo(val)
+ if mi.fields.pk == nil {
outFor:
- for _, fi := range info.fields.fieldsDB {
+ for _, fi := range mi.fields.fieldsDB {
if strings.ToLower(fi.name) == "id" {
switch fi.addrValue.Elem().Kind() {
case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
fi.auto = true
fi.pk = true
- info.fields.pk = fi
+ mi.fields.pk = fi
break outFor
}
}
}
- if info.fields.pk == nil {
+ if mi.fields.pk == nil {
fmt.Printf(" `%s` need a primary key field, default use 'id' if not set\n", name)
os.Exit(2)
}
}
- info.table = table
- info.pkg = typ.PkgPath()
- info.model = model
- info.manual = true
+ mi.table = table
+ mi.pkg = typ.PkgPath()
+ mi.model = model
+ mi.manual = true
- modelCache.set(table, info)
+ modelCache.set(table, mi)
}
// boostrap models
@@ -85,12 +94,10 @@ func bootStrap() {
if modelCache.done {
return
}
-
var (
err error
models map[string]*modelInfo
)
-
if dataBaseCache.getDefault() == nil {
err = fmt.Errorf("must have one register DataBase alias named `default`")
goto end
@@ -101,14 +108,13 @@ func bootStrap() {
for _, fi := range mi.fields.columns {
if fi.rel || fi.reverse {
elm := fi.addrValue.Type().Elem()
- switch fi.fieldType {
- case RelReverseMany, RelManyToMany:
+ if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany {
elm = elm.Elem()
}
-
+ // check the rel or reverse model already register
name := getFullName(elm)
- mii, ok := modelCache.getByFN(name)
- if ok == false || mii.pkg != elm.PkgPath() {
+ mii, ok := modelCache.getByFullName(name)
+ if !ok || mii.pkg != elm.PkgPath() {
err = fmt.Errorf("can not found rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String())
goto end
}
@@ -117,20 +123,17 @@ func bootStrap() {
switch fi.fieldType {
case RelManyToMany:
if fi.relThrough != "" {
- msg := fmt.Sprintf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough)
if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) {
pn := fi.relThrough[:i]
- rmi, ok := modelCache.getByFN(fi.relThrough)
+ rmi, ok := modelCache.getByFullName(fi.relThrough)
if ok == false || pn != rmi.pkg {
- err = errors.New(msg + " cannot find table")
+ err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough)
goto end
}
-
fi.relThroughModelInfo = rmi
fi.relTable = rmi.table
-
} else {
- err = errors.New(msg)
+ err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough)
goto end
}
} else {
@@ -138,7 +141,6 @@ func bootStrap() {
if fi.relTable != "" {
i.table = fi.relTable
}
-
if v := modelCache.set(i.table, i); v != nil {
err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable)
goto end
@@ -216,7 +218,6 @@ func bootStrap() {
}
}
}
-
if fi.reverseFieldInfoTwo == nil {
err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct",
fi.relThroughModelInfo.fullName)
@@ -300,17 +301,31 @@ end:
// RegisterModel register models
func RegisterModel(models ...interface{}) {
+ if modelCache.done {
+ panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
+ }
RegisterModelWithPrefix("", models...)
}
// RegisterModelWithPrefix register models with a prefix
func RegisterModelWithPrefix(prefix string, models ...interface{}) {
if modelCache.done {
- panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
+ panic(fmt.Errorf("RegisterModelWithPrefix must be run before BootStrap"))
}
for _, model := range models {
- registerModel(prefix, model)
+ registerModel(prefix, model, true)
+ }
+}
+
+// RegisterModelWithSuffix register models with a suffix
+func RegisterModelWithSuffix(suffix string, models ...interface{}) {
+ if modelCache.done {
+ panic(fmt.Errorf("RegisterModelWithSuffix must be run before BootStrap"))
+ }
+
+ for _, model := range models {
+ registerModel(suffix, model, false)
}
}
@@ -320,7 +335,6 @@ func BootStrap() {
if modelCache.done {
return
}
-
modelCache.Lock()
defer modelCache.Unlock()
bootStrap()
diff --git a/orm/models_info_f.go b/orm/models_info_f.go
index be6c9aa4..4b3d3e27 100644
--- a/orm/models_info_f.go
+++ b/orm/models_info_f.go
@@ -104,7 +104,7 @@ type fieldInfo struct {
mi *modelInfo
fieldIndex []int
fieldType int
- dbcol bool
+ dbcol bool // table column fk and onetoone
inModel bool
name string
fullName string
@@ -116,13 +116,13 @@ type fieldInfo struct {
null bool
index bool
unique bool
- colDefault bool
- initial StrTo
+ colDefault bool // whether has default tag
+ initial StrTo // store the default value
size int
toText bool
autoNow bool
autoNowAdd bool
- rel bool
+ rel bool // if type equal to RelForeignKey, RelOneToOne, RelManyToMany then true
reverse bool
reverseField string
reverseFieldInfo *fieldInfo
@@ -134,7 +134,7 @@ type fieldInfo struct {
relModelInfo *modelInfo
digits int
decimals int
- isFielder bool
+ isFielder bool // implement Fielder interface
onDelete string
}
@@ -143,7 +143,7 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mN
var (
tag string
tagValue string
- initial StrTo
+ initial StrTo // store the default value
fieldType int
attrs map[string]bool
tags map[string]string
@@ -152,6 +152,10 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mN
fi = new(fieldInfo)
+ // if field which CanAddr is the follow type
+ // A value is addressable if it is an element of a slice,
+ // an element of an addressable array, a field of an
+ // addressable struct, or the result of dereferencing a pointer.
addrField = field
if field.CanAddr() && field.Kind() != reflect.Ptr {
addrField = field.Addr()
@@ -162,7 +166,7 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mN
}
}
- parseStructTag(sf.Tag.Get(defaultStructTagName), &attrs, &tags)
+ attrs, tags = parseStructTag(sf.Tag.Get(defaultStructTagName))
if _, ok := attrs["-"]; ok {
return nil, errSkipField
@@ -188,7 +192,7 @@ checkType:
}
fieldType = f.FieldType()
if fieldType&IsRelField > 0 {
- err = fmt.Errorf("unsupport rel type custom field")
+ err = fmt.Errorf("unsupport type custom field, please refer to https://github.com/astaxie/beego/blob/master/orm/models_fields.go#L24-L42")
goto end
}
default:
@@ -211,7 +215,7 @@ checkType:
}
break checkType
default:
- err = fmt.Errorf("error")
+ err = fmt.Errorf("rel only allow these value: fk, one, m2m")
goto wrongTag
}
}
@@ -231,7 +235,7 @@ checkType:
}
break checkType
default:
- err = fmt.Errorf("error")
+ err = fmt.Errorf("reverse only allow these value: one, many")
goto wrongTag
}
}
@@ -261,6 +265,9 @@ checkType:
}
}
+ // check the rel and reverse type
+ // rel should Ptr
+ // reverse should slice []*struct
switch fieldType {
case RelForeignKey, RelOneToOne, RelReverseOne:
if field.Kind() != reflect.Ptr {
@@ -399,14 +406,12 @@ checkType:
if fi.auto || fi.pk {
if fi.auto {
-
switch addrField.Elem().Kind() {
case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64:
default:
err = fmt.Errorf("auto primary key only support int, int32, int64, uint, uint32, uint64 but found `%s`", addrField.Elem().Kind())
goto end
}
-
fi.pk = true
}
fi.null = false
@@ -418,8 +423,8 @@ checkType:
fi.index = false
}
+ // can not set default for these type
if fi.auto || fi.pk || fi.unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField {
- // can not set default
initial.Clear()
}
diff --git a/orm/models_info_m.go b/orm/models_info_m.go
index bbb82444..d6ba1dca 100644
--- a/orm/models_info_m.go
+++ b/orm/models_info_m.go
@@ -29,31 +29,25 @@ type modelInfo struct {
model interface{}
fields *fields
manual bool
- addrField reflect.Value
+ addrField reflect.Value //store the original struct value
uniques []string
isThrough bool
}
// new model info
-func newModelInfo(val reflect.Value) (info *modelInfo) {
-
- info = &modelInfo{}
- info.fields = newFields()
-
+func newModelInfo(val reflect.Value) (mi *modelInfo) {
+ mi = &modelInfo{}
+ mi.fields = newFields()
ind := reflect.Indirect(val)
- typ := ind.Type()
-
- info.addrField = val
-
- info.name = typ.Name()
- info.fullName = getFullName(typ)
-
- addModelFields(info, ind, "", []int{})
-
+ mi.addrField = val
+ mi.name = ind.Type().Name()
+ mi.fullName = getFullName(ind.Type())
+ addModelFields(mi, ind, "", []int{})
return
}
-func addModelFields(info *modelInfo, ind reflect.Value, mName string, index []int) {
+// index: FieldByIndex returns the nested field corresponding to index
+func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int) {
var (
err error
fi *fieldInfo
@@ -63,43 +57,39 @@ func addModelFields(info *modelInfo, ind reflect.Value, mName string, index []in
for i := 0; i < ind.NumField(); i++ {
field := ind.Field(i)
sf = ind.Type().Field(i)
+ // if the field is unexported skip
if sf.PkgPath != "" {
continue
}
// add anonymous struct fields
if sf.Anonymous {
- addModelFields(info, field, mName+"."+sf.Name, append(index, i))
+ addModelFields(mi, field, mName+"."+sf.Name, append(index, i))
continue
}
- fi, err = newFieldInfo(info, field, sf, mName)
-
- if err != nil {
- if err == errSkipField {
- err = nil
- continue
- }
+ fi, err = newFieldInfo(mi, field, sf, mName)
+ if err == errSkipField {
+ err = nil
+ continue
+ } else if err != nil {
break
}
-
- added := info.fields.Add(fi)
- if added == false {
+ //record current field index
+ fi.fieldIndex = append(index, i)
+ fi.mi = mi
+ fi.inModel = true
+ if mi.fields.Add(fi) == false {
err = fmt.Errorf("duplicate column name: %s", fi.column)
break
}
-
if fi.pk {
- if info.fields.pk != nil {
+ if mi.fields.pk != nil {
err = fmt.Errorf("one model must have one pk field only")
break
} else {
- info.fields.pk = fi
+ mi.fields.pk = fi
}
}
-
- fi.fieldIndex = append(index, i)
- fi.mi = info
- fi.inModel = true
}
if err != nil {
@@ -110,23 +100,23 @@ func addModelFields(info *modelInfo, ind reflect.Value, mName string, index []in
// combine related model info to new model info.
// prepare for relation models query.
-func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
- info = new(modelInfo)
- info.fields = newFields()
- info.table = m1.table + "_" + m2.table + "s"
- info.name = camelString(info.table)
- info.fullName = m1.pkg + "." + info.name
+func newM2MModelInfo(m1, m2 *modelInfo) (mi *modelInfo) {
+ mi = new(modelInfo)
+ mi.fields = newFields()
+ mi.table = m1.table + "_" + m2.table + "s"
+ mi.name = camelString(mi.table)
+ mi.fullName = m1.pkg + "." + mi.name
- fa := new(fieldInfo)
- f1 := new(fieldInfo)
- f2 := new(fieldInfo)
+ fa := new(fieldInfo) // pk
+ f1 := new(fieldInfo) // m1 table RelForeignKey
+ f2 := new(fieldInfo) // m2 table RelForeignKey
fa.fieldType = TypeBigIntegerField
fa.auto = true
fa.pk = true
fa.dbcol = true
fa.name = "Id"
fa.column = "id"
- fa.fullName = info.fullName + "." + fa.name
+ fa.fullName = mi.fullName + "." + fa.name
f1.dbcol = true
f2.dbcol = true
@@ -134,8 +124,8 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
f2.fieldType = RelForeignKey
f1.name = camelString(m1.table)
f2.name = camelString(m2.table)
- f1.fullName = info.fullName + "." + f1.name
- f2.fullName = info.fullName + "." + f2.name
+ f1.fullName = mi.fullName + "." + f1.name
+ f2.fullName = mi.fullName + "." + f2.name
f1.column = m1.table + "_id"
f2.column = m2.table + "_id"
f1.rel = true
@@ -144,14 +134,14 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
f2.relTable = m2.table
f1.relModelInfo = m1
f2.relModelInfo = m2
- f1.mi = info
- f2.mi = info
+ f1.mi = mi
+ f2.mi = mi
- info.fields.Add(fa)
- info.fields.Add(f1)
- info.fields.Add(f2)
- info.fields.pk = fa
+ mi.fields.Add(fa)
+ mi.fields.Add(f1)
+ mi.fields.Add(f2)
+ mi.fields.pk = fa
- info.uniques = []string{f1.column, f2.column}
+ mi.uniques = []string{f1.column, f2.column}
return
}
diff --git a/orm/models_test.go b/orm/models_test.go
index c68c7339..462370b2 100644
--- a/orm/models_test.go
+++ b/orm/models_test.go
@@ -181,6 +181,9 @@ type DataNull struct {
Float32Ptr *float32 `orm:"null"`
Float64Ptr *float64 `orm:"null"`
DecimalPtr *float64 `orm:"digits(8);decimals(4);null"`
+ TimePtr *time.Time `orm:"null;type(time)"`
+ DatePtr *time.Time `orm:"null;type(date)"`
+ DateTimePtr *time.Time `orm:"null"`
}
type String string
diff --git a/orm/models_utils.go b/orm/models_utils.go
index ec11d516..44a0e76a 100644
--- a/orm/models_utils.go
+++ b/orm/models_utils.go
@@ -22,25 +22,47 @@ import (
"time"
)
+// 1 is attr
+// 2 is tag
+var supportTag = map[string]int{
+ "-": 1,
+ "null": 1,
+ "index": 1,
+ "unique": 1,
+ "pk": 1,
+ "auto": 1,
+ "auto_now": 1,
+ "auto_now_add": 1,
+ "size": 2,
+ "column": 2,
+ "default": 2,
+ "rel": 2,
+ "reverse": 2,
+ "rel_table": 2,
+ "rel_through": 2,
+ "digits": 2,
+ "decimals": 2,
+ "on_delete": 2,
+ "type": 2,
+}
+
// get reflect.Type name with package path.
func getFullName(typ reflect.Type) string {
return typ.PkgPath() + "." + typ.Name()
}
-// get table name. method, or field name. auto snaked.
+// getTableName get struct table name.
+// If the struct implement the TableName, then get the result as tablename
+// else use the struct name which will apply snakeString.
func getTableName(val reflect.Value) string {
- ind := reflect.Indirect(val)
- fun := val.MethodByName("TableName")
- if fun.IsValid() {
+ if fun := val.MethodByName("TableName"); fun.IsValid() {
vals := fun.Call([]reflect.Value{})
- if len(vals) > 0 {
- val := vals[0]
- if val.Kind() == reflect.String {
- return val.String()
- }
+ // has return and the first val is string
+ if len(vals) > 0 && vals[0].Kind() == reflect.String {
+ return vals[0].String()
}
}
- return snakeString(ind.Type().Name())
+ return snakeString(reflect.Indirect(val).Type().Name())
}
// get table engine, mysiam or innodb.
@@ -48,11 +70,8 @@ func getTableEngine(val reflect.Value) string {
fun := val.MethodByName("TableEngine")
if fun.IsValid() {
vals := fun.Call([]reflect.Value{})
- if len(vals) > 0 {
- val := vals[0]
- if val.Kind() == reflect.String {
- return val.String()
- }
+ if len(vals) > 0 && vals[0].Kind() == reflect.String {
+ return vals[0].String()
}
}
return ""
@@ -63,12 +82,9 @@ func getTableIndex(val reflect.Value) [][]string {
fun := val.MethodByName("TableIndex")
if fun.IsValid() {
vals := fun.Call([]reflect.Value{})
- if len(vals) > 0 {
- val := vals[0]
- if val.CanInterface() {
- if d, ok := val.Interface().([][]string); ok {
- return d
- }
+ if len(vals) > 0 && vals[0].CanInterface() {
+ if d, ok := vals[0].Interface().([][]string); ok {
+ return d
}
}
}
@@ -80,12 +96,9 @@ func getTableUnique(val reflect.Value) [][]string {
fun := val.MethodByName("TableUnique")
if fun.IsValid() {
vals := fun.Call([]reflect.Value{})
- if len(vals) > 0 {
- val := vals[0]
- if val.CanInterface() {
- if d, ok := val.Interface().([][]string); ok {
- return d
- }
+ if len(vals) > 0 && vals[0].CanInterface() {
+ if d, ok := vals[0].Interface().([][]string); ok {
+ return d
}
}
}
@@ -137,6 +150,8 @@ func getFieldType(val reflect.Value) (ft int, err error) {
ft = TypeBooleanField
case reflect.TypeOf(new(string)):
ft = TypeCharField
+ case reflect.TypeOf(new(time.Time)):
+ ft = TypeDateTimeField
default:
elm := reflect.Indirect(val)
switch elm.Kind() {
@@ -187,21 +202,25 @@ func getFieldType(val reflect.Value) (ft int, err error) {
}
// parse struct tag string
-func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) {
- attr := make(map[string]bool)
- tag := make(map[string]string)
+func parseStructTag(data string) (attrs map[string]bool, tags map[string]string) {
+ attrs = make(map[string]bool)
+ tags = make(map[string]string)
for _, v := range strings.Split(data, defaultStructTagDelim) {
+ if v == "" {
+ continue
+ }
v = strings.TrimSpace(v)
- if supportTag[v] == 1 {
- attr[v] = true
+ if t := strings.ToLower(v); supportTag[t] == 1 {
+ attrs[t] = true
} else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 {
- name := v[:i]
+ name := t[:i]
if supportTag[name] == 2 {
v = v[i+1 : len(v)-1]
- tag[name] = v
+ tags[name] = v
}
+ } else {
+ DebugLog.Println("unsupport orm tag", v)
}
}
- *attrs = attr
- *tags = tag
+ return
}
diff --git a/orm/orm.go b/orm/orm.go
index 5e43ae59..538916e4 100644
--- a/orm/orm.go
+++ b/orm/orm.go
@@ -68,7 +68,7 @@ const (
// Define common vars
var (
Debug = false
- DebugLog = NewLog(os.Stderr)
+ DebugLog = NewLog(os.Stdout)
DefaultRowsLimit = 1000
DefaultRelsDepth = 2
DefaultTimeLoc = time.Local
@@ -104,7 +104,7 @@ func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect
panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ)))
}
name := getFullName(typ)
- if mi, ok := modelCache.getByFN(name); ok {
+ if mi, ok := modelCache.getByFullName(name); ok {
return mi, ind
}
panic(fmt.Errorf(" table: `%s` not found, maybe not RegisterModel", name))
@@ -122,7 +122,17 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
// read data to model
func (o *orm) Read(md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md, true)
- err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
+ err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+// read data to model, like Read(), but use "SELECT FOR UPDATE" form
+func (o *orm) ReadForUpdate(md interface{}, cols ...string) error {
+ mi, ind := o.getMiInd(md, true)
+ err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true)
if err != nil {
return err
}
@@ -133,7 +143,7 @@ func (o *orm) Read(md interface{}, cols ...string) error {
func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) {
cols = append([]string{col1}, cols...)
mi, ind := o.getMiInd(md, true)
- err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
+ err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
if err == ErrNoRows {
// Create
id, err := o.Insert(md)
@@ -209,6 +219,19 @@ func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
return cnt, nil
}
+// InsertOrUpdate data to database
+func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) {
+ mi, ind := o.getMiInd(md, true)
+ id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...)
+ if err != nil {
+ return id, err
+ }
+
+ o.setPk(mi, ind, id)
+
+ return id, nil
+}
+
// update model to database.
// cols set the columns those want to update.
func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
@@ -221,9 +244,10 @@ func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
}
// delete model in database
-func (o *orm) Delete(md interface{}) (int64, error) {
+// cols shows the delete conditions values read from. deafult is pk
+func (o *orm) Delete(md interface{}, cols ...string) (int64, error) {
mi, ind := o.getMiInd(md, true)
- num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ)
+ num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols)
if err != nil {
return num, err
}
@@ -414,7 +438,7 @@ func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
}
} else {
name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName)))
- if mi, ok := modelCache.getByFN(name); ok {
+ if mi, ok := modelCache.getByFullName(name); ok {
qs = newQuerySet(o, mi)
}
}
diff --git a/orm/orm_log.go b/orm/orm_log.go
index 54723273..26c73f9e 100644
--- a/orm/orm_log.go
+++ b/orm/orm_log.go
@@ -42,7 +42,7 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error
if err != nil {
flag = "FAIL"
}
- con := fmt.Sprintf(" - %s - [Queries/%s] - [%s / %11s / %7.1fms] - [%s]", t.Format(formatDateTime), alias.Name, flag, operaton, elsp, query)
+ con := fmt.Sprintf(" -[Queries/%s] - [%s / %11s / %7.1fms] - [%s]", alias.Name, flag, operaton, elsp, query)
cons := make([]string, 0, len(args))
for _, arg := range args {
cons = append(cons, fmt.Sprintf("%v", arg))
diff --git a/orm/orm_raw.go b/orm/orm_raw.go
index 5f88121c..a968b1a1 100644
--- a/orm/orm_raw.go
+++ b/orm/orm_raw.go
@@ -286,7 +286,7 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
structMode = true
fn := getFullName(typ)
- if mi, ok := modelCache.getByFN(fn); ok {
+ if mi, ok := modelCache.getByFullName(fn); ok {
sMi = mi
}
} else {
@@ -342,19 +342,22 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
- o.setFieldValue(ind.FieldByIndex(fi.fieldIndex), value)
+ field := ind.FieldByIndex(fi.fieldIndex)
+ if fi.fieldType&IsRelField > 0 {
+ mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
+ field.Set(mf)
+ field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex)
+ }
+ o.setFieldValue(field, value)
}
}
} else {
for i := 0; i < ind.NumField(); i++ {
f := ind.Field(i)
fe := ind.Type().Field(i)
-
- var attrs map[string]bool
- var tags map[string]string
- parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
+ _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
var col string
- if col = tags["column"]; len(col) == 0 {
+ if col = tags["column"]; col == "" {
col = snakeString(fe.Name)
}
if v, ok := columnsMp[col]; ok {
@@ -416,7 +419,7 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
structMode = true
fn := getFullName(typ)
- if mi, ok := modelCache.getByFN(fn); ok {
+ if mi, ok := modelCache.getByFullName(fn); ok {
sMi = mi
}
} else {
@@ -480,19 +483,22 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
- o.setFieldValue(ind.FieldByIndex(fi.fieldIndex), value)
+ field := ind.FieldByIndex(fi.fieldIndex)
+ if fi.fieldType&IsRelField > 0 {
+ mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
+ field.Set(mf)
+ field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex)
+ }
+ o.setFieldValue(field, value)
}
}
} else {
for i := 0; i < ind.NumField(); i++ {
f := ind.Field(i)
fe := ind.Type().Field(i)
-
- var attrs map[string]bool
- var tags map[string]string
- parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
+ _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
var col string
- if col = tags["column"]; len(col) == 0 {
+ if col = tags["column"]; col == "" {
col = snakeString(fe.Name)
}
if v, ok := columnsMp[col]; ok {
diff --git a/orm/orm_test.go b/orm/orm_test.go
index 11f6bd56..fbf4768d 100644
--- a/orm/orm_test.go
+++ b/orm/orm_test.go
@@ -227,7 +227,7 @@ func TestModelSyntax(t *testing.T) {
user := &User{}
ind := reflect.ValueOf(user).Elem()
fn := getFullName(ind.Type())
- mi, ok := modelCache.getByFN(fn)
+ mi, ok := modelCache.getByFullName(fn)
throwFail(t, AssertIs(ok, true))
mi, ok = modelCache.get("user")
@@ -350,6 +350,9 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, AssertIs(d.Float32Ptr, nil))
throwFail(t, AssertIs(d.Float64Ptr, nil))
throwFail(t, AssertIs(d.DecimalPtr, nil))
+ throwFail(t, AssertIs(d.TimePtr, nil))
+ throwFail(t, AssertIs(d.DatePtr, nil))
+ throwFail(t, AssertIs(d.DateTimePtr, nil))
_, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec()
throwFail(t, err)
@@ -376,6 +379,9 @@ func TestNullDataTypes(t *testing.T) {
float32Ptr := float32(42.0)
float64Ptr := float64(42.0)
decimalPtr := float64(42.0)
+ timePtr := time.Now()
+ datePtr := time.Now()
+ dateTimePtr := time.Now()
d = DataNull{
DateTime: time.Now(),
@@ -401,6 +407,9 @@ func TestNullDataTypes(t *testing.T) {
Float32Ptr: &float32Ptr,
Float64Ptr: &float64Ptr,
DecimalPtr: &decimalPtr,
+ TimePtr: &timePtr,
+ DatePtr: &datePtr,
+ DateTimePtr: &dateTimePtr,
}
id, err = dORM.Insert(&d)
@@ -441,6 +450,9 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr))
throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr))
throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr))
+ throwFail(t, AssertIs((*d.TimePtr).Format(testTime), timePtr.Format(testTime)))
+ throwFail(t, AssertIs((*d.DatePtr).Format(testDate), datePtr.Format(testDate)))
+ throwFail(t, AssertIs((*d.DateTimePtr).Format(testDateTime), dateTimePtr.Format(testDateTime)))
}
func TestDataCustomTypes(t *testing.T) {
@@ -565,6 +577,10 @@ func TestCRUD(t *testing.T) {
err = dORM.Read(&ub)
throwFail(t, err)
throwFail(t, AssertIs(ub.Name, "name"))
+
+ num, err = dORM.Delete(&ub, "name")
+ throwFail(t, err)
+ throwFail(t, AssertIs(num, 1))
}
func TestInsertTestData(t *testing.T) {
@@ -2050,7 +2066,7 @@ func TestIntegerPk(t *testing.T) {
throwFail(t, AssertIs(out.Value, intPk.Value))
}
- num, err = dORM.InsertMulti(1, []*IntegerPk{&IntegerPk{
+ num, err = dORM.InsertMulti(1, []*IntegerPk{{
ID: 1, Value: "ok",
}})
throwFail(t, err)
@@ -2117,3 +2133,134 @@ func TestUintPk(t *testing.T) {
dORM.Delete(u)
}
+
+func TestSnake(t *testing.T) {
+ cases := map[string]string{
+ "i": "i",
+ "I": "i",
+ "iD": "i_d",
+ "ID": "i_d",
+ "NO": "n_o",
+ "NOO": "n_o_o",
+ "NOOooOOoo": "n_o_ooo_o_ooo",
+ "OrderNO": "order_n_o",
+ "tagName": "tag_name",
+ "tag_Name": "tag__name",
+ "tag_name": "tag_name",
+ "_tag_name": "_tag_name",
+ "tag_666name": "tag_666name",
+ "tag_666Name": "tag_666_name",
+ }
+ for name, want := range cases {
+ got := snakeString(name)
+ throwFail(t, AssertIs(got, want))
+ }
+}
+
+func TestIgnoreCaseTag(t *testing.T) {
+ type testTagModel struct {
+ ID int `orm:"pk"`
+ NOO string `orm:"column(n)"`
+ Name01 string `orm:"NULL"`
+ Name02 string `orm:"COLUMN(Name)"`
+ Name03 string `orm:"Column(name)"`
+ }
+ modelCache.clean()
+ RegisterModel(&testTagModel{})
+ info, ok := modelCache.get("test_tag_model")
+ throwFail(t, AssertIs(ok, true))
+ throwFail(t, AssertNot(info, nil))
+ if t == nil {
+ return
+ }
+ throwFail(t, AssertIs(info.fields.GetByName("NOO").column, "n"))
+ throwFail(t, AssertIs(info.fields.GetByName("Name01").null, true))
+ throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name"))
+ throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name"))
+}
+func TestInsertOrUpdate(t *testing.T) {
+ RegisterModel(new(User))
+ user := User{UserName: "unique_username133", Status: 1, Password: "o"}
+ user1 := User{UserName: "unique_username133", Status: 2, Password: "o"}
+ user2 := User{UserName: "unique_username133", Status: 3, Password: "oo"}
+ dORM.Insert(&user)
+ test := User{UserName: "unique_username133"}
+ fmt.Println(dORM.Driver().Name())
+ if dORM.Driver().Name() == "sqlite3" {
+ fmt.Println("sqlite3 is nonsupport")
+ return
+ }
+ //test1
+ _, err := dORM.InsertOrUpdate(&user1, "user_name")
+ if err != nil {
+ fmt.Println(err)
+ if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" {
+ } else {
+ throwFailNow(t, err)
+ }
+ } else {
+ dORM.Read(&test, "user_name")
+ throwFailNow(t, AssertIs(user1.Status, test.Status))
+ }
+ //test2
+ _, err = dORM.InsertOrUpdate(&user2, "user_name")
+ if err != nil {
+ fmt.Println(err)
+ if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" {
+ } else {
+ throwFailNow(t, err)
+ }
+ } else {
+ dORM.Read(&test, "user_name")
+ throwFailNow(t, AssertIs(user2.Status, test.Status))
+ throwFailNow(t, AssertIs(user2.Password, strings.TrimSpace(test.Password)))
+ }
+ //test3 +
+ _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status+1")
+ if err != nil {
+ fmt.Println(err)
+ if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" {
+ } else {
+ throwFailNow(t, err)
+ }
+ } else {
+ dORM.Read(&test, "user_name")
+ throwFailNow(t, AssertIs(user2.Status+1, test.Status))
+ }
+ //test4 -
+ _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status-1")
+ if err != nil {
+ fmt.Println(err)
+ if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" {
+ } else {
+ throwFailNow(t, err)
+ }
+ } else {
+ dORM.Read(&test, "user_name")
+ throwFailNow(t, AssertIs((user2.Status+1)-1, test.Status))
+ }
+ //test5 *
+ _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status*3")
+ if err != nil {
+ fmt.Println(err)
+ if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" {
+ } else {
+ throwFailNow(t, err)
+ }
+ } else {
+ dORM.Read(&test, "user_name")
+ throwFailNow(t, AssertIs(((user2.Status+1)-1)*3, test.Status))
+ }
+ //test6 /
+ _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status/3")
+ if err != nil {
+ fmt.Println(err)
+ if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" {
+ } else {
+ throwFailNow(t, err)
+ }
+ } else {
+ dORM.Read(&test, "user_name")
+ throwFailNow(t, AssertIs((((user2.Status+1)-1)*3)/3, test.Status))
+ }
+}
diff --git a/orm/qb.go b/orm/qb.go
index 9f778916..e0655a17 100644
--- a/orm/qb.go
+++ b/orm/qb.go
@@ -19,6 +19,7 @@ import "errors"
// QueryBuilder is the Query builder interface
type QueryBuilder interface {
Select(fields ...string) QueryBuilder
+ ForUpdate() QueryBuilder
From(tables ...string) QueryBuilder
InnerJoin(table string) QueryBuilder
LeftJoin(table string) QueryBuilder
diff --git a/orm/qb_mysql.go b/orm/qb_mysql.go
index 886bc50e..23bdc9ee 100644
--- a/orm/qb_mysql.go
+++ b/orm/qb_mysql.go
@@ -34,6 +34,12 @@ func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder {
return qb
}
+// ForUpdate add the FOR UPDATE clause
+func (qb *MySQLQueryBuilder) ForUpdate() QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "FOR UPDATE")
+ return qb
+}
+
// From join the tables
func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace))
diff --git a/orm/qb_tidb.go b/orm/qb_tidb.go
index c504049e..87b3ae84 100644
--- a/orm/qb_tidb.go
+++ b/orm/qb_tidb.go
@@ -31,6 +31,12 @@ func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder {
return qb
}
+// ForUpdate add the FOR UPDATE clause
+func (qb *TiDBQueryBuilder) ForUpdate() QueryBuilder {
+ qb.Tokens = append(qb.Tokens, "FOR UPDATE")
+ return qb
+}
+
// From join the tables
func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder {
qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace))
diff --git a/orm/types.go b/orm/types.go
index cb55e71a..fd3062ab 100644
--- a/orm/types.go
+++ b/orm/types.go
@@ -45,6 +45,9 @@ type Ormer interface {
// u = &User{UserName: "astaxie", Password: "pass"}
// err = Ormer.Read(u, "UserName")
Read(md interface{}, cols ...string) error
+ // Like Read(), but with "FOR UPDATE" clause, useful in transaction.
+ // Some databases are not support this feature.
+ ReadForUpdate(md interface{}, cols ...string) error
// Try to read a row from the database, or insert one if it doesn't exist
ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error)
// insert model data to database
@@ -53,6 +56,11 @@ type Ormer interface {
// id, err = Ormer.Insert(user)
// user must a pointer and Insert will set user's pk field
Insert(interface{}) (int64, error)
+ // mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value")
+ // if colu type is integer : can use(+-*/), string : convert(colu,"value")
+ // postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value")
+ // if colu type is integer : can use(+-*/), string : colu || "value"
+ InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error)
// insert some models to database
InsertMulti(bulk int, mds interface{}) (int64, error)
// update model to database.
@@ -66,7 +74,7 @@ type Ormer interface {
// num, err = Ormer.Update(&user, "Langs", "Extra")
Update(md interface{}, cols ...string) (int64, error)
// delete model in database
- Delete(md interface{}) (int64, error)
+ Delete(md interface{}, cols ...string) (int64, error)
// load related models to md model.
// args are limit, offset int and order string.
//
@@ -389,13 +397,14 @@ type txEnder interface {
// base database struct
type dbBaser interface {
- Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
+ Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
+ InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error)
InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
- Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
+ Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
SupportUpdateJoin() bool
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
diff --git a/orm/utils.go b/orm/utils.go
index 99437c7b..6e23447e 100644
--- a/orm/utils.go
+++ b/orm/utils.go
@@ -181,7 +181,7 @@ func ToInt64(value interface{}) (d int64) {
return
}
-// snake string, XxYy to xx_yy
+// snake string, XxYy to xx_yy , XxYY to xx_yy
func snakeString(s string) string {
data := make([]byte, 0, len(s)*2)
j := false
diff --git a/parser.go b/parser.go
index 3bf3cf6b..ffcd27a4 100644
--- a/parser.go
+++ b/parser.go
@@ -101,7 +101,7 @@ func parserComments(comments *ast.CommentGroup, funcName, controllerName, pkgpat
elements := strings.TrimLeft(t, "@router ")
e1 := strings.SplitN(elements, " ", 2)
if len(e1) < 1 {
- return errors.New("you should has router infomation")
+ return errors.New("you should has router information")
}
key := pkgpath + ":" + controllerName
cc := ControllerComments{}
@@ -166,10 +166,10 @@ func genRouterCode(pkgRealpath string) {
globalinfo = globalinfo + `
beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"],
beego.ControllerComments{
- "` + strings.TrimSpace(c.Method) + `",
- ` + "`" + c.Router + "`" + `,
- ` + allmethod + `,
- ` + params + `})
+ Method: "` + strings.TrimSpace(c.Method) + `",
+ ` + "Router: `" + c.Router + "`" + `,
+ AllowHTTPMethods: ` + allmethod + `,
+ Params: ` + params + `})
`
}
}
diff --git a/router.go b/router.go
index 1e35895b..97b0edba 100644
--- a/router.go
+++ b/router.go
@@ -406,20 +406,27 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface)
}
// InsertFilter Add a FilterFunc with pattern rule and action constant.
-// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute)
+// params is for:
+// 1. setting the returnOnOutput value (false allows multiple filters to execute)
+// 2. determining whether or not params need to be reset.
func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error {
- mr := new(FilterRouter)
- mr.tree = NewTree()
- mr.pattern = pattern
- mr.filterFunc = filter
- if !BConfig.RouterCaseSensitive {
- pattern = strings.ToLower(pattern)
+ mr := &FilterRouter{
+ tree: NewTree(),
+ pattern: pattern,
+ filterFunc: filter,
+ returnOnOutput: true,
}
- if len(params) == 0 {
- mr.returnOnOutput = true
- } else {
+ if !BConfig.RouterCaseSensitive {
+ mr.pattern = strings.ToLower(pattern)
+ }
+
+ paramsLen := len(params)
+ if paramsLen > 0 {
mr.returnOnOutput = params[0]
}
+ if paramsLen > 1 {
+ mr.resetParams = params[1]
+ }
mr.tree.AddRouter(pattern, true)
return p.insertFilterRouter(pos, mr)
}
@@ -427,7 +434,7 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter
// add Filter into
func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) {
if pos < BeforeStatic || pos > FinishRouter {
- err = fmt.Errorf("can not find your filter postion")
+ err = fmt.Errorf("can not find your filter position")
return
}
p.enableFilter = true
@@ -581,12 +588,22 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin
}
func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) {
+ var preFilterParams map[string]string
for _, filterR := range p.filters[pos] {
if filterR.returnOnOutput && context.ResponseWriter.Started {
return true
}
+ if filterR.resetParams {
+ preFilterParams = context.Input.Params()
+ }
if ok := filterR.ValidRouter(urlPath, context); ok {
filterR.filterFunc(context)
+ if filterR.resetParams {
+ context.Input.ResetParams()
+ for k, v := range preFilterParams {
+ context.Input.SetParam(k, v)
+ }
+ }
}
if filterR.returnOnOutput && context.ResponseWriter.Started {
return true
@@ -609,7 +626,9 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
context.Reset(rw, r)
defer p.pool.Put(context)
- defer p.recoverPanic(context)
+ if BConfig.RecoverFunc != nil {
+ defer BConfig.RecoverFunc(context)
+ }
context.Output.EnableGzip = BConfig.EnableGzip
@@ -666,8 +685,16 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if len(p.filters[BeforeRouter]) > 0 && p.execFilter(context, urlPath, BeforeRouter) {
goto Admin
}
+ // User can define RunController and RunMethod in filter
+ if context.Input.RunController != nil && context.Input.RunMethod != "" {
+ findRouter = true
+ isRunnable = true
+ runMethod = context.Input.RunMethod
+ runRouter = context.Input.RunController
+ } else {
+ routerInfo, findRouter = p.FindRouter(context)
+ }
- routerInfo, findRouter = p.FindRouter(context)
//if no matches to url, throw a not found exception
if !findRouter {
exception("404", context)
@@ -679,15 +706,16 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
}
}
- //store router pattern into context
- context.Input.SetData("RouterPattern", routerInfo.pattern)
-
//execute middleware filters
if len(p.filters[BeforeExec]) > 0 && p.execFilter(context, urlPath, BeforeExec) {
goto Admin
}
if routerInfo != nil {
+ if BConfig.RunMode == DEV {
+ //store router pattern into context
+ context.Input.SetData("RouterPattern", routerInfo.pattern)
+ }
if routerInfo.routerType == routerTypeRESTFul {
if _, ok := routerInfo.methods[r.Method]; ok {
isRunnable = true
@@ -808,16 +836,33 @@ Admin:
if BConfig.RunMode == DEV || BConfig.Log.AccessLogs {
timeDur := time.Since(startTime)
var devInfo string
+
+ statusCode := context.ResponseWriter.Status
+ if statusCode == 0 {
+ statusCode = 200
+ }
+
+ iswin := (runtime.GOOS == "windows")
+ statusColor := logs.ColorByStatus(iswin, statusCode)
+ methodColor := logs.ColorByMethod(iswin, r.Method)
+ resetColor := logs.ColorByMethod(iswin, "")
+
if findRouter {
if routerInfo != nil {
- devInfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s | % -40s |", r.Method, r.URL.Path, timeDur.String(), "match", routerInfo.pattern)
+ devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s r:%s", context.Input.IP(), statusColor, statusCode,
+ resetColor, timeDur.String(), "match", methodColor, r.Method, resetColor, r.URL.Path,
+ routerInfo.pattern)
} else {
- devInfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeDur.String(), "match")
+ devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", context.Input.IP(), statusColor, statusCode, resetColor,
+ timeDur.String(), "match", methodColor, r.Method, resetColor, r.URL.Path)
}
} else {
- devInfo = fmt.Sprintf("| % -10s | % -40s | % -16s | % -10s |", r.Method, r.URL.Path, timeDur.String(), "notmatch")
+ devInfo = fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", context.Input.IP(), statusColor, statusCode, resetColor,
+ timeDur.String(), "nomatch", methodColor, r.Method, resetColor, r.URL.Path)
}
- if DefaultAccessLogFilter == nil || !DefaultAccessLogFilter.Filter(context) {
+ if iswin {
+ logs.W32Debug(devInfo)
+ } else {
logs.Debug(devInfo)
}
}
@@ -844,37 +889,6 @@ func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo
return
}
-func (p *ControllerRegister) recoverPanic(context *beecontext.Context) {
- if err := recover(); err != nil {
- if err == ErrAbort {
- return
- }
- if !BConfig.RecoverPanic {
- panic(err)
- }
- if BConfig.EnableErrorsShow {
- if _, ok := ErrorMaps[fmt.Sprint(err)]; ok {
- exception(fmt.Sprint(err), context)
- return
- }
- }
- var stack string
- logs.Critical("the request url is ", context.Input.URL())
- logs.Critical("Handler crashed with error", err)
- for i := 1; ; i++ {
- _, file, line, ok := runtime.Caller(i)
- if !ok {
- break
- }
- logs.Critical(fmt.Sprintf("%s:%d", file, line))
- stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line))
- }
- if BConfig.RunMode == DEV {
- showErr(err, context, stack)
- }
- }
-}
-
func toURL(params map[string]string) string {
if len(params) == 0 {
return ""
diff --git a/router_test.go b/router_test.go
index 9f11286c..936fd5e8 100644
--- a/router_test.go
+++ b/router_test.go
@@ -420,6 +420,74 @@ func testRequest(method, path string) (*httptest.ResponseRecorder, *http.Request
return recorder, request
}
+// Expectation: A Filter with the correct configuration should be created given
+// specific parameters.
+func TestInsertFilter(t *testing.T) {
+ testName := "TestInsertFilter"
+
+ mux := NewControllerRegister()
+ mux.InsertFilter("*", BeforeRouter, func(*context.Context) {})
+ if !mux.filters[BeforeRouter][0].returnOnOutput {
+ t.Errorf(
+ "%s: passing no variadic params should set returnOnOutput to true",
+ testName)
+ }
+ if mux.filters[BeforeRouter][0].resetParams {
+ t.Errorf(
+ "%s: passing no variadic params should set resetParams to false",
+ testName)
+ }
+
+ mux = NewControllerRegister()
+ mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, false)
+ if mux.filters[BeforeRouter][0].returnOnOutput {
+ t.Errorf(
+ "%s: passing false as 1st variadic param should set returnOnOutput to false",
+ testName)
+ }
+
+ mux = NewControllerRegister()
+ mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, true, true)
+ if !mux.filters[BeforeRouter][0].resetParams {
+ t.Errorf(
+ "%s: passing true as 2nd variadic param should set resetParams to true",
+ testName)
+ }
+}
+
+// Expectation: the second variadic arg should cause the execution of the filter
+// to preserve the parameters from before its execution.
+func TestParamResetFilter(t *testing.T) {
+ testName := "TestParamResetFilter"
+ route := "/beego/*" // splat
+ path := "/beego/routes/routes"
+
+ mux := NewControllerRegister()
+
+ mux.InsertFilter("*", BeforeExec, beegoResetParams, true, true)
+
+ mux.Get(route, beegoHandleResetParams)
+
+ rw, r := testRequest("GET", path)
+ mux.ServeHTTP(rw, r)
+
+ // The two functions, `beegoResetParams` and `beegoHandleResetParams` add
+ // a response header of `Splat`. The expectation here is that that Header
+ // value should match what the _request's_ router set, not the filter's.
+
+ headers := rw.HeaderMap
+ if len(headers["Splat"]) != 1 {
+ t.Errorf(
+ "%s: There was an error in the test. Splat param not set in Header",
+ testName)
+ }
+ if headers["Splat"][0] != "routes/routes" {
+ t.Errorf(
+ "%s: expected `:splat` param to be [routes/routes] but it was [%s]",
+ testName, headers["Splat"][0])
+ }
+}
+
// Execution point: BeforeRouter
// expectation: only BeforeRouter function is executed, notmatch output as router doesn't handle
func TestFilterBeforeRouter(t *testing.T) {
@@ -612,3 +680,10 @@ func beegoFinishRouter1(ctx *context.Context) {
func beegoFinishRouter2(ctx *context.Context) {
ctx.WriteString("|FinishRouter2")
}
+func beegoResetParams(ctx *context.Context) {
+ ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat"))
+}
+
+func beegoHandleResetParams(ctx *context.Context) {
+ ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat"))
+}
diff --git a/session/sess_cookie_test.go b/session/sess_cookie_test.go
index 209e501c..b6726005 100644
--- a/session/sess_cookie_test.go
+++ b/session/sess_cookie_test.go
@@ -15,6 +15,7 @@
package session
import (
+ "encoding/json"
"net/http"
"net/http/httptest"
"strings"
@@ -23,7 +24,11 @@ import (
func TestCookie(t *testing.T) {
config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
- globalSessions, err := NewManager("cookie", config)
+ conf := new(ManagerConfig)
+ if err := json.Unmarshal([]byte(config), conf); err != nil {
+ t.Fatal("json decode error", err)
+ }
+ globalSessions, err := NewManager("cookie", conf)
if err != nil {
t.Fatal("init cookie session err", err)
}
@@ -56,7 +61,11 @@ func TestCookie(t *testing.T) {
func TestDestorySessionCookie(t *testing.T) {
config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
- globalSessions, err := NewManager("cookie", config)
+ conf := new(ManagerConfig)
+ if err := json.Unmarshal([]byte(config), conf); err != nil {
+ t.Fatal("json decode error", err)
+ }
+ globalSessions, err := NewManager("cookie", conf)
if err != nil {
t.Fatal("init cookie session err", err)
}
diff --git a/session/sess_file.go b/session/sess_file.go
index 91acfcd4..132f5a00 100644
--- a/session/sess_file.go
+++ b/session/sess_file.go
@@ -88,10 +88,9 @@ func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
var f *os.File
if err == nil {
f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777)
- SLogger.Println(err)
} else if os.IsNotExist(err) {
f, err = os.Create(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid))
- SLogger.Println(err)
+
} else {
return
}
diff --git a/session/sess_mem_test.go b/session/sess_mem_test.go
index 43f5b0a9..2e8934b8 100644
--- a/session/sess_mem_test.go
+++ b/session/sess_mem_test.go
@@ -15,6 +15,7 @@
package session
import (
+ "encoding/json"
"net/http"
"net/http/httptest"
"strings"
@@ -22,7 +23,12 @@ import (
)
func TestMem(t *testing.T) {
- globalSessions, _ := NewManager("memory", `{"cookieName":"gosessionid","gclifetime":10}`)
+ config := `{"cookieName":"gosessionid","gclifetime":10, "enableSetCookie":true}`
+ conf := new(ManagerConfig)
+ if err := json.Unmarshal([]byte(config), conf); err != nil {
+ t.Fatal("json decode error", err)
+ }
+ globalSessions, _ := NewManager("memory", conf)
go globalSessions.GC()
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
diff --git a/session/sess_test.go b/session/sess_test.go
index 5ba910f2..b40865f3 100644
--- a/session/sess_test.go
+++ b/session/sess_test.go
@@ -89,7 +89,7 @@ func TestCookieEncodeDecode(t *testing.T) {
func TestParseConfig(t *testing.T) {
s := `{"cookieName":"gosessionid","gclifetime":3600}`
- cf := new(managerConfig)
+ cf := new(ManagerConfig)
cf.EnableSetCookie = true
err := json.Unmarshal([]byte(s), cf)
if err != nil {
@@ -103,7 +103,7 @@ func TestParseConfig(t *testing.T) {
}
cc := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
- cf2 := new(managerConfig)
+ cf2 := new(ManagerConfig)
cf2.EnableSetCookie = true
err = json.Unmarshal([]byte(cc), cf2)
if err != nil {
diff --git a/session/session.go b/session/session.go
index 73f0d677..3c9d07ab 100644
--- a/session/session.go
+++ b/session/session.go
@@ -30,7 +30,6 @@ package session
import (
"crypto/rand"
"encoding/hex"
- "encoding/json"
"errors"
"fmt"
"io"
@@ -82,7 +81,7 @@ func Register(name string, provide Provider) {
provides[name] = provide
}
-type managerConfig struct {
+type ManagerConfig struct {
CookieName string `json:"cookieName"`
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
Gclifetime int64 `json:"gclifetime"`
@@ -100,7 +99,7 @@ type managerConfig struct {
// Manager contains Provider and its configuration.
type Manager struct {
provider Provider
- config *managerConfig
+ config *ManagerConfig
}
// NewManager Create new Manager with provider name and json config string.
@@ -115,17 +114,12 @@ type Manager struct {
// 2. hashfunc default sha1
// 3. hashkey default beegosessionkey
// 4. maxage default is none
-func NewManager(provideName, config string) (*Manager, error) {
+func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) {
provider, ok := provides[provideName]
if !ok {
return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName)
}
- cf := new(managerConfig)
- cf.EnableSetCookie = true
- err := json.Unmarshal([]byte(config), cf)
- if err != nil {
- return nil, err
- }
+
if cf.Maxlifetime == 0 {
cf.Maxlifetime = cf.Gclifetime
}
@@ -142,7 +136,7 @@ func NewManager(provideName, config string) (*Manager, error) {
}
}
- err = provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig)
+ err := provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig)
if err != nil {
return nil, err
}
@@ -166,7 +160,7 @@ func NewManager(provideName, config string) (*Manager, error) {
// otherwise return an valid session id.
func (manager *Manager) getSid(r *http.Request) (string, error) {
cookie, errs := r.Cookie(manager.config.CookieName)
- if errs != nil || cookie.Value == "" || cookie.MaxAge < 0 {
+ if errs != nil || cookie.Value == "" {
var sid string
if manager.config.EnableSidInUrlQuery {
errs := r.ParseForm()
@@ -211,6 +205,9 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
}
session, err = manager.provider.SessionRead(sid)
+ if err != nil {
+ return nil, errs
+ }
cookie := &http.Cookie{
Name: manager.config.CookieName,
Value: url.QueryEscape(sid),
diff --git a/staticfile.go b/staticfile.go
index 8a1bc57b..b7be24f3 100644
--- a/staticfile.go
+++ b/staticfile.go
@@ -57,7 +57,11 @@ func serverStaticRouter(ctx *context.Context) {
if fileInfo.IsDir() {
requestURL := ctx.Input.URL()
if requestURL[len(requestURL)-1] != '/' {
- ctx.Redirect(302, requestURL+"/")
+ redirectURL := requestURL + "/"
+ if ctx.Request.URL.RawQuery != "" {
+ redirectURL = redirectURL + "?" + ctx.Request.URL.RawQuery
+ }
+ ctx.Redirect(302, redirectURL)
} else {
//serveFile will list dir
http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath)
@@ -163,13 +167,10 @@ func searchFile(ctx *context.Context) (string, os.FileInfo, error) {
return filePath, fi, nil
}
}
- return "", nil, errors.New(requestPath + " file not find")
+ return "", nil, errNotStaticRequest
}
for prefix, staticDir := range BConfig.WebConfig.StaticDir {
- if len(prefix) == 0 {
- continue
- }
if !strings.Contains(requestPath, prefix) {
continue
}
@@ -195,9 +196,11 @@ func lookupFile(ctx *context.Context) (bool, string, os.FileInfo, error) {
if !fi.IsDir() {
return false, fp, fi, err
}
- ifp := filepath.Join(fp, "index.html")
- if ifi, _ := os.Stat(ifp); ifi != nil && ifi.Mode().IsRegular() {
- return false, ifp, ifi, err
+ if requestURL := ctx.Input.URL(); requestURL[len(requestURL)-1] == '/' {
+ ifp := filepath.Join(fp, "index.html")
+ if ifi, _ := os.Stat(ifp); ifi != nil && ifi.Mode().IsRegular() {
+ return false, ifp, ifi, err
+ }
}
return !BConfig.WebConfig.DirectoryIndex, fp, fi, err
}
diff --git a/swagger/swagger.go b/swagger/swagger.go
index e48dcf1e..409e264e 100644
--- a/swagger/swagger.go
+++ b/swagger/swagger.go
@@ -22,134 +22,149 @@ package swagger
// Swagger list the resource
type Swagger struct {
- SwaggerVersion string `json:"swagger,omitempty"`
- Infos Information `json:"info"`
- Host string `json:"host,omitempty"`
- BasePath string `json:"basePath,omitempty"`
- Schemes []string `json:"schemes,omitempty"`
- Consumes []string `json:"consumes,omitempty"`
- Produces []string `json:"produces,omitempty"`
- Paths map[string]Item `json:"paths"`
- Definitions map[string]Schema `json:"definitions,omitempty"`
- SecurityDefinitions map[string]Scurity `json:"securityDefinitions,omitempty"`
- Security map[string][]string `json:"security,omitempty"`
- Tags []Tag `json:"tags,omitempty"`
- ExternalDocs ExternalDocs `json:"externalDocs,omitempty"`
+ SwaggerVersion string `json:"swagger,omitempty" yaml:"swagger,omitempty"`
+ Infos Information `json:"info" yaml:"info"`
+ Host string `json:"host,omitempty" yaml:"host,omitempty"`
+ BasePath string `json:"basePath,omitempty" yaml:"basePath,omitempty"`
+ Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"`
+ Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"`
+ Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"`
+ Paths map[string]*Item `json:"paths" yaml:"paths"`
+ Definitions map[string]Schema `json:"definitions,omitempty" yaml:"definitions,omitempty"`
+ SecurityDefinitions map[string]Security `json:"securityDefinitions,omitempty" yaml:"securityDefinitions,omitempty"`
+ Security map[string][]string `json:"security,omitempty" yaml:"security,omitempty"`
+ Tags []Tag `json:"tags,omitempty" yaml:"tags,omitempty"`
+ ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"`
}
// Information Provides metadata about the API. The metadata can be used by the clients if needed.
type Information struct {
- Title string `json:"title,omitempty"`
- Description string `json:"description,omitempty"`
- Version string `json:"version,omitempty"`
- TermsOfServiceURL string `json:"termsOfServiceUrl,omitempty"`
+ Title string `json:"title,omitempty" yaml:"title,omitempty"`
+ Description string `json:"description,omitempty" yaml:"description,omitempty"`
+ Version string `json:"version,omitempty" yaml:"version,omitempty"`
+ TermsOfService string `json:"termsOfService,omitempty" yaml:"termsOfService,omitempty"`
- Contact Contact `json:"contact,omitempty"`
- License License `json:"license,omitempty"`
+ Contact Contact `json:"contact,omitempty" yaml:"contact,omitempty"`
+ License *License `json:"license,omitempty" yaml:"license,omitempty"`
}
// Contact information for the exposed API.
type Contact struct {
- Name string `json:"name,omitempty"`
- URL string `json:"url,omitempty"`
- EMail string `json:"email,omitempty"`
+ Name string `json:"name,omitempty" yaml:"name,omitempty"`
+ URL string `json:"url,omitempty" yaml:"url,omitempty"`
+ EMail string `json:"email,omitempty" yaml:"email,omitempty"`
}
// License information for the exposed API.
type License struct {
- Name string `json:"name,omitempty"`
- URL string `json:"url,omitempty"`
+ Name string `json:"name,omitempty" yaml:"name,omitempty"`
+ URL string `json:"url,omitempty" yaml:"url,omitempty"`
}
// Item Describes the operations available on a single path.
type Item struct {
- Ref string `json:"$ref,omitempty"`
- Get *Operation `json:"get,omitempty"`
- Put *Operation `json:"put,omitempty"`
- Post *Operation `json:"post,omitempty"`
- Delete *Operation `json:"delete,omitempty"`
- Options *Operation `json:"options,omitempty"`
- Head *Operation `json:"head,omitempty"`
- Patch *Operation `json:"patch,omitempty"`
+ Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"`
+ Get *Operation `json:"get,omitempty" yaml:"get,omitempty"`
+ Put *Operation `json:"put,omitempty" yaml:"put,omitempty"`
+ Post *Operation `json:"post,omitempty" yaml:"post,omitempty"`
+ Delete *Operation `json:"delete,omitempty" yaml:"delete,omitempty"`
+ Options *Operation `json:"options,omitempty" yaml:"options,omitempty"`
+ Head *Operation `json:"head,omitempty" yaml:"head,omitempty"`
+ Patch *Operation `json:"patch,omitempty" yaml:"patch,omitempty"`
}
// Operation Describes a single API operation on a path.
type Operation struct {
- Tags []string `json:"tags,omitempty"`
- Summary string `json:"summary,omitempty"`
- Description string `json:"description,omitempty"`
- OperationID string `json:"operationId,omitempty"`
- Consumes []string `json:"consumes,omitempty"`
- Produces []string `json:"produces,omitempty"`
- Schemes []string `json:"schemes,omitempty"`
- Parameters []Parameter `json:"parameters,omitempty"`
- Responses map[string]Response `json:"responses,omitempty"`
- Deprecated bool `json:"deprecated,omitempty"`
+ Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"`
+ Summary string `json:"summary,omitempty" yaml:"summary,omitempty"`
+ Description string `json:"description,omitempty" yaml:"description,omitempty"`
+ OperationID string `json:"operationId,omitempty" yaml:"operationId,omitempty"`
+ Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"`
+ Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"`
+ Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"`
+ Parameters []Parameter `json:"parameters,omitempty" yaml:"parameters,omitempty"`
+ Responses map[string]Response `json:"responses,omitempty" yaml:"responses,omitempty"`
+ Deprecated bool `json:"deprecated,omitempty" yaml:"deprecated,omitempty"`
}
// Parameter Describes a single operation parameter.
type Parameter struct {
- In string `json:"in,omitempty"`
- Name string `json:"name,omitempty"`
- Description string `json:"description,omitempty"`
- Required bool `json:"required,omitempty"`
- Schema Schema `json:"schema,omitempty"`
- Type string `json:"type,omitempty"`
- Format string `json:"format,omitempty"`
+ In string `json:"in,omitempty" yaml:"in,omitempty"`
+ Name string `json:"name,omitempty" yaml:"name,omitempty"`
+ Description string `json:"description,omitempty" yaml:"description,omitempty"`
+ Required bool `json:"required,omitempty" yaml:"required,omitempty"`
+ Schema *Schema `json:"schema,omitempty" yaml:"schema,omitempty"`
+ Type string `json:"type,omitempty" yaml:"type,omitempty"`
+ Format string `json:"format,omitempty" yaml:"format,omitempty"`
+ Items *ParameterItems `json:"items,omitempty" yaml:"items,omitempty"`
+}
+
+// A limited subset of JSON-Schema's items object. It is used by parameter definitions that are not located in "body".
+// http://swagger.io/specification/#itemsObject
+type ParameterItems struct {
+ Type string `json:"type,omitempty" yaml:"type,omitempty"`
+ Format string `json:"format,omitempty" yaml:"format,omitempty"`
+ Items []*ParameterItems `json:"items,omitempty" yaml:"items,omitempty"` //Required if type is "array". Describes the type of items in the array.
+ CollectionFormat string `json:"collectionFormat,omitempty" yaml:"collectionFormat,omitempty"`
+ Default string `json:"default,omitempty" yaml:"default,omitempty"`
}
// Schema Object allows the definition of input and output data types.
type Schema struct {
- Ref string `json:"$ref,omitempty"`
- Title string `json:"title,omitempty"`
- Format string `json:"format,omitempty"`
- Description string `json:"description,omitempty"`
- Required []string `json:"required,omitempty"`
- Type string `json:"type,omitempty"`
- Properties map[string]Propertie `json:"properties,omitempty"`
+ Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"`
+ Title string `json:"title,omitempty" yaml:"title,omitempty"`
+ Format string `json:"format,omitempty" yaml:"format,omitempty"`
+ Description string `json:"description,omitempty" yaml:"description,omitempty"`
+ Required []string `json:"required,omitempty" yaml:"required,omitempty"`
+ Type string `json:"type,omitempty" yaml:"type,omitempty"`
+ Items *Schema `json:"items,omitempty" yaml:"items,omitempty"`
+ Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"`
}
// Propertie are taken from the JSON Schema definition but their definitions were adjusted to the Swagger Specification
type Propertie struct {
- Title string `json:"title,omitempty"`
- Description string `json:"description,omitempty"`
- Default string `json:"default,omitempty"`
- Type string `json:"type,omitempty"`
- Example string `json:"example,omitempty"`
- Required []string `json:"required,omitempty"`
- Format string `json:"format,omitempty"`
- ReadOnly bool `json:"readOnly,omitempty"`
- Properties map[string]Propertie `json:"properties,omitempty"`
+ Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"`
+ Title string `json:"title,omitempty" yaml:"title,omitempty"`
+ Description string `json:"description,omitempty" yaml:"description,omitempty"`
+ Default string `json:"default,omitempty" yaml:"default,omitempty"`
+ Type string `json:"type,omitempty" yaml:"type,omitempty"`
+ Example string `json:"example,omitempty" yaml:"example,omitempty"`
+ Required []string `json:"required,omitempty" yaml:"required,omitempty"`
+ Format string `json:"format,omitempty" yaml:"format,omitempty"`
+ ReadOnly bool `json:"readOnly,omitempty" yaml:"readOnly,omitempty"`
+ Properties map[string]Propertie `json:"properties,omitempty" yaml:"properties,omitempty"`
+ Items *Propertie `json:"items,omitempty" yaml:"items,omitempty"`
+ AdditionalProperties *Propertie `json:"additionalProperties,omitempty" yaml:"additionalProperties,omitempty"`
}
// Response as they are returned from executing this operation.
type Response struct {
- Description string `json:"description,omitempty"`
- Schema Schema `json:"schema,omitempty"`
- Ref string `json:"$ref,omitempty"`
+ Description string `json:"description,omitempty" yaml:"description,omitempty"`
+ Schema *Schema `json:"schema,omitempty" yaml:"schema,omitempty"`
+ Ref string `json:"$ref,omitempty" yaml:"$ref,omitempty"`
}
-// Scurity Allows the definition of a security scheme that can be used by the operations
-type Scurity struct {
- Type string `json:"type,omitempty"` // Valid values are "basic", "apiKey" or "oauth2".
- Description string `json:"description,omitempty"`
- Name string `json:"name,omitempty"`
- In string `json:"in,omitempty"` // Valid values are "query" or "header".
- Flow string `json:"flow,omitempty"` // Valid values are "implicit", "password", "application" or "accessCode".
- AuthorizationURL string `json:"authorizationUrl,omitempty"`
- TokenURL string `json:"tokenUrl,omitempty"`
- Scopes map[string]string `json:"scopes,omitempty"` // The available scopes for the OAuth2 security scheme.
+// Security Allows the definition of a security scheme that can be used by the operations
+type Security struct {
+ Type string `json:"type,omitempty" yaml:"type,omitempty"` // Valid values are "basic", "apiKey" or "oauth2".
+ Description string `json:"description,omitempty" yaml:"description,omitempty"`
+ Name string `json:"name,omitempty" yaml:"name,omitempty"`
+ In string `json:"in,omitempty" yaml:"in,omitempty"` // Valid values are "query" or "header".
+ Flow string `json:"flow,omitempty" yaml:"flow,omitempty"` // Valid values are "implicit", "password", "application" or "accessCode".
+ AuthorizationURL string `json:"authorizationUrl,omitempty" yaml:"authorizationUrl,omitempty"`
+ TokenURL string `json:"tokenUrl,omitempty" yaml:"tokenUrl,omitempty"`
+ Scopes map[string]string `json:"scopes,omitempty" yaml:"scopes,omitempty"` // The available scopes for the OAuth2 security scheme.
}
// Tag Allows adding meta data to a single tag that is used by the Operation Object
type Tag struct {
- Name string `json:"name,omitempty"`
- Description string `json:"description,omitempty"`
- ExternalDocs ExternalDocs `json:"externalDocs,omitempty"`
+ Name string `json:"name,omitempty" yaml:"name,omitempty"`
+ Description string `json:"description,omitempty" yaml:"description,omitempty"`
+ ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"`
}
// ExternalDocs include Additional external documentation
type ExternalDocs struct {
- Description string `json:"description,omitempty"`
- URL string `json:"url,omitempty"`
+ Description string `json:"description,omitempty" yaml:"description,omitempty"`
+ URL string `json:"url,omitempty" yaml:"url,omitempty"`
}
diff --git a/template.go b/template.go
index 494acc4f..5415f5f0 100644
--- a/template.go
+++ b/template.go
@@ -50,22 +50,16 @@ func ExecuteTemplate(wr io.Writer, name string, data interface{}) error {
defer templatesLock.RUnlock()
}
if t, ok := beeTemplates[name]; ok {
+ var err error
if t.Lookup(name) != nil {
- err := t.ExecuteTemplate(wr, name, data)
- if err != nil {
- logs.Trace("template Execute err:", err)
- }
- return err
+ err = t.ExecuteTemplate(wr, name, data)
} else {
- err := t.Execute(wr, data)
- if err != nil {
- if err != nil {
- logs.Trace("template Execute err:", err)
- }
- return err
- }
+ err = t.Execute(wr, data)
}
- return nil
+ if err != nil {
+ logs.Trace("template Execute err:", err)
+ }
+ return err
}
panic("can't find templatefile in the path:" + name)
}
diff --git a/templatefunc.go b/templatefunc.go
index 8558733f..01751717 100644
--- a/templatefunc.go
+++ b/templatefunc.go
@@ -280,15 +280,8 @@ func AssetsCSS(src string) template.HTML {
}
// ParseForm will parse form values to struct via tag.
-func ParseForm(form url.Values, obj interface{}) error {
- objT := reflect.TypeOf(obj)
- objV := reflect.ValueOf(obj)
- if !isStructPtr(objT) {
- return fmt.Errorf("%v must be a struct pointer", obj)
- }
- objT = objT.Elem()
- objV = objV.Elem()
-
+// Support for anonymous struct.
+func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) error {
for i := 0; i < objT.NumField(); i++ {
fieldV := objV.Field(i)
if !fieldV.CanSet() {
@@ -296,6 +289,14 @@ func ParseForm(form url.Values, obj interface{}) error {
}
fieldT := objT.Field(i)
+ if fieldT.Anonymous && fieldT.Type.Kind() == reflect.Struct {
+ err := parseFormToStruct(form, fieldT.Type, fieldV)
+ if err != nil {
+ return err
+ }
+ continue
+ }
+
tags := strings.Split(fieldT.Tag.Get("form"), ",")
var tag string
if len(tags) == 0 || len(tags[0]) == 0 {
@@ -384,6 +385,19 @@ func ParseForm(form url.Values, obj interface{}) error {
return nil
}
+// ParseForm will parse form values to struct via tag.
+func ParseForm(form url.Values, obj interface{}) error {
+ objT := reflect.TypeOf(obj)
+ objV := reflect.ValueOf(obj)
+ if !isStructPtr(objT) {
+ return fmt.Errorf("%v must be a struct pointer", obj)
+ }
+ objT = objT.Elem()
+ objV = objV.Elem()
+
+ return parseFormToStruct(form, objT, objV)
+}
+
var sliceOfInts = reflect.TypeOf([]int(nil))
var sliceOfStrings = reflect.TypeOf([]string(nil))
@@ -421,18 +435,18 @@ func RenderForm(obj interface{}) template.HTML {
fieldT := objT.Field(i)
- label, name, fType, id, class, ignored := parseFormTag(fieldT)
+ label, name, fType, id, class, ignored, required := parseFormTag(fieldT)
if ignored {
continue
}
- raw = append(raw, renderFormField(label, name, fType, fieldV.Interface(), id, class))
+ raw = append(raw, renderFormField(label, name, fType, fieldV.Interface(), id, class, required))
}
return template.HTML(strings.Join(raw, ""))
}
// renderFormField returns a string containing HTML of a single form field.
-func renderFormField(label, name, fType string, value interface{}, id string, class string) string {
+func renderFormField(label, name, fType string, value interface{}, id string, class string, required bool) string {
if id != "" {
id = " id=\"" + id + "\""
}
@@ -441,11 +455,16 @@ func renderFormField(label, name, fType string, value interface{}, id string, cl
class = " class=\"" + class + "\""
}
- if isValidForInput(fType) {
- return fmt.Sprintf(`%v`, label, id, class, name, fType, value)
+ requiredString := ""
+ if required {
+ requiredString = " required"
}
- return fmt.Sprintf(`%v<%v%v%v name="%v">%v%v>`, label, fType, id, class, name, value, fType)
+ if isValidForInput(fType) {
+ return fmt.Sprintf(`%v`, label, id, class, name, fType, value, requiredString)
+ }
+
+ return fmt.Sprintf(`%v<%v%v%v name="%v"%v>%v%v>`, label, fType, id, class, name, requiredString, value, fType)
}
// isValidForInput checks if fType is a valid value for the `type` property of an HTML input element.
@@ -461,7 +480,7 @@ func isValidForInput(fType string) bool {
// parseFormTag takes the stuct-tag of a StructField and parses the `form` value.
// returned are the form label, name-property, type and wether the field should be ignored.
-func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id string, class string, ignored bool) {
+func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id string, class string, ignored bool, required bool) {
tags := strings.Split(fieldT.Tag.Get("form"), ",")
label = fieldT.Name + ": "
name = fieldT.Name
@@ -470,6 +489,12 @@ func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id str
id = fieldT.Tag.Get("id")
class = fieldT.Tag.Get("class")
+ required = false
+ required_field := fieldT.Tag.Get("required")
+ if required_field != "-" && required_field != "" {
+ required, _ = strconv.ParseBool(required_field)
+ }
+
switch len(tags) {
case 1:
if tags[0] == "-" {
@@ -496,6 +521,7 @@ func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id str
label = tags[2]
}
}
+
return
}
diff --git a/templatefunc_test.go b/templatefunc_test.go
index 98fbf7ab..a1ec1136 100644
--- a/templatefunc_test.go
+++ b/templatefunc_test.go
@@ -110,6 +110,17 @@ func TestHtmlunquote(t *testing.T) {
}
func TestParseForm(t *testing.T) {
+ type ExtendInfo struct {
+ Hobby string `form:"hobby"`
+ Memo string
+ }
+
+ type OtherInfo struct {
+ Organization string `form:"organization"`
+ Title string `form:"title"`
+ ExtendInfo
+ }
+
type user struct {
ID int `form:"-"`
tag string `form:"tag"`
@@ -119,19 +130,24 @@ func TestParseForm(t *testing.T) {
Intro string `form:",textarea"`
StrBool bool `form:"strbool"`
Date time.Time `form:"date,2006-01-02"`
+ OtherInfo
}
u := user{}
form := url.Values{
- "ID": []string{"1"},
- "-": []string{"1"},
- "tag": []string{"no"},
- "username": []string{"test"},
- "age": []string{"40"},
- "Email": []string{"test@gmail.com"},
- "Intro": []string{"I am an engineer!"},
- "strbool": []string{"yes"},
- "date": []string{"2014-11-12"},
+ "ID": []string{"1"},
+ "-": []string{"1"},
+ "tag": []string{"no"},
+ "username": []string{"test"},
+ "age": []string{"40"},
+ "Email": []string{"test@gmail.com"},
+ "Intro": []string{"I am an engineer!"},
+ "strbool": []string{"yes"},
+ "date": []string{"2014-11-12"},
+ "organization": []string{"beego"},
+ "title": []string{"CXO"},
+ "hobby": []string{"Basketball"},
+ "memo": []string{"nothing"},
}
if err := ParseForm(form, u); err == nil {
t.Fatal("nothing will be changed")
@@ -164,6 +180,18 @@ func TestParseForm(t *testing.T) {
if y != 2014 || m.String() != "November" || d != 12 {
t.Errorf("Date should equal `2014-11-12`, but got `%v`", u.Date.String())
}
+ if u.Organization != "beego" {
+ t.Errorf("Organization should equal `beego`, but got `%v`", u.Organization)
+ }
+ if u.Title != "CXO" {
+ t.Errorf("Title should equal `CXO`, but got `%v`", u.Title)
+ }
+ if u.Hobby != "Basketball" {
+ t.Errorf("Hobby should equal `Basketball`, but got `%v`", u.Hobby)
+ }
+ if len(u.Memo) != 0 {
+ t.Errorf("Memo's length should equal 0 but got %v", len(u.Memo))
+ }
}
func TestRenderForm(t *testing.T) {
@@ -195,54 +223,78 @@ func TestRenderForm(t *testing.T) {
}
func TestRenderFormField(t *testing.T) {
- html := renderFormField("Label: ", "Name", "text", "Value", "", "")
+ html := renderFormField("Label: ", "Name", "text", "Value", "", "", false)
if html != `Label: ` {
t.Errorf("Wrong html output for input[type=text]: %v ", html)
}
- html = renderFormField("Label: ", "Name", "textarea", "Value", "", "")
+ html = renderFormField("Label: ", "Name", "textarea", "Value", "", "", false)
if html != `Label: ` {
t.Errorf("Wrong html output for textarea: %v ", html)
}
+
+ html = renderFormField("Label: ", "Name", "textarea", "Value", "", "", true)
+ if html != `Label: ` {
+ t.Errorf("Wrong html output for textarea: %v ", html)
+ }
}
func TestParseFormTag(t *testing.T) {
// create struct to contain field with different types of struct-tag `form`
type user struct {
- All int `form:"name,text,年龄:"`
- NoName int `form:",hidden,年龄:"`
- OnlyLabel int `form:",,年龄:"`
- OnlyName int `form:"name" id:"name" class:"form-name"`
- Ignored int `form:"-"`
+ All int `form:"name,text,年龄:"`
+ NoName int `form:",hidden,年龄:"`
+ OnlyLabel int `form:",,年龄:"`
+ OnlyName int `form:"name" id:"name" class:"form-name"`
+ Ignored int `form:"-"`
+ Required int `form:"name" required:"true"`
+ IgnoreRequired int `form:"name"`
+ NotRequired int `form:"name" required:"false"`
}
objT := reflect.TypeOf(&user{}).Elem()
- label, name, fType, id, class, ignored := parseFormTag(objT.Field(0))
+ label, name, fType, id, class, ignored, required := parseFormTag(objT.Field(0))
if !(name == "name" && label == "年龄:" && fType == "text" && ignored == false) {
t.Errorf("Form Tag with name, label and type was not correctly parsed.")
}
- label, name, fType, id, class, ignored = parseFormTag(objT.Field(1))
+ label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(1))
if !(name == "NoName" && label == "年龄:" && fType == "hidden" && ignored == false) {
t.Errorf("Form Tag with label and type but without name was not correctly parsed.")
}
- label, name, fType, id, class, ignored = parseFormTag(objT.Field(2))
+ label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(2))
if !(name == "OnlyLabel" && label == "年龄:" && fType == "text" && ignored == false) {
t.Errorf("Form Tag containing only label was not correctly parsed.")
}
- label, name, fType, id, class, ignored = parseFormTag(objT.Field(3))
+ label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(3))
if !(name == "name" && label == "OnlyName: " && fType == "text" && ignored == false &&
id == "name" && class == "form-name") {
t.Errorf("Form Tag containing only name was not correctly parsed.")
}
- label, name, fType, id, class, ignored = parseFormTag(objT.Field(4))
+ label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(4))
if ignored == false {
t.Errorf("Form Tag that should be ignored was not correctly parsed.")
}
+
+ label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(5))
+ if !(name == "name" && required == true) {
+ t.Errorf("Form Tag containing only name and required was not correctly parsed.")
+ }
+
+ label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(6))
+ if !(name == "name" && required == false) {
+ t.Errorf("Form Tag containing only name and ignore required was not correctly parsed.")
+ }
+
+ label, name, fType, id, class, ignored, required = parseFormTag(objT.Field(7))
+ if !(name == "name" && required == false) {
+ t.Errorf("Form Tag containing only name and not required was not correctly parsed.")
+ }
+
}
func TestMapGet(t *testing.T) {
diff --git a/toolbox/statistics.go b/toolbox/statistics.go
index 32eb7e23..69b88772 100644
--- a/toolbox/statistics.go
+++ b/toolbox/statistics.go
@@ -99,9 +99,13 @@ func (m *URLMap) GetMap() map[string]interface{} {
fmt.Sprintf("% -50s", k),
fmt.Sprintf("% -10s", kk),
fmt.Sprintf("% -16d", vv.RequestNum),
+ fmt.Sprintf("%d", vv.TotalTime),
fmt.Sprintf("% -16s", toS(vv.TotalTime)),
+ fmt.Sprintf("%d", vv.MaxTime),
fmt.Sprintf("% -16s", toS(vv.MaxTime)),
+ fmt.Sprintf("%d", vv.MinTime),
fmt.Sprintf("% -16s", toS(vv.MinTime)),
+ fmt.Sprintf("%d", time.Duration(int64(vv.TotalTime)/vv.RequestNum)),
fmt.Sprintf("% -16s", toS(time.Duration(int64(vv.TotalTime)/vv.RequestNum))),
}
resultLists = append(resultLists, result)
diff --git a/utils/captcha/image.go b/utils/captcha/image.go
index 1057192a..0ceb8e42 100644
--- a/utils/captcha/image.go
+++ b/utils/captcha/image.go
@@ -359,6 +359,9 @@ func (m *Image) calculateSizes(width, height, ncount int) {
}
// Calculate dot size.
m.dotSize = int(nh / fh)
+ if m.dotSize < 1 {
+ m.dotSize = 1
+ }
// Save everything, making the actual width smaller by 1 dot to account
// for spacing between digits.
m.numWidth = int(nw) - m.dotSize
diff --git a/utils/file_test.go b/utils/file_test.go
index 020d7e4c..86d1a700 100644
--- a/utils/file_test.go
+++ b/utils/file_test.go
@@ -41,7 +41,7 @@ func TestFileExists(t *testing.T) {
}
if FileExists(noExistedFile) {
- t.Errorf("Wierd, how could this file exists: %s", noExistedFile)
+ t.Errorf("Weird, how could this file exists: %s", noExistedFile)
}
}
@@ -52,7 +52,7 @@ func TestSearchFile(t *testing.T) {
}
t.Log(path)
- path, err = SearchFile(noExistedFile, ".")
+ _, err = SearchFile(noExistedFile, ".")
if err == nil {
t.Errorf("err shouldnot be nil, got path: %s", SelfDir())
}