", result)
+ }
+ }
+
+ func() {
+ ctrl.TplName = "file2.tpl"
+ defer func() {
+ if r := recover(); r == nil {
+ t.Fatal("TestAdditionalViewPaths expected error")
+ }
+ }()
+ ctrl.RenderString()
+ }()
+
+ ctrl.TplName = "file2.tpl"
+ ctrl.ViewPath = dir2
+ ctrl.RenderString()
+}
diff --git a/error.go b/error.go
index ab626247..b913db39 100644
--- a/error.go
+++ b/error.go
@@ -252,6 +252,30 @@ func forbidden(rw http.ResponseWriter, r *http.Request) {
)
}
+// show 422 missing xsrf token
+func missingxsrf(rw http.ResponseWriter, r *http.Request) {
+ responseError(rw, r,
+ 422,
+ " The page you have requested is forbidden."+
+ " Perhaps you are here because:"+
+ "
"+
+ " '_xsrf' argument missing from POST"+
+ "
",
+ )
+}
+
+// show 417 invalid xsrf token
+func invalidxsrf(rw http.ResponseWriter, r *http.Request) {
+ responseError(rw, r,
+ 417,
+ " The page you have requested is forbidden."+
+ " Perhaps you are here because:"+
+ "
"+
+ " expected XSRF not found"+
+ "
",
+ )
+}
+
// show 404 not found error.
func notFound(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
diff --git a/error_test.go b/error_test.go
index 2fb8f962..378aa953 100644
--- a/error_test.go
+++ b/error_test.go
@@ -52,7 +52,7 @@ func TestErrorCode_01(t *testing.T) {
if w.Code != code {
t.Fail()
}
- if !strings.Contains(string(w.Body.Bytes()), http.StatusText(code)) {
+ if !strings.Contains(w.Body.String(), http.StatusText(code)) {
t.Fail()
}
}
@@ -82,7 +82,7 @@ func TestErrorCode_03(t *testing.T) {
if w.Code != 200 {
t.Fail()
}
- if string(w.Body.Bytes()) != parseCodeError {
+ if w.Body.String() != parseCodeError {
t.Fail()
}
}
diff --git a/flash_test.go b/flash_test.go
index 640d54de..d5e9608d 100644
--- a/flash_test.go
+++ b/flash_test.go
@@ -48,7 +48,7 @@ func TestFlashHeader(t *testing.T) {
// match for the expected header
res := strings.Contains(sc, "BEEGO_FLASH=%00notice%23BEEGOFLASH%23TestFlashString%00")
// validate the assertion
- if res != true {
+ if !res {
t.Errorf("TestFlashHeader() unable to validate flash message")
}
}
diff --git a/grace/conn.go b/grace/conn.go
index 6807e1ac..e020f850 100644
--- a/grace/conn.go
+++ b/grace/conn.go
@@ -3,14 +3,17 @@ package grace
import (
"errors"
"net"
+ "sync"
)
type graceConn struct {
net.Conn
server *Server
+ m sync.Mutex
+ closed bool
}
-func (c graceConn) Close() (err error) {
+func (c *graceConn) Close() (err error) {
defer func() {
if r := recover(); r != nil {
switch x := r.(type) {
@@ -23,6 +26,14 @@ func (c graceConn) Close() (err error) {
}
}
}()
+
+ c.m.Lock()
+ if c.closed {
+ c.m.Unlock()
+ return
+ }
c.server.wg.Done()
+ c.closed = true
+ c.m.Unlock()
return c.Conn.Close()
}
diff --git a/grace/grace.go b/grace/grace.go
index af4e9068..6ebf8455 100644
--- a/grace/grace.go
+++ b/grace/grace.go
@@ -85,23 +85,31 @@ var (
isChild bool
socketOrder string
- once sync.Once
+
+ hookableSignals []os.Signal
)
-func onceInit() {
- regLock = &sync.Mutex{}
+func init() {
flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)")
flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started")
+
+ regLock = &sync.Mutex{}
runningServers = make(map[string]*Server)
runningServersOrder = []string{}
socketPtrOffsetMap = make(map[string]uint)
+
+ hookableSignals = []os.Signal{
+ syscall.SIGHUP,
+ syscall.SIGINT,
+ syscall.SIGTERM,
+ }
}
// NewServer returns a new graceServer.
func NewServer(addr string, handler http.Handler) (srv *Server) {
- once.Do(onceInit)
regLock.Lock()
defer regLock.Unlock()
+
if !flag.Parsed() {
flag.Parse()
}
diff --git a/grace/listener.go b/grace/listener.go
index 5439d0b2..7ede63a3 100644
--- a/grace/listener.go
+++ b/grace/listener.go
@@ -21,7 +21,7 @@ func newGraceListener(l net.Listener, srv *Server) (el *graceListener) {
server: srv,
}
go func() {
- _ = <-el.stop
+ <-el.stop
el.stopped = true
el.stop <- el.Listener.Close()
}()
@@ -37,7 +37,7 @@ func (gl *graceListener) Accept() (c net.Conn, err error) {
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
- c = graceConn{
+ c = &graceConn{
Conn: tc,
server: gl.server,
}
diff --git a/grace/server.go b/grace/server.go
index 101bda56..b8242335 100644
--- a/grace/server.go
+++ b/grace/server.go
@@ -162,9 +162,7 @@ func (srv *Server) handleSignals() {
signal.Notify(
srv.sigChan,
- syscall.SIGHUP,
- syscall.SIGINT,
- syscall.SIGTERM,
+ hookableSignals...,
)
pid := syscall.Getpid()
@@ -198,7 +196,6 @@ func (srv *Server) signalHooks(ppFlag int, sig os.Signal) {
for _, f := range srv.SignalHooks[ppFlag][sig] {
f()
}
- return
}
// shutdown closes the listener so that no new connections are accepted. it also
@@ -290,3 +287,19 @@ func (srv *Server) fork() (err error) {
return
}
+
+// RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal.
+func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) {
+ if ppFlag != PreSignal && ppFlag != PostSignal {
+ err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal")
+ return
+ }
+ for _, s := range hookableSignals {
+ if s == sig {
+ srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f)
+ return
+ }
+ }
+ err = fmt.Errorf("Signal '%v' is not supported", sig)
+ return
+}
diff --git a/hooks.go b/hooks.go
index b5a5e6c5..c5ec8e2d 100644
--- a/hooks.go
+++ b/hooks.go
@@ -32,6 +32,8 @@ func registerDefaultErrorHandler() error {
"502": badGateway,
"503": serviceUnavailable,
"504": gatewayTimeout,
+ "417": invalidxsrf,
+ "422": missingxsrf,
}
for e, h := range m {
if _, ok := ErrorMaps[e]; !ok {
@@ -55,9 +57,9 @@ func registerSession() error {
conf.ProviderConfig = filepath.ToSlash(BConfig.WebConfig.Session.SessionProviderConfig)
conf.DisableHTTPOnly = BConfig.WebConfig.Session.SessionDisableHTTPOnly
conf.Domain = BConfig.WebConfig.Session.SessionDomain
- conf.EnableSidInHttpHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader
- conf.SessionNameInHttpHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader
- conf.EnableSidInUrlQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery
+ conf.EnableSidInHTTPHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader
+ conf.SessionNameInHTTPHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader
+ conf.EnableSidInURLQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery
} else {
if err = json.Unmarshal([]byte(sessionConfig), conf); err != nil {
return err
@@ -72,7 +74,8 @@ func registerSession() error {
}
func registerTemplate() error {
- if err := BuildTemplate(BConfig.WebConfig.ViewsPath); err != nil {
+ defer lockViewPaths()
+ if err := AddViewPath(BConfig.WebConfig.ViewsPath); err != nil {
if BConfig.RunMode == DEV {
logs.Warn(err)
}
diff --git a/httplib/README.md b/httplib/README.md
index 6a72cf7c..97df8e6b 100644
--- a/httplib/README.md
+++ b/httplib/README.md
@@ -32,7 +32,7 @@ The default timeout is `60` seconds, function prototype:
SetTimeout(connectTimeout, readWriteTimeout time.Duration)
-Exmaple:
+Example:
// GET
httplib.Get("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second)
diff --git a/httplib/httplib.go b/httplib/httplib.go
index 510ad75e..4fd572d6 100644
--- a/httplib/httplib.go
+++ b/httplib/httplib.go
@@ -140,6 +140,7 @@ type BeegoHTTPSettings struct {
EnableCookie bool
Gzip bool
DumpBody bool
+ Retries int // if set to -1 means will retry forever
}
// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request.
@@ -189,6 +190,15 @@ func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest {
return b
}
+// Retries sets Retries times.
+// default is 0 means no retried.
+// -1 means retried forever.
+// others means retried times.
+func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest {
+ b.setting.Retries = times
+ return b
+}
+
// DumpBody setting whether need to Dump the Body.
func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest {
b.setting.DumpBody = isdump
@@ -325,7 +335,7 @@ func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error)
func (b *BeegoHTTPRequest) buildURL(paramBody string) {
// build GET url with query string
if b.req.Method == "GET" && len(paramBody) > 0 {
- if strings.Index(b.url, "?") != -1 {
+ if strings.Contains(b.url, "?") {
b.url += "&" + paramBody
} else {
b.url = b.url + "?" + paramBody
@@ -334,7 +344,7 @@ func (b *BeegoHTTPRequest) buildURL(paramBody string) {
}
// build POST/PUT/PATCH url and body
- if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH") && b.req.Body == nil {
+ if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH" || b.req.Method == "DELETE") && b.req.Body == nil {
// with files
if len(b.files) > 0 {
pr, pw := io.Pipe()
@@ -390,7 +400,7 @@ func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) {
}
// DoRequest will do the client.Do
-func (b *BeegoHTTPRequest) DoRequest() (*http.Response, error) {
+func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) {
var paramBody string
if len(b.params) > 0 {
var buf bytes.Buffer
@@ -467,7 +477,16 @@ func (b *BeegoHTTPRequest) DoRequest() (*http.Response, error) {
}
b.dump = dump
}
- return client.Do(b.req)
+ // retries default value is 0, it will run once.
+ // retries equal to -1, it will run forever until success
+ // retries is setted, it will retries fixed times.
+ for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ {
+ resp, err = client.Do(b.req)
+ if err == nil {
+ break
+ }
+ }
+ return resp, err
}
// String returns the body string in response.
@@ -501,9 +520,9 @@ func (b *BeegoHTTPRequest) Bytes() ([]byte, error) {
return nil, err
}
b.body, err = ioutil.ReadAll(reader)
- } else {
- b.body, err = ioutil.ReadAll(resp.Body)
+ return b.body, err
}
+ b.body, err = ioutil.ReadAll(resp.Body)
return b.body, err
}
diff --git a/httplib/httplib_test.go b/httplib/httplib_test.go
index 05815054..32d3e7f6 100644
--- a/httplib/httplib_test.go
+++ b/httplib/httplib_test.go
@@ -102,6 +102,14 @@ func TestSimpleDelete(t *testing.T) {
t.Log(str)
}
+func TestSimpleDeleteParam(t *testing.T) {
+ str, err := Delete("http://httpbin.org/delete").Param("key", "val").String()
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Log(str)
+}
+
func TestWithCookie(t *testing.T) {
v := "smallfish"
str, err := Get("http://httpbin.org/cookies/set?k1=" + v).SetEnableCookie(true).String()
diff --git a/logs/alils/alils.go b/logs/alils/alils.go
new file mode 100644
index 00000000..867ff4cb
--- /dev/null
+++ b/logs/alils/alils.go
@@ -0,0 +1,186 @@
+package alils
+
+import (
+ "encoding/json"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/astaxie/beego/logs"
+ "github.com/gogo/protobuf/proto"
+)
+
+const (
+ // CacheSize set the flush size
+ CacheSize int = 64
+ // Delimiter define the topic delimiter
+ Delimiter string = "##"
+)
+
+// Config is the Config for Ali Log
+type Config struct {
+ Project string `json:"project"`
+ Endpoint string `json:"endpoint"`
+ KeyID string `json:"key_id"`
+ KeySecret string `json:"key_secret"`
+ LogStore string `json:"log_store"`
+ Topics []string `json:"topics"`
+ Source string `json:"source"`
+ Level int `json:"level"`
+ FlushWhen int `json:"flush_when"`
+}
+
+// aliLSWriter implements LoggerInterface.
+// it writes messages in keep-live tcp connection.
+type aliLSWriter struct {
+ store *LogStore
+ group []*LogGroup
+ withMap bool
+ groupMap map[string]*LogGroup
+ lock *sync.Mutex
+ Config
+}
+
+// NewAliLS create a new Logger
+func NewAliLS() logs.Logger {
+ alils := new(aliLSWriter)
+ alils.Level = logs.LevelTrace
+ return alils
+}
+
+// Init parse config and init struct
+func (c *aliLSWriter) Init(jsonConfig string) (err error) {
+
+ json.Unmarshal([]byte(jsonConfig), c)
+
+ if c.FlushWhen > CacheSize {
+ c.FlushWhen = CacheSize
+ }
+
+ prj := &LogProject{
+ Name: c.Project,
+ Endpoint: c.Endpoint,
+ AccessKeyID: c.KeyID,
+ AccessKeySecret: c.KeySecret,
+ }
+
+ c.store, err = prj.GetLogStore(c.LogStore)
+ if err != nil {
+ return err
+ }
+
+ // Create default Log Group
+ c.group = append(c.group, &LogGroup{
+ Topic: proto.String(""),
+ Source: proto.String(c.Source),
+ Logs: make([]*Log, 0, c.FlushWhen),
+ })
+
+ // Create other Log Group
+ c.groupMap = make(map[string]*LogGroup)
+ for _, topic := range c.Topics {
+
+ lg := &LogGroup{
+ Topic: proto.String(topic),
+ Source: proto.String(c.Source),
+ Logs: make([]*Log, 0, c.FlushWhen),
+ }
+
+ c.group = append(c.group, lg)
+ c.groupMap[topic] = lg
+ }
+
+ if len(c.group) == 1 {
+ c.withMap = false
+ } else {
+ c.withMap = true
+ }
+
+ c.lock = &sync.Mutex{}
+
+ return nil
+}
+
+// WriteMsg write message in connection.
+// if connection is down, try to re-connect.
+func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error) {
+
+ if level > c.Level {
+ return nil
+ }
+
+ var topic string
+ var content string
+ var lg *LogGroup
+ if c.withMap {
+
+ // Topic,LogGroup
+ strs := strings.SplitN(msg, Delimiter, 2)
+ if len(strs) == 2 {
+ pos := strings.LastIndex(strs[0], " ")
+ topic = strs[0][pos+1 : len(strs[0])]
+ content = strs[0][0:pos] + strs[1]
+ lg = c.groupMap[topic]
+ }
+
+ // send to empty Topic
+ if lg == nil {
+ content = msg
+ lg = c.group[0]
+ }
+ } else {
+ content = msg
+ lg = c.group[0]
+ }
+
+ c1 := &LogContent{
+ Key: proto.String("msg"),
+ Value: proto.String(content),
+ }
+
+ l := &Log{
+ Time: proto.Uint32(uint32(when.Unix())),
+ Contents: []*LogContent{
+ c1,
+ },
+ }
+
+ c.lock.Lock()
+ lg.Logs = append(lg.Logs, l)
+ c.lock.Unlock()
+
+ if len(lg.Logs) >= c.FlushWhen {
+ c.flush(lg)
+ }
+
+ return nil
+}
+
+// Flush implementing method. empty.
+func (c *aliLSWriter) Flush() {
+
+ // flush all group
+ for _, lg := range c.group {
+ c.flush(lg)
+ }
+}
+
+// Destroy destroy connection writer and close tcp listener.
+func (c *aliLSWriter) Destroy() {
+}
+
+func (c *aliLSWriter) flush(lg *LogGroup) {
+
+ c.lock.Lock()
+ defer c.lock.Unlock()
+ err := c.store.PutLogs(lg)
+ if err != nil {
+ return
+ }
+
+ lg.Logs = make([]*Log, 0, c.FlushWhen)
+}
+
+func init() {
+ logs.Register(logs.AdapterAliLS, NewAliLS)
+}
diff --git a/logs/alils/config.go b/logs/alils/config.go
new file mode 100755
index 00000000..e8c24448
--- /dev/null
+++ b/logs/alils/config.go
@@ -0,0 +1,13 @@
+package alils
+
+const (
+ version = "0.5.0" // SDK version
+ signatureMethod = "hmac-sha1" // Signature method
+
+ // OffsetNewest stands for the log head offset, i.e. the offset that will be
+ // assigned to the next message that will be produced to the shard.
+ OffsetNewest = "end"
+ // OffsetOldest stands for the oldest offset available on the logstore for a
+ // shard.
+ OffsetOldest = "begin"
+)
diff --git a/logs/alils/log.pb.go b/logs/alils/log.pb.go
new file mode 100755
index 00000000..601b0d78
--- /dev/null
+++ b/logs/alils/log.pb.go
@@ -0,0 +1,1038 @@
+package alils
+
+import (
+ "fmt"
+ "io"
+ "math"
+
+ "github.com/gogo/protobuf/proto"
+ github_com_gogo_protobuf_proto "github.com/gogo/protobuf/proto"
+)
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+var (
+ // ErrInvalidLengthLog invalid proto
+ ErrInvalidLengthLog = fmt.Errorf("proto: negative length found during unmarshaling")
+ // ErrIntOverflowLog overflow
+ ErrIntOverflowLog = fmt.Errorf("proto: integer overflow")
+)
+
+// Log define the proto Log
+type Log struct {
+ Time *uint32 `protobuf:"varint,1,req,name=Time" json:"Time,omitempty"`
+ Contents []*LogContent `protobuf:"bytes,2,rep,name=Contents" json:"Contents,omitempty"`
+ XXXUnrecognized []byte `json:"-"`
+}
+
+// Reset the Log
+func (m *Log) Reset() { *m = Log{} }
+
+// String return the Compact Log
+func (m *Log) String() string { return proto.CompactTextString(m) }
+
+// ProtoMessage not implemented
+func (*Log) ProtoMessage() {}
+
+// GetTime return the Log's Time
+func (m *Log) GetTime() uint32 {
+ if m != nil && m.Time != nil {
+ return *m.Time
+ }
+ return 0
+}
+
+// GetContents return the Log's Contents
+func (m *Log) GetContents() []*LogContent {
+ if m != nil {
+ return m.Contents
+ }
+ return nil
+}
+
+// LogContent define the Log content struct
+type LogContent struct {
+ Key *string `protobuf:"bytes,1,req,name=Key" json:"Key,omitempty"`
+ Value *string `protobuf:"bytes,2,req,name=Value" json:"Value,omitempty"`
+ XXXUnrecognized []byte `json:"-"`
+}
+
+// Reset LogContent
+func (m *LogContent) Reset() { *m = LogContent{} }
+
+// String return the compact text
+func (m *LogContent) String() string { return proto.CompactTextString(m) }
+
+// ProtoMessage not implemented
+func (*LogContent) ProtoMessage() {}
+
+// GetKey return the Key
+func (m *LogContent) GetKey() string {
+ if m != nil && m.Key != nil {
+ return *m.Key
+ }
+ return ""
+}
+
+// GetValue return the Value
+func (m *LogContent) GetValue() string {
+ if m != nil && m.Value != nil {
+ return *m.Value
+ }
+ return ""
+}
+
+// LogGroup define the logs struct
+type LogGroup struct {
+ Logs []*Log `protobuf:"bytes,1,rep,name=Logs" json:"Logs,omitempty"`
+ Reserved *string `protobuf:"bytes,2,opt,name=Reserved" json:"Reserved,omitempty"`
+ Topic *string `protobuf:"bytes,3,opt,name=Topic" json:"Topic,omitempty"`
+ Source *string `protobuf:"bytes,4,opt,name=Source" json:"Source,omitempty"`
+ XXXUnrecognized []byte `json:"-"`
+}
+
+// Reset LogGroup
+func (m *LogGroup) Reset() { *m = LogGroup{} }
+
+// String return the compact text
+func (m *LogGroup) String() string { return proto.CompactTextString(m) }
+
+// ProtoMessage not implemented
+func (*LogGroup) ProtoMessage() {}
+
+// GetLogs return the loggroup logs
+func (m *LogGroup) GetLogs() []*Log {
+ if m != nil {
+ return m.Logs
+ }
+ return nil
+}
+
+// GetReserved return Reserved
+func (m *LogGroup) GetReserved() string {
+ if m != nil && m.Reserved != nil {
+ return *m.Reserved
+ }
+ return ""
+}
+
+// GetTopic return Topic
+func (m *LogGroup) GetTopic() string {
+ if m != nil && m.Topic != nil {
+ return *m.Topic
+ }
+ return ""
+}
+
+// GetSource return Source
+func (m *LogGroup) GetSource() string {
+ if m != nil && m.Source != nil {
+ return *m.Source
+ }
+ return ""
+}
+
+// LogGroupList define the LogGroups
+type LogGroupList struct {
+ LogGroups []*LogGroup `protobuf:"bytes,1,rep,name=logGroups" json:"logGroups,omitempty"`
+ XXXUnrecognized []byte `json:"-"`
+}
+
+// Reset LogGroupList
+func (m *LogGroupList) Reset() { *m = LogGroupList{} }
+
+// String return compact text
+func (m *LogGroupList) String() string { return proto.CompactTextString(m) }
+
+// ProtoMessage not implemented
+func (*LogGroupList) ProtoMessage() {}
+
+// GetLogGroups return the LogGroups
+func (m *LogGroupList) GetLogGroups() []*LogGroup {
+ if m != nil {
+ return m.LogGroups
+ }
+ return nil
+}
+
+// Marshal the logs to byte slice
+func (m *Log) Marshal() (data []byte, err error) {
+ size := m.Size()
+ data = make([]byte, size)
+ n, err := m.MarshalTo(data)
+ if err != nil {
+ return nil, err
+ }
+ return data[:n], nil
+}
+
+// MarshalTo data
+func (m *Log) MarshalTo(data []byte) (int, error) {
+ var i int
+ _ = i
+ var l int
+ _ = l
+ if m.Time == nil {
+ return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Time")
+ }
+ data[i] = 0x8
+ i++
+ i = encodeVarintLog(data, i, uint64(*m.Time))
+ if len(m.Contents) > 0 {
+ for _, msg := range m.Contents {
+ data[i] = 0x12
+ i++
+ i = encodeVarintLog(data, i, uint64(msg.Size()))
+ n, err := msg.MarshalTo(data[i:])
+ if err != nil {
+ return 0, err
+ }
+ i += n
+ }
+ }
+ if m.XXXUnrecognized != nil {
+ i += copy(data[i:], m.XXXUnrecognized)
+ }
+ return i, nil
+}
+
+// Marshal LogContent
+func (m *LogContent) Marshal() (data []byte, err error) {
+ size := m.Size()
+ data = make([]byte, size)
+ n, err := m.MarshalTo(data)
+ if err != nil {
+ return nil, err
+ }
+ return data[:n], nil
+}
+
+// MarshalTo logcontent to data
+func (m *LogContent) MarshalTo(data []byte) (int, error) {
+ var i int
+ _ = i
+ var l int
+ _ = l
+ if m.Key == nil {
+ return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Key")
+ }
+ data[i] = 0xa
+ i++
+ i = encodeVarintLog(data, i, uint64(len(*m.Key)))
+ i += copy(data[i:], *m.Key)
+
+ if m.Value == nil {
+ return 0, github_com_gogo_protobuf_proto.NewRequiredNotSetError("Value")
+ }
+ data[i] = 0x12
+ i++
+ i = encodeVarintLog(data, i, uint64(len(*m.Value)))
+ i += copy(data[i:], *m.Value)
+ if m.XXXUnrecognized != nil {
+ i += copy(data[i:], m.XXXUnrecognized)
+ }
+ return i, nil
+}
+
+// Marshal LogGroup
+func (m *LogGroup) Marshal() (data []byte, err error) {
+ size := m.Size()
+ data = make([]byte, size)
+ n, err := m.MarshalTo(data)
+ if err != nil {
+ return nil, err
+ }
+ return data[:n], nil
+}
+
+// MarshalTo LogGroup to data
+func (m *LogGroup) MarshalTo(data []byte) (int, error) {
+ var i int
+ _ = i
+ var l int
+ _ = l
+ if len(m.Logs) > 0 {
+ for _, msg := range m.Logs {
+ data[i] = 0xa
+ i++
+ i = encodeVarintLog(data, i, uint64(msg.Size()))
+ n, err := msg.MarshalTo(data[i:])
+ if err != nil {
+ return 0, err
+ }
+ i += n
+ }
+ }
+ if m.Reserved != nil {
+ data[i] = 0x12
+ i++
+ i = encodeVarintLog(data, i, uint64(len(*m.Reserved)))
+ i += copy(data[i:], *m.Reserved)
+ }
+ if m.Topic != nil {
+ data[i] = 0x1a
+ i++
+ i = encodeVarintLog(data, i, uint64(len(*m.Topic)))
+ i += copy(data[i:], *m.Topic)
+ }
+ if m.Source != nil {
+ data[i] = 0x22
+ i++
+ i = encodeVarintLog(data, i, uint64(len(*m.Source)))
+ i += copy(data[i:], *m.Source)
+ }
+ if m.XXXUnrecognized != nil {
+ i += copy(data[i:], m.XXXUnrecognized)
+ }
+ return i, nil
+}
+
+// Marshal LogGroupList
+func (m *LogGroupList) Marshal() (data []byte, err error) {
+ size := m.Size()
+ data = make([]byte, size)
+ n, err := m.MarshalTo(data)
+ if err != nil {
+ return nil, err
+ }
+ return data[:n], nil
+}
+
+// MarshalTo LogGroupList to data
+func (m *LogGroupList) MarshalTo(data []byte) (int, error) {
+ var i int
+ _ = i
+ var l int
+ _ = l
+ if len(m.LogGroups) > 0 {
+ for _, msg := range m.LogGroups {
+ data[i] = 0xa
+ i++
+ i = encodeVarintLog(data, i, uint64(msg.Size()))
+ n, err := msg.MarshalTo(data[i:])
+ if err != nil {
+ return 0, err
+ }
+ i += n
+ }
+ }
+ if m.XXXUnrecognized != nil {
+ i += copy(data[i:], m.XXXUnrecognized)
+ }
+ return i, nil
+}
+
+func encodeFixed64Log(data []byte, offset int, v uint64) int {
+ data[offset] = uint8(v)
+ data[offset+1] = uint8(v >> 8)
+ data[offset+2] = uint8(v >> 16)
+ data[offset+3] = uint8(v >> 24)
+ data[offset+4] = uint8(v >> 32)
+ data[offset+5] = uint8(v >> 40)
+ data[offset+6] = uint8(v >> 48)
+ data[offset+7] = uint8(v >> 56)
+ return offset + 8
+}
+func encodeFixed32Log(data []byte, offset int, v uint32) int {
+ data[offset] = uint8(v)
+ data[offset+1] = uint8(v >> 8)
+ data[offset+2] = uint8(v >> 16)
+ data[offset+3] = uint8(v >> 24)
+ return offset + 4
+}
+func encodeVarintLog(data []byte, offset int, v uint64) int {
+ for v >= 1<<7 {
+ data[offset] = uint8(v&0x7f | 0x80)
+ v >>= 7
+ offset++
+ }
+ data[offset] = uint8(v)
+ return offset + 1
+}
+
+// Size return the log's size
+func (m *Log) Size() (n int) {
+ var l int
+ _ = l
+ if m.Time != nil {
+ n += 1 + sovLog(uint64(*m.Time))
+ }
+ if len(m.Contents) > 0 {
+ for _, e := range m.Contents {
+ l = e.Size()
+ n += 1 + l + sovLog(uint64(l))
+ }
+ }
+ if m.XXXUnrecognized != nil {
+ n += len(m.XXXUnrecognized)
+ }
+ return n
+}
+
+// Size return LogContent size based on Key and Value
+func (m *LogContent) Size() (n int) {
+ var l int
+ _ = l
+ if m.Key != nil {
+ l = len(*m.Key)
+ n += 1 + l + sovLog(uint64(l))
+ }
+ if m.Value != nil {
+ l = len(*m.Value)
+ n += 1 + l + sovLog(uint64(l))
+ }
+ if m.XXXUnrecognized != nil {
+ n += len(m.XXXUnrecognized)
+ }
+ return n
+}
+
+// Size return LogGroup size based on Logs
+func (m *LogGroup) Size() (n int) {
+ var l int
+ _ = l
+ if len(m.Logs) > 0 {
+ for _, e := range m.Logs {
+ l = e.Size()
+ n += 1 + l + sovLog(uint64(l))
+ }
+ }
+ if m.Reserved != nil {
+ l = len(*m.Reserved)
+ n += 1 + l + sovLog(uint64(l))
+ }
+ if m.Topic != nil {
+ l = len(*m.Topic)
+ n += 1 + l + sovLog(uint64(l))
+ }
+ if m.Source != nil {
+ l = len(*m.Source)
+ n += 1 + l + sovLog(uint64(l))
+ }
+ if m.XXXUnrecognized != nil {
+ n += len(m.XXXUnrecognized)
+ }
+ return n
+}
+
+// Size return LogGroupList size
+func (m *LogGroupList) Size() (n int) {
+ var l int
+ _ = l
+ if len(m.LogGroups) > 0 {
+ for _, e := range m.LogGroups {
+ l = e.Size()
+ n += 1 + l + sovLog(uint64(l))
+ }
+ }
+ if m.XXXUnrecognized != nil {
+ n += len(m.XXXUnrecognized)
+ }
+ return n
+}
+
+func sovLog(x uint64) (n int) {
+ for {
+ n++
+ x >>= 7
+ if x == 0 {
+ break
+ }
+ }
+ return n
+}
+func sozLog(x uint64) (n int) {
+ return sovLog((x << 1) ^ (x >> 63))
+}
+
+// Unmarshal data to log
+func (m *Log) Unmarshal(data []byte) error {
+ var hasFields [1]uint64
+ l := len(data)
+ iNdEx := 0
+ for iNdEx < l {
+ preIndex := iNdEx
+ var wire uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ wire |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ fieldNum := int32(wire >> 3)
+ wireType := int(wire & 0x7)
+ if wireType == 4 {
+ return fmt.Errorf("proto: Log: wiretype end group for non-group")
+ }
+ if fieldNum <= 0 {
+ return fmt.Errorf("proto: Log: illegal tag %d (wire type %d)", fieldNum, wire)
+ }
+ switch fieldNum {
+ case 1:
+ if wireType != 0 {
+ return fmt.Errorf("proto: wrong wireType = %d for field Time", wireType)
+ }
+ var v uint32
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ v |= (uint32(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ m.Time = &v
+ hasFields[0] |= uint64(0x00000001)
+ case 2:
+ if wireType != 2 {
+ return fmt.Errorf("proto: wrong wireType = %d for field Contents", wireType)
+ }
+ var msglen int
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ msglen |= (int(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ if msglen < 0 {
+ return ErrInvalidLengthLog
+ }
+ postIndex := iNdEx + msglen
+ if postIndex > l {
+ return io.ErrUnexpectedEOF
+ }
+ m.Contents = append(m.Contents, &LogContent{})
+ if err := m.Contents[len(m.Contents)-1].Unmarshal(data[iNdEx:postIndex]); err != nil {
+ return err
+ }
+ iNdEx = postIndex
+ default:
+ iNdEx = preIndex
+ skippy, err := skipLog(data[iNdEx:])
+ if err != nil {
+ return err
+ }
+ if skippy < 0 {
+ return ErrInvalidLengthLog
+ }
+ if (iNdEx + skippy) > l {
+ return io.ErrUnexpectedEOF
+ }
+ m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...)
+ iNdEx += skippy
+ }
+ }
+ if hasFields[0]&uint64(0x00000001) == 0 {
+ return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Time")
+ }
+
+ if iNdEx > l {
+ return io.ErrUnexpectedEOF
+ }
+ return nil
+}
+
+// Unmarshal data to LogContent
+func (m *LogContent) Unmarshal(data []byte) error {
+ var hasFields [1]uint64
+ l := len(data)
+ iNdEx := 0
+ for iNdEx < l {
+ preIndex := iNdEx
+ var wire uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ wire |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ fieldNum := int32(wire >> 3)
+ wireType := int(wire & 0x7)
+ if wireType == 4 {
+ return fmt.Errorf("proto: Content: wiretype end group for non-group")
+ }
+ if fieldNum <= 0 {
+ return fmt.Errorf("proto: Content: illegal tag %d (wire type %d)", fieldNum, wire)
+ }
+ switch fieldNum {
+ case 1:
+ if wireType != 2 {
+ return fmt.Errorf("proto: wrong wireType = %d for field Key", wireType)
+ }
+ var stringLen uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ stringLen |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ intStringLen := int(stringLen)
+ if intStringLen < 0 {
+ return ErrInvalidLengthLog
+ }
+ postIndex := iNdEx + intStringLen
+ if postIndex > l {
+ return io.ErrUnexpectedEOF
+ }
+ s := string(data[iNdEx:postIndex])
+ m.Key = &s
+ iNdEx = postIndex
+ hasFields[0] |= uint64(0x00000001)
+ case 2:
+ if wireType != 2 {
+ return fmt.Errorf("proto: wrong wireType = %d for field Value", wireType)
+ }
+ var stringLen uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ stringLen |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ intStringLen := int(stringLen)
+ if intStringLen < 0 {
+ return ErrInvalidLengthLog
+ }
+ postIndex := iNdEx + intStringLen
+ if postIndex > l {
+ return io.ErrUnexpectedEOF
+ }
+ s := string(data[iNdEx:postIndex])
+ m.Value = &s
+ iNdEx = postIndex
+ hasFields[0] |= uint64(0x00000002)
+ default:
+ iNdEx = preIndex
+ skippy, err := skipLog(data[iNdEx:])
+ if err != nil {
+ return err
+ }
+ if skippy < 0 {
+ return ErrInvalidLengthLog
+ }
+ if (iNdEx + skippy) > l {
+ return io.ErrUnexpectedEOF
+ }
+ m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...)
+ iNdEx += skippy
+ }
+ }
+ if hasFields[0]&uint64(0x00000001) == 0 {
+ return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Key")
+ }
+ if hasFields[0]&uint64(0x00000002) == 0 {
+ return github_com_gogo_protobuf_proto.NewRequiredNotSetError("Value")
+ }
+
+ if iNdEx > l {
+ return io.ErrUnexpectedEOF
+ }
+ return nil
+}
+
+// Unmarshal data to LogGroup
+func (m *LogGroup) Unmarshal(data []byte) error {
+ l := len(data)
+ iNdEx := 0
+ for iNdEx < l {
+ preIndex := iNdEx
+ var wire uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ wire |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ fieldNum := int32(wire >> 3)
+ wireType := int(wire & 0x7)
+ if wireType == 4 {
+ return fmt.Errorf("proto: LogGroup: wiretype end group for non-group")
+ }
+ if fieldNum <= 0 {
+ return fmt.Errorf("proto: LogGroup: illegal tag %d (wire type %d)", fieldNum, wire)
+ }
+ switch fieldNum {
+ case 1:
+ if wireType != 2 {
+ return fmt.Errorf("proto: wrong wireType = %d for field Logs", wireType)
+ }
+ var msglen int
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ msglen |= (int(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ if msglen < 0 {
+ return ErrInvalidLengthLog
+ }
+ postIndex := iNdEx + msglen
+ if postIndex > l {
+ return io.ErrUnexpectedEOF
+ }
+ m.Logs = append(m.Logs, &Log{})
+ if err := m.Logs[len(m.Logs)-1].Unmarshal(data[iNdEx:postIndex]); err != nil {
+ return err
+ }
+ iNdEx = postIndex
+ case 2:
+ if wireType != 2 {
+ return fmt.Errorf("proto: wrong wireType = %d for field Reserved", wireType)
+ }
+ var stringLen uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ stringLen |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ intStringLen := int(stringLen)
+ if intStringLen < 0 {
+ return ErrInvalidLengthLog
+ }
+ postIndex := iNdEx + intStringLen
+ if postIndex > l {
+ return io.ErrUnexpectedEOF
+ }
+ s := string(data[iNdEx:postIndex])
+ m.Reserved = &s
+ iNdEx = postIndex
+ case 3:
+ if wireType != 2 {
+ return fmt.Errorf("proto: wrong wireType = %d for field Topic", wireType)
+ }
+ var stringLen uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ stringLen |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ intStringLen := int(stringLen)
+ if intStringLen < 0 {
+ return ErrInvalidLengthLog
+ }
+ postIndex := iNdEx + intStringLen
+ if postIndex > l {
+ return io.ErrUnexpectedEOF
+ }
+ s := string(data[iNdEx:postIndex])
+ m.Topic = &s
+ iNdEx = postIndex
+ case 4:
+ if wireType != 2 {
+ return fmt.Errorf("proto: wrong wireType = %d for field Source", wireType)
+ }
+ var stringLen uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ stringLen |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ intStringLen := int(stringLen)
+ if intStringLen < 0 {
+ return ErrInvalidLengthLog
+ }
+ postIndex := iNdEx + intStringLen
+ if postIndex > l {
+ return io.ErrUnexpectedEOF
+ }
+ s := string(data[iNdEx:postIndex])
+ m.Source = &s
+ iNdEx = postIndex
+ default:
+ iNdEx = preIndex
+ skippy, err := skipLog(data[iNdEx:])
+ if err != nil {
+ return err
+ }
+ if skippy < 0 {
+ return ErrInvalidLengthLog
+ }
+ if (iNdEx + skippy) > l {
+ return io.ErrUnexpectedEOF
+ }
+ m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...)
+ iNdEx += skippy
+ }
+ }
+
+ if iNdEx > l {
+ return io.ErrUnexpectedEOF
+ }
+ return nil
+}
+
+// Unmarshal data to LogGroupList
+func (m *LogGroupList) Unmarshal(data []byte) error {
+ l := len(data)
+ iNdEx := 0
+ for iNdEx < l {
+ preIndex := iNdEx
+ var wire uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ wire |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ fieldNum := int32(wire >> 3)
+ wireType := int(wire & 0x7)
+ if wireType == 4 {
+ return fmt.Errorf("proto: LogGroupList: wiretype end group for non-group")
+ }
+ if fieldNum <= 0 {
+ return fmt.Errorf("proto: LogGroupList: illegal tag %d (wire type %d)", fieldNum, wire)
+ }
+ switch fieldNum {
+ case 1:
+ if wireType != 2 {
+ return fmt.Errorf("proto: wrong wireType = %d for field LogGroups", wireType)
+ }
+ var msglen int
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ msglen |= (int(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ if msglen < 0 {
+ return ErrInvalidLengthLog
+ }
+ postIndex := iNdEx + msglen
+ if postIndex > l {
+ return io.ErrUnexpectedEOF
+ }
+ m.LogGroups = append(m.LogGroups, &LogGroup{})
+ if err := m.LogGroups[len(m.LogGroups)-1].Unmarshal(data[iNdEx:postIndex]); err != nil {
+ return err
+ }
+ iNdEx = postIndex
+ default:
+ iNdEx = preIndex
+ skippy, err := skipLog(data[iNdEx:])
+ if err != nil {
+ return err
+ }
+ if skippy < 0 {
+ return ErrInvalidLengthLog
+ }
+ if (iNdEx + skippy) > l {
+ return io.ErrUnexpectedEOF
+ }
+ m.XXXUnrecognized = append(m.XXXUnrecognized, data[iNdEx:iNdEx+skippy]...)
+ iNdEx += skippy
+ }
+ }
+
+ if iNdEx > l {
+ return io.ErrUnexpectedEOF
+ }
+ return nil
+}
+
+func skipLog(data []byte) (n int, err error) {
+ l := len(data)
+ iNdEx := 0
+ for iNdEx < l {
+ var wire uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return 0, ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return 0, io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ wire |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ wireType := int(wire & 0x7)
+ switch wireType {
+ case 0:
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return 0, ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return 0, io.ErrUnexpectedEOF
+ }
+ iNdEx++
+ if data[iNdEx-1] < 0x80 {
+ break
+ }
+ }
+ return iNdEx, nil
+ case 1:
+ iNdEx += 8
+ return iNdEx, nil
+ case 2:
+ var length int
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return 0, ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return 0, io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ length |= (int(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ iNdEx += length
+ if length < 0 {
+ return 0, ErrInvalidLengthLog
+ }
+ return iNdEx, nil
+ case 3:
+ for {
+ var innerWire uint64
+ var start = iNdEx
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return 0, ErrIntOverflowLog
+ }
+ if iNdEx >= l {
+ return 0, io.ErrUnexpectedEOF
+ }
+ b := data[iNdEx]
+ iNdEx++
+ innerWire |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ innerWireType := int(innerWire & 0x7)
+ if innerWireType == 4 {
+ break
+ }
+ next, err := skipLog(data[start:])
+ if err != nil {
+ return 0, err
+ }
+ iNdEx = start + next
+ }
+ return iNdEx, nil
+ case 4:
+ return iNdEx, nil
+ case 5:
+ iNdEx += 4
+ return iNdEx, nil
+ default:
+ return 0, fmt.Errorf("proto: illegal wireType %d", wireType)
+ }
+ }
+ panic("unreachable")
+}
diff --git a/logs/alils/log_config.go b/logs/alils/log_config.go
new file mode 100755
index 00000000..e8564efb
--- /dev/null
+++ b/logs/alils/log_config.go
@@ -0,0 +1,42 @@
+package alils
+
+// InputDetail define log detail
+type InputDetail struct {
+ LogType string `json:"logType"`
+ LogPath string `json:"logPath"`
+ FilePattern string `json:"filePattern"`
+ LocalStorage bool `json:"localStorage"`
+ TimeFormat string `json:"timeFormat"`
+ LogBeginRegex string `json:"logBeginRegex"`
+ Regex string `json:"regex"`
+ Keys []string `json:"key"`
+ FilterKeys []string `json:"filterKey"`
+ FilterRegex []string `json:"filterRegex"`
+ TopicFormat string `json:"topicFormat"`
+}
+
+// OutputDetail define the output detail
+type OutputDetail struct {
+ Endpoint string `json:"endpoint"`
+ LogStoreName string `json:"logstoreName"`
+}
+
+// LogConfig define Log Config
+type LogConfig struct {
+ Name string `json:"configName"`
+ InputType string `json:"inputType"`
+ InputDetail InputDetail `json:"inputDetail"`
+ OutputType string `json:"outputType"`
+ OutputDetail OutputDetail `json:"outputDetail"`
+
+ CreateTime uint32
+ LastModifyTime uint32
+
+ project *LogProject
+}
+
+// GetAppliedMachineGroup returns applied machine group of this config.
+func (c *LogConfig) GetAppliedMachineGroup(confName string) (groupNames []string, err error) {
+ groupNames, err = c.project.GetAppliedMachineGroups(c.Name)
+ return
+}
diff --git a/logs/alils/log_project.go b/logs/alils/log_project.go
new file mode 100755
index 00000000..59db8cbf
--- /dev/null
+++ b/logs/alils/log_project.go
@@ -0,0 +1,819 @@
+/*
+Package alils implements the SDK(v0.5.0) of Simple Log Service(abbr. SLS).
+
+For more description about SLS, please read this article:
+http://gitlab.alibaba-inc.com/sls/doc.
+*/
+package alils
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/http/httputil"
+)
+
+// Error message in SLS HTTP response.
+type errorMessage struct {
+ Code string `json:"errorCode"`
+ Message string `json:"errorMessage"`
+}
+
+// LogProject Define the Ali Project detail
+type LogProject struct {
+ Name string // Project name
+ Endpoint string // IP or hostname of SLS endpoint
+ AccessKeyID string
+ AccessKeySecret string
+}
+
+// NewLogProject creates a new SLS project.
+func NewLogProject(name, endpoint, AccessKeyID, accessKeySecret string) (p *LogProject, err error) {
+ p = &LogProject{
+ Name: name,
+ Endpoint: endpoint,
+ AccessKeyID: AccessKeyID,
+ AccessKeySecret: accessKeySecret,
+ }
+ return p, nil
+}
+
+// ListLogStore returns all logstore names of project p.
+func (p *LogProject) ListLogStore() (storeNames []string, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ uri := fmt.Sprintf("/logstores")
+ r, err := request(p, "GET", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to list logstore")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ type Body struct {
+ Count int
+ LogStores []string
+ }
+ body := &Body{}
+
+ err = json.Unmarshal(buf, body)
+ if err != nil {
+ return
+ }
+
+ storeNames = body.LogStores
+
+ return
+}
+
+// GetLogStore returns logstore according by logstore name.
+func (p *LogProject) GetLogStore(name string) (s *LogStore, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ r, err := request(p, "GET", "/logstores/"+name, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to get logstore")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ s = &LogStore{}
+ err = json.Unmarshal(buf, s)
+ if err != nil {
+ return
+ }
+ s.project = p
+ return
+}
+
+// CreateLogStore creates a new logstore in SLS,
+// where name is logstore name,
+// and ttl is time-to-live(in day) of logs,
+// and shardCnt is the number of shards.
+func (p *LogProject) CreateLogStore(name string, ttl, shardCnt int) (err error) {
+
+ type Body struct {
+ Name string `json:"logstoreName"`
+ TTL int `json:"ttl"`
+ ShardCount int `json:"shardCount"`
+ }
+
+ store := &Body{
+ Name: name,
+ TTL: ttl,
+ ShardCount: shardCnt,
+ }
+
+ body, err := json.Marshal(store)
+ if err != nil {
+ return
+ }
+
+ h := map[string]string{
+ "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
+ "Content-Type": "application/json",
+ "Accept-Encoding": "deflate", // TODO: support lz4
+ }
+
+ r, err := request(p, "POST", "/logstores", h, body)
+ if err != nil {
+ return
+ }
+
+ body, err = ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(body, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to create logstore")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ return
+}
+
+// DeleteLogStore deletes a logstore according by logstore name.
+func (p *LogProject) DeleteLogStore(name string) (err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ r, err := request(p, "DELETE", "/logstores/"+name, h, nil)
+ if err != nil {
+ return
+ }
+
+ body, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(body, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to delete logstore")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+ return
+}
+
+// UpdateLogStore updates a logstore according by logstore name,
+// obviously we can't modify the logstore name itself.
+func (p *LogProject) UpdateLogStore(name string, ttl, shardCnt int) (err error) {
+
+ type Body struct {
+ Name string `json:"logstoreName"`
+ TTL int `json:"ttl"`
+ ShardCount int `json:"shardCount"`
+ }
+
+ store := &Body{
+ Name: name,
+ TTL: ttl,
+ ShardCount: shardCnt,
+ }
+
+ body, err := json.Marshal(store)
+ if err != nil {
+ return
+ }
+
+ h := map[string]string{
+ "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
+ "Content-Type": "application/json",
+ "Accept-Encoding": "deflate", // TODO: support lz4
+ }
+
+ r, err := request(p, "PUT", "/logstores", h, body)
+ if err != nil {
+ return
+ }
+
+ body, err = ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(body, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to update logstore")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ return
+}
+
+// ListMachineGroup returns machine group name list and the total number of machine groups.
+// The offset starts from 0 and the size is the max number of machine groups could be returned.
+func (p *LogProject) ListMachineGroup(offset, size int) (m []string, total int, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ if size <= 0 {
+ size = 500
+ }
+
+ uri := fmt.Sprintf("/machinegroups?offset=%v&size=%v", offset, size)
+ r, err := request(p, "GET", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to list machine group")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ type Body struct {
+ MachineGroups []string
+ Count int
+ Total int
+ }
+ body := &Body{}
+
+ err = json.Unmarshal(buf, body)
+ if err != nil {
+ return
+ }
+
+ m = body.MachineGroups
+ total = body.Total
+
+ return
+}
+
+// GetMachineGroup retruns machine group according by machine group name.
+func (p *LogProject) GetMachineGroup(name string) (m *MachineGroup, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ r, err := request(p, "GET", "/machinegroups/"+name, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to get machine group:%v", name)
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ m = &MachineGroup{}
+ err = json.Unmarshal(buf, m)
+ if err != nil {
+ return
+ }
+ m.project = p
+ return
+}
+
+// CreateMachineGroup creates a new machine group in SLS.
+func (p *LogProject) CreateMachineGroup(m *MachineGroup) (err error) {
+
+ body, err := json.Marshal(m)
+ if err != nil {
+ return
+ }
+
+ h := map[string]string{
+ "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
+ "Content-Type": "application/json",
+ "Accept-Encoding": "deflate", // TODO: support lz4
+ }
+
+ r, err := request(p, "POST", "/machinegroups", h, body)
+ if err != nil {
+ return
+ }
+
+ body, err = ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(body, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to create machine group")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ return
+}
+
+// UpdateMachineGroup updates a machine group.
+func (p *LogProject) UpdateMachineGroup(m *MachineGroup) (err error) {
+
+ body, err := json.Marshal(m)
+ if err != nil {
+ return
+ }
+
+ h := map[string]string{
+ "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
+ "Content-Type": "application/json",
+ "Accept-Encoding": "deflate", // TODO: support lz4
+ }
+
+ r, err := request(p, "PUT", "/machinegroups/"+m.Name, h, body)
+ if err != nil {
+ return
+ }
+
+ body, err = ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(body, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to update machine group")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ return
+}
+
+// DeleteMachineGroup deletes machine group according machine group name.
+func (p *LogProject) DeleteMachineGroup(name string) (err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ r, err := request(p, "DELETE", "/machinegroups/"+name, h, nil)
+ if err != nil {
+ return
+ }
+
+ body, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(body, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to delete machine group")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+ return
+}
+
+// ListConfig returns config names list and the total number of configs.
+// The offset starts from 0 and the size is the max number of configs could be returned.
+func (p *LogProject) ListConfig(offset, size int) (cfgNames []string, total int, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ if size <= 0 {
+ size = 100
+ }
+
+ uri := fmt.Sprintf("/configs?offset=%v&size=%v", offset, size)
+ r, err := request(p, "GET", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to delete machine group")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ type Body struct {
+ Total int
+ Configs []string
+ }
+ body := &Body{}
+
+ err = json.Unmarshal(buf, body)
+ if err != nil {
+ return
+ }
+
+ cfgNames = body.Configs
+ total = body.Total
+ return
+}
+
+// GetConfig returns config according by config name.
+func (p *LogProject) GetConfig(name string) (c *LogConfig, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ r, err := request(p, "GET", "/configs/"+name, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to delete config")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ c = &LogConfig{}
+ err = json.Unmarshal(buf, c)
+ if err != nil {
+ return
+ }
+ c.project = p
+ return
+}
+
+// UpdateConfig updates a config.
+func (p *LogProject) UpdateConfig(c *LogConfig) (err error) {
+
+ body, err := json.Marshal(c)
+ if err != nil {
+ return
+ }
+
+ h := map[string]string{
+ "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
+ "Content-Type": "application/json",
+ "Accept-Encoding": "deflate", // TODO: support lz4
+ }
+
+ r, err := request(p, "PUT", "/configs/"+c.Name, h, body)
+ if err != nil {
+ return
+ }
+
+ body, err = ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(body, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to update config")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ return
+}
+
+// CreateConfig creates a new config in SLS.
+func (p *LogProject) CreateConfig(c *LogConfig) (err error) {
+
+ body, err := json.Marshal(c)
+ if err != nil {
+ return
+ }
+
+ h := map[string]string{
+ "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
+ "Content-Type": "application/json",
+ "Accept-Encoding": "deflate", // TODO: support lz4
+ }
+
+ r, err := request(p, "POST", "/configs", h, body)
+ if err != nil {
+ return
+ }
+
+ body, err = ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(body, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to update config")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ return
+}
+
+// DeleteConfig deletes a config according by config name.
+func (p *LogProject) DeleteConfig(name string) (err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ r, err := request(p, "DELETE", "/configs/"+name, h, nil)
+ if err != nil {
+ return
+ }
+
+ body, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(body, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to delete config")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+ return
+}
+
+// GetAppliedMachineGroups returns applied machine group names list according config name.
+func (p *LogProject) GetAppliedMachineGroups(confName string) (groupNames []string, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ uri := fmt.Sprintf("/configs/%v/machinegroups", confName)
+ r, err := request(p, "GET", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to get applied machine groups")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ type Body struct {
+ Count int
+ Machinegroups []string
+ }
+
+ body := &Body{}
+ err = json.Unmarshal(buf, body)
+ if err != nil {
+ return
+ }
+
+ groupNames = body.Machinegroups
+ return
+}
+
+// GetAppliedConfigs returns applied config names list according machine group name groupName.
+func (p *LogProject) GetAppliedConfigs(groupName string) (confNames []string, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ uri := fmt.Sprintf("/machinegroups/%v/configs", groupName)
+ r, err := request(p, "GET", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to applied configs")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ type Cfg struct {
+ Count int `json:"count"`
+ Configs []string `json:"configs"`
+ }
+
+ body := &Cfg{}
+ err = json.Unmarshal(buf, body)
+ if err != nil {
+ return
+ }
+
+ confNames = body.Configs
+ return
+}
+
+// ApplyConfigToMachineGroup applies config to machine group.
+func (p *LogProject) ApplyConfigToMachineGroup(confName, groupName string) (err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName)
+ r, err := request(p, "PUT", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to apply config to machine group")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+ return
+}
+
+// RemoveConfigFromMachineGroup removes config from machine group.
+func (p *LogProject) RemoveConfigFromMachineGroup(confName, groupName string) (err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ uri := fmt.Sprintf("/machinegroups/%v/configs/%v", groupName, confName)
+ r, err := request(p, "DELETE", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to remove config from machine group")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Printf("%s\n", dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+ return
+}
diff --git a/logs/alils/log_store.go b/logs/alils/log_store.go
new file mode 100755
index 00000000..fa502736
--- /dev/null
+++ b/logs/alils/log_store.go
@@ -0,0 +1,271 @@
+package alils
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/http/httputil"
+ "strconv"
+
+ lz4 "github.com/cloudflare/golz4"
+ "github.com/gogo/protobuf/proto"
+)
+
+// LogStore Store the logs
+type LogStore struct {
+ Name string `json:"logstoreName"`
+ TTL int
+ ShardCount int
+
+ CreateTime uint32
+ LastModifyTime uint32
+
+ project *LogProject
+}
+
+// Shard define the Log Shard
+type Shard struct {
+ ShardID int `json:"shardID"`
+}
+
+// ListShards returns shard id list of this logstore.
+func (s *LogStore) ListShards() (shardIDs []int, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ uri := fmt.Sprintf("/logstores/%v/shards", s.Name)
+ r, err := request(s.project, "GET", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to list logstore")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Println(dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ var shards []*Shard
+ err = json.Unmarshal(buf, &shards)
+ if err != nil {
+ return
+ }
+
+ for _, v := range shards {
+ shardIDs = append(shardIDs, v.ShardID)
+ }
+ return
+}
+
+// PutLogs put logs into logstore.
+// The callers should transform user logs into LogGroup.
+func (s *LogStore) PutLogs(lg *LogGroup) (err error) {
+ body, err := proto.Marshal(lg)
+ if err != nil {
+ return
+ }
+
+ // Compresse body with lz4
+ out := make([]byte, lz4.CompressBound(body))
+ n, err := lz4.Compress(body, out)
+ if err != nil {
+ return
+ }
+
+ h := map[string]string{
+ "x-sls-compresstype": "lz4",
+ "x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
+ "Content-Type": "application/x-protobuf",
+ }
+
+ uri := fmt.Sprintf("/logstores/%v", s.Name)
+ r, err := request(s.project, "POST", uri, h, out[:n])
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to put logs")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Println(dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+ return
+}
+
+// GetCursor gets log cursor of one shard specified by shardID.
+// The from can be in three form: a) unix timestamp in seccond, b) "begin", c) "end".
+// For more detail please read: http://gitlab.alibaba-inc.com/sls/doc/blob/master/api/shard.md#logstore
+func (s *LogStore) GetCursor(shardID int, from string) (cursor string, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ uri := fmt.Sprintf("/logstores/%v/shards/%v?type=cursor&from=%v",
+ s.Name, shardID, from)
+
+ r, err := request(s.project, "GET", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to get cursor")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Println(dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ type Body struct {
+ Cursor string
+ }
+ body := &Body{}
+
+ err = json.Unmarshal(buf, body)
+ if err != nil {
+ return
+ }
+ cursor = body.Cursor
+ return
+}
+
+// GetLogsBytes gets logs binary data from shard specified by shardID according cursor.
+// The logGroupMaxCount is the max number of logGroup could be returned.
+// The nextCursor is the next curosr can be used to read logs at next time.
+func (s *LogStore) GetLogsBytes(shardID int, cursor string,
+ logGroupMaxCount int) (out []byte, nextCursor string, err error) {
+
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ "Accept": "application/x-protobuf",
+ "Accept-Encoding": "lz4",
+ }
+
+ uri := fmt.Sprintf("/logstores/%v/shards/%v?type=logs&cursor=%v&count=%v",
+ s.Name, shardID, cursor, logGroupMaxCount)
+
+ r, err := request(s.project, "GET", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to get cursor")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Println(dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ v, ok := r.Header["X-Sls-Compresstype"]
+ if !ok || len(v) == 0 {
+ err = fmt.Errorf("can't find 'x-sls-compresstype' header")
+ return
+ }
+ if v[0] != "lz4" {
+ err = fmt.Errorf("unexpected compress type:%v", v[0])
+ return
+ }
+
+ v, ok = r.Header["X-Sls-Cursor"]
+ if !ok || len(v) == 0 {
+ err = fmt.Errorf("can't find 'x-sls-cursor' header")
+ return
+ }
+ nextCursor = v[0]
+
+ v, ok = r.Header["X-Sls-Bodyrawsize"]
+ if !ok || len(v) == 0 {
+ err = fmt.Errorf("can't find 'x-sls-bodyrawsize' header")
+ return
+ }
+ bodyRawSize, err := strconv.Atoi(v[0])
+ if err != nil {
+ return
+ }
+
+ out = make([]byte, bodyRawSize)
+ err = lz4.Uncompress(buf, out)
+ if err != nil {
+ return
+ }
+
+ return
+}
+
+// LogsBytesDecode decodes logs binary data retruned by GetLogsBytes API
+func LogsBytesDecode(data []byte) (gl *LogGroupList, err error) {
+
+ gl = &LogGroupList{}
+ err = proto.Unmarshal(data, gl)
+ if err != nil {
+ return
+ }
+
+ return
+}
+
+// GetLogs gets logs from shard specified by shardID according cursor.
+// The logGroupMaxCount is the max number of logGroup could be returned.
+// The nextCursor is the next curosr can be used to read logs at next time.
+func (s *LogStore) GetLogs(shardID int, cursor string,
+ logGroupMaxCount int) (gl *LogGroupList, nextCursor string, err error) {
+
+ out, nextCursor, err := s.GetLogsBytes(shardID, cursor, logGroupMaxCount)
+ if err != nil {
+ return
+ }
+
+ gl, err = LogsBytesDecode(out)
+ if err != nil {
+ return
+ }
+
+ return
+}
diff --git a/logs/alils/machine_group.go b/logs/alils/machine_group.go
new file mode 100755
index 00000000..b6c69a14
--- /dev/null
+++ b/logs/alils/machine_group.go
@@ -0,0 +1,91 @@
+package alils
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/http/httputil"
+)
+
+// MachineGroupAttribute define the Attribute
+type MachineGroupAttribute struct {
+ ExternalName string `json:"externalName"`
+ TopicName string `json:"groupTopic"`
+}
+
+// MachineGroup define the machine Group
+type MachineGroup struct {
+ Name string `json:"groupName"`
+ Type string `json:"groupType"`
+ MachineIDType string `json:"machineIdentifyType"`
+ MachineIDList []string `json:"machineList"`
+
+ Attribute MachineGroupAttribute `json:"groupAttribute"`
+
+ CreateTime uint32
+ LastModifyTime uint32
+
+ project *LogProject
+}
+
+// Machine define the Machine
+type Machine struct {
+ IP string
+ UniqueID string `json:"machine-uniqueid"`
+ UserdefinedID string `json:"userdefined-id"`
+}
+
+// MachineList define the Machine List
+type MachineList struct {
+ Total int
+ Machines []*Machine
+}
+
+// ListMachines returns machine list of this machine group.
+func (m *MachineGroup) ListMachines() (ms []*Machine, total int, err error) {
+ h := map[string]string{
+ "x-sls-bodyrawsize": "0",
+ }
+
+ uri := fmt.Sprintf("/machinegroups/%v/machines", m.Name)
+ r, err := request(m.project, "GET", uri, h, nil)
+ if err != nil {
+ return
+ }
+
+ buf, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return
+ }
+
+ if r.StatusCode != http.StatusOK {
+ errMsg := &errorMessage{}
+ err = json.Unmarshal(buf, errMsg)
+ if err != nil {
+ err = fmt.Errorf("failed to remove config from machine group")
+ dump, _ := httputil.DumpResponse(r, true)
+ fmt.Println(dump)
+ return
+ }
+ err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
+ return
+ }
+
+ body := &MachineList{}
+ err = json.Unmarshal(buf, body)
+ if err != nil {
+ return
+ }
+
+ ms = body.Machines
+ total = body.Total
+
+ return
+}
+
+// GetAppliedConfigs returns applied configs of this machine group.
+func (m *MachineGroup) GetAppliedConfigs() (confNames []string, err error) {
+ confNames, err = m.project.GetAppliedConfigs(m.Name)
+ return
+}
diff --git a/logs/alils/request.go b/logs/alils/request.go
new file mode 100755
index 00000000..50d9c43c
--- /dev/null
+++ b/logs/alils/request.go
@@ -0,0 +1,62 @@
+package alils
+
+import (
+ "bytes"
+ "crypto/md5"
+ "fmt"
+ "net/http"
+)
+
+// request sends a request to SLS.
+func request(project *LogProject, method, uri string, headers map[string]string,
+ body []byte) (resp *http.Response, err error) {
+
+ // The caller should provide 'x-sls-bodyrawsize' header
+ if _, ok := headers["x-sls-bodyrawsize"]; !ok {
+ err = fmt.Errorf("Can't find 'x-sls-bodyrawsize' header")
+ return
+ }
+
+ // SLS public request headers
+ headers["Host"] = project.Name + "." + project.Endpoint
+ headers["Date"] = nowRFC1123()
+ headers["x-sls-apiversion"] = version
+ headers["x-sls-signaturemethod"] = signatureMethod
+ if body != nil {
+ bodyMD5 := fmt.Sprintf("%X", md5.Sum(body))
+ headers["Content-MD5"] = bodyMD5
+
+ if _, ok := headers["Content-Type"]; !ok {
+ err = fmt.Errorf("Can't find 'Content-Type' header")
+ return
+ }
+ }
+
+ // Calc Authorization
+ // Authorization = "SLS :"
+ digest, err := signature(project, method, uri, headers)
+ if err != nil {
+ return
+ }
+ auth := fmt.Sprintf("SLS %v:%v", project.AccessKeyID, digest)
+ headers["Authorization"] = auth
+
+ // Initialize http request
+ reader := bytes.NewReader(body)
+ urlStr := fmt.Sprintf("http://%v.%v%v", project.Name, project.Endpoint, uri)
+ req, err := http.NewRequest(method, urlStr, reader)
+ if err != nil {
+ return
+ }
+ for k, v := range headers {
+ req.Header.Add(k, v)
+ }
+
+ // Get ready to do request
+ resp, err = http.DefaultClient.Do(req)
+ if err != nil {
+ return
+ }
+
+ return
+}
diff --git a/logs/alils/signature.go b/logs/alils/signature.go
new file mode 100755
index 00000000..2d611307
--- /dev/null
+++ b/logs/alils/signature.go
@@ -0,0 +1,111 @@
+package alils
+
+import (
+ "crypto/hmac"
+ "crypto/sha1"
+ "encoding/base64"
+ "fmt"
+ "net/url"
+ "sort"
+ "strings"
+ "time"
+)
+
+// GMT location
+var gmtLoc = time.FixedZone("GMT", 0)
+
+// NowRFC1123 returns now time in RFC1123 format with GMT timezone,
+// eg. "Mon, 02 Jan 2006 15:04:05 GMT".
+func nowRFC1123() string {
+ return time.Now().In(gmtLoc).Format(time.RFC1123)
+}
+
+// signature calculates a request's signature digest.
+func signature(project *LogProject, method, uri string,
+ headers map[string]string) (digest string, err error) {
+ var contentMD5, contentType, date, canoHeaders, canoResource string
+ var slsHeaderKeys sort.StringSlice
+
+ // SignString = VERB + "\n"
+ // + CONTENT-MD5 + "\n"
+ // + CONTENT-TYPE + "\n"
+ // + DATE + "\n"
+ // + CanonicalizedSLSHeaders + "\n"
+ // + CanonicalizedResource
+
+ if val, ok := headers["Content-MD5"]; ok {
+ contentMD5 = val
+ }
+
+ if val, ok := headers["Content-Type"]; ok {
+ contentType = val
+ }
+
+ date, ok := headers["Date"]
+ if !ok {
+ err = fmt.Errorf("Can't find 'Date' header")
+ return
+ }
+
+ // Calc CanonicalizedSLSHeaders
+ slsHeaders := make(map[string]string, len(headers))
+ for k, v := range headers {
+ l := strings.TrimSpace(strings.ToLower(k))
+ if strings.HasPrefix(l, "x-sls-") {
+ slsHeaders[l] = strings.TrimSpace(v)
+ slsHeaderKeys = append(slsHeaderKeys, l)
+ }
+ }
+
+ sort.Sort(slsHeaderKeys)
+ for i, k := range slsHeaderKeys {
+ canoHeaders += k + ":" + slsHeaders[k]
+ if i+1 < len(slsHeaderKeys) {
+ canoHeaders += "\n"
+ }
+ }
+
+ // Calc CanonicalizedResource
+ u, err := url.Parse(uri)
+ if err != nil {
+ return
+ }
+
+ canoResource += url.QueryEscape(u.Path)
+ if u.RawQuery != "" {
+ var keys sort.StringSlice
+
+ vals := u.Query()
+ for k := range vals {
+ keys = append(keys, k)
+ }
+
+ sort.Sort(keys)
+ canoResource += "?"
+ for i, k := range keys {
+ if i > 0 {
+ canoResource += "&"
+ }
+
+ for _, v := range vals[k] {
+ canoResource += k + "=" + v
+ }
+ }
+ }
+
+ signStr := method + "\n" +
+ contentMD5 + "\n" +
+ contentType + "\n" +
+ date + "\n" +
+ canoHeaders + "\n" +
+ canoResource
+
+ // Signature = base64(hmac-sha1(UTF8-Encoding-Of(SignString),AccessKeySecret))
+ mac := hmac.New(sha1.New, []byte(project.AccessKeySecret))
+ _, err = mac.Write([]byte(signStr))
+ if err != nil {
+ return
+ }
+ digest = base64.StdEncoding.EncodeToString(mac.Sum(nil))
+ return
+}
diff --git a/logs/color_windows.go b/logs/color_windows.go
index deee4c87..4e28f188 100644
--- a/logs/color_windows.go
+++ b/logs/color_windows.go
@@ -361,7 +361,7 @@ func isParameterChar(b byte) bool {
}
func (cw *ansiColorWriter) Write(p []byte) (int, error) {
- r, nw, first, last := 0, 0, 0, 0
+ var r, nw, first, last int
if cw.mode != DiscardNonColorEscSeq {
cw.state = outsideCsiCode
cw.resetBuffer()
diff --git a/logs/console.go b/logs/console.go
index e6bf6c29..e75f2a1b 100644
--- a/logs/console.go
+++ b/logs/console.go
@@ -41,7 +41,7 @@ var colors = []brush{
newBrush("1;33"), // Warning yellow
newBrush("1;32"), // Notice green
newBrush("1;34"), // Informational blue
- newBrush("1;34"), // Debug blue
+ newBrush("1;44"), // Debug Background blue
}
// consoleWriter implements LoggerInterface and writes messages to terminal.
diff --git a/logs/file.go b/logs/file.go
index 42146dae..e8c1f37e 100644
--- a/logs/file.go
+++ b/logs/file.go
@@ -56,17 +56,20 @@ type fileLogWriter struct {
Perm string `json:"perm"`
+ RotatePerm string `json:"rotateperm"`
+
fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix
}
// newFileWriter create a FileLogWriter returning as LoggerInterface.
func newFileWriter() Logger {
w := &fileLogWriter{
- Daily: true,
- MaxDays: 7,
- Rotate: true,
- Level: LevelTrace,
- Perm: "0660",
+ Daily: true,
+ MaxDays: 7,
+ Rotate: true,
+ RotatePerm: "0440",
+ Level: LevelTrace,
+ Perm: "0660",
}
return w
}
@@ -170,7 +173,7 @@ func (w *fileLogWriter) initFd() error {
fd := w.fileWriter
fInfo, err := fd.Stat()
if err != nil {
- return fmt.Errorf("get stat err: %s\n", err)
+ return fmt.Errorf("get stat err: %s", err)
}
w.maxSizeCurSize = int(fInfo.Size())
w.dailyOpenTime = time.Now()
@@ -193,16 +196,14 @@ func (w *fileLogWriter) dailyRotate(openTime time.Time) {
y, m, d := openTime.Add(24 * time.Hour).Date()
nextDay := time.Date(y, m, d, 0, 0, 0, 0, openTime.Location())
tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100))
- select {
- case <-tm.C:
- w.Lock()
- if w.needRotate(0, time.Now().Day()) {
- if err := w.doRotate(time.Now()); err != nil {
- fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
- }
+ <-tm.C
+ w.Lock()
+ if w.needRotate(0, time.Now().Day()) {
+ if err := w.doRotate(time.Now()); err != nil {
+ fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err)
}
- w.Unlock()
}
+ w.Unlock()
}
func (w *fileLogWriter) lines() (int, error) {
@@ -239,8 +240,12 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error {
// Find the next available number
num := 1
fName := ""
+ rotatePerm, err := strconv.ParseInt(w.RotatePerm, 8, 64)
+ if err != nil {
+ return err
+ }
- _, err := os.Lstat(w.Filename)
+ _, err = os.Lstat(w.Filename)
if err != nil {
//even if the file is not exist or other ,we should RESTART the logger
goto RESTART_LOGGER
@@ -261,7 +266,7 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error {
}
// return error if the last file checked still existed
if err == nil {
- return fmt.Errorf("Rotate: Cannot find free log number to rename %s\n", w.Filename)
+ return fmt.Errorf("Rotate: Cannot find free log number to rename %s", w.Filename)
}
// close fileWriter before rename
@@ -270,20 +275,24 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error {
// Rename the file to its new found name
// even if occurs error,we MUST guarantee to restart new logger
err = os.Rename(w.Filename, fName)
- // re-start logger
+ if err != nil {
+ goto RESTART_LOGGER
+ }
+
+ err = os.Chmod(fName, os.FileMode(rotatePerm))
+
RESTART_LOGGER:
startLoggerErr := w.startLogger()
go w.deleteOldLog()
if startLoggerErr != nil {
- return fmt.Errorf("Rotate StartLogger: %s\n", startLoggerErr)
+ return fmt.Errorf("Rotate StartLogger: %s", startLoggerErr)
}
if err != nil {
- return fmt.Errorf("Rotate: %s\n", err)
+ return fmt.Errorf("Rotate: %s", err)
}
return nil
-
}
func (w *fileLogWriter) deleteOldLog() {
diff --git a/logs/file_test.go b/logs/file_test.go
index 69a66d84..626521b9 100644
--- a/logs/file_test.go
+++ b/logs/file_test.go
@@ -162,14 +162,35 @@ func TestFileRotate_05(t *testing.T) {
testFileDailyRotate(t, fn1, fn2)
os.Remove(fn)
}
-
+func TestFileRotate_06(t *testing.T) { //test file mode
+ log := NewLogger(10000)
+ log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`)
+ log.Debug("debug")
+ log.Info("info")
+ log.Notice("notice")
+ log.Warning("warning")
+ log.Error("error")
+ log.Alert("alert")
+ log.Critical("critical")
+ log.Emergency("emergency")
+ rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log"
+ s, _ := os.Lstat(rotateName)
+ if s.Mode() != 0440 {
+ os.Remove(rotateName)
+ os.Remove("test3.log")
+ t.Fatal("rotate file mode error")
+ }
+ os.Remove(rotateName)
+ os.Remove("test3.log")
+}
func testFileRotate(t *testing.T, fn1, fn2 string) {
fw := &fileLogWriter{
- Daily: true,
- MaxDays: 7,
- Rotate: true,
- Level: LevelTrace,
- Perm: "0660",
+ Daily: true,
+ MaxDays: 7,
+ Rotate: true,
+ Level: LevelTrace,
+ Perm: "0660",
+ RotatePerm: "0440",
}
fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1))
fw.dailyOpenTime = time.Now().Add(-24 * time.Hour)
@@ -188,11 +209,12 @@ func testFileRotate(t *testing.T, fn1, fn2 string) {
func testFileDailyRotate(t *testing.T, fn1, fn2 string) {
fw := &fileLogWriter{
- Daily: true,
- MaxDays: 7,
- Rotate: true,
- Level: LevelTrace,
- Perm: "0660",
+ Daily: true,
+ MaxDays: 7,
+ Rotate: true,
+ Level: LevelTrace,
+ Perm: "0660",
+ RotatePerm: "0440",
}
fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1))
fw.dailyOpenTime = time.Now().Add(-24 * time.Hour)
diff --git a/logs/jianliao.go b/logs/jianliao.go
index 16773c93..88ba0f9a 100644
--- a/logs/jianliao.go
+++ b/logs/jianliao.go
@@ -25,11 +25,7 @@ func newJLWriter() Logger {
// Init JLWriter with json config string
func (s *JLWriter) Init(jsonconfig string) error {
- err := json.Unmarshal([]byte(jsonconfig), s)
- if err != nil {
- return err
- }
- return nil
+ return json.Unmarshal([]byte(jsonconfig), s)
}
// WriteMsg write message in smtp writer.
@@ -65,12 +61,10 @@ func (s *JLWriter) WriteMsg(when time.Time, msg string, level int) error {
// Flush implementing method. empty.
func (s *JLWriter) Flush() {
- return
}
// Destroy implementing method. empty.
func (s *JLWriter) Destroy() {
- return
}
func init() {
diff --git a/logs/log.go b/logs/log.go
index 806ebaa0..0e97a70e 100644
--- a/logs/log.go
+++ b/logs/log.go
@@ -71,6 +71,7 @@ const (
AdapterEs = "es"
AdapterJianLiao = "jianliao"
AdapterSlack = "slack"
+ AdapterAliLS = "alils"
)
// Legacy log level constants to ensure backwards compatibility.
@@ -274,7 +275,7 @@ func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error
line = 0
}
_, filename := path.Split(file)
- msg = "[" + filename + ":" + strconv.FormatInt(int64(line), 10) + "] " + msg
+ msg = "[" + filename + ":" + strconv.Itoa(line) + "] " + msg
}
//set level info in front of filename info
@@ -491,9 +492,9 @@ func (bl *BeeLogger) flush() {
}
// beeLogger references the used application logger.
-var beeLogger *BeeLogger = NewLogger()
+var beeLogger = NewLogger()
-// GetLogger returns the default BeeLogger
+// GetBeeLogger returns the default BeeLogger
func GetBeeLogger() *BeeLogger {
return beeLogger
}
@@ -533,6 +534,7 @@ func Reset() {
beeLogger.Reset()
}
+// Async set the beelogger with Async mode and hold msglen messages
func Async(msgLen ...int64) *BeeLogger {
return beeLogger.Async(msgLen...)
}
@@ -560,11 +562,7 @@ func SetLogFuncCallDepth(d int) {
// SetLogger sets a new logger.
func SetLogger(adapter string, config ...string) error {
- err := beeLogger.SetLogger(adapter, config...)
- if err != nil {
- return err
- }
- return nil
+ return beeLogger.SetLogger(adapter, config...)
}
// Emergency logs a message at emergency level.
diff --git a/logs/logger.go b/logs/logger.go
index e0abfdc4..b5d7255f 100644
--- a/logs/logger.go
+++ b/logs/logger.go
@@ -139,6 +139,11 @@ var (
reset = string([]byte{27, 91, 48, 109})
)
+// ColorByStatus return color by http code
+// 2xx return Green
+// 3xx return White
+// 4xx return Yellow
+// 5xx return Red
func ColorByStatus(cond bool, code int) string {
switch {
case code >= 200 && code < 300:
@@ -152,6 +157,14 @@ func ColorByStatus(cond bool, code int) string {
}
}
+// ColorByMethod return color by http code
+// GET return Blue
+// POST return Cyan
+// PUT return Yellow
+// DELETE return Red
+// PATCH return Green
+// HEAD return Magenta
+// OPTIONS return WHITE
func ColorByMethod(cond bool, method string) string {
switch method {
case "GET":
@@ -173,10 +186,10 @@ func ColorByMethod(cond bool, method string) string {
}
}
-// Guard Mutex to guarantee atomicity of W32Debug(string) function
+// Guard Mutex to guarantee atomic of W32Debug(string) function
var mu sync.Mutex
-// Helper method to output colored logs in Windows terminals
+// W32Debug Helper method to output colored logs in Windows terminals
func W32Debug(msg string) {
mu.Lock()
defer mu.Unlock()
diff --git a/logs/slack.go b/logs/slack.go
index 90f009cb..1cd2e5ae 100644
--- a/logs/slack.go
+++ b/logs/slack.go
@@ -21,11 +21,7 @@ func newSLACKWriter() Logger {
// Init SLACKWriter with json config string
func (s *SLACKWriter) Init(jsonconfig string) error {
- err := json.Unmarshal([]byte(jsonconfig), s)
- if err != nil {
- return err
- }
- return nil
+ return json.Unmarshal([]byte(jsonconfig), s)
}
// WriteMsg write message in smtp writer.
@@ -53,12 +49,10 @@ func (s *SLACKWriter) WriteMsg(when time.Time, msg string, level int) error {
// Flush implementing method. empty.
func (s *SLACKWriter) Flush() {
- return
}
// Destroy implementing method. empty.
func (s *SLACKWriter) Destroy() {
- return
}
func init() {
diff --git a/logs/smtp.go b/logs/smtp.go
index 834130ef..6208d7b8 100644
--- a/logs/smtp.go
+++ b/logs/smtp.go
@@ -52,11 +52,7 @@ func newSMTPWriter() Logger {
// "level":LevelError
// }
func (s *SMTPWriter) Init(jsonconfig string) error {
- err := json.Unmarshal([]byte(jsonconfig), s)
- if err != nil {
- return err
- }
- return nil
+ return json.Unmarshal([]byte(jsonconfig), s)
}
func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth {
@@ -106,7 +102,7 @@ func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd
if err != nil {
return err
}
- _, err = w.Write([]byte(msgContent))
+ _, err = w.Write(msgContent)
if err != nil {
return err
}
@@ -116,12 +112,7 @@ func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd
return err
}
- err = client.Quit()
- if err != nil {
- return err
- }
-
- return nil
+ return client.Quit()
}
// WriteMsg write message in smtp writer.
@@ -147,12 +138,10 @@ func (s *SMTPWriter) WriteMsg(when time.Time, msg string, level int) error {
// Flush implementing method. empty.
func (s *SMTPWriter) Flush() {
- return
}
// Destroy implementing method. empty.
func (s *SMTPWriter) Destroy() {
- return
}
func init() {
diff --git a/migration/ddl.go b/migration/ddl.go
index 51243337..cea10355 100644
--- a/migration/ddl.go
+++ b/migration/ddl.go
@@ -14,40 +14,382 @@
package migration
-// Table store the tablename and Column
-type Table struct {
- TableName string
- Columns []*Column
+import (
+ "fmt"
+
+ "github.com/astaxie/beego"
+)
+
+// Index struct defines the structure of Index Columns
+type Index struct {
+ Name string
}
-// Create return the create sql
-func (t *Table) Create() string {
- return ""
+// Unique struct defines a single unique key combination
+type Unique struct {
+ Definition string
+ Columns []*Column
}
-// Drop return the drop sql
-func (t *Table) Drop() string {
- return ""
-}
-
-// Column define the columns name type and Default
+//Column struct defines a single column of a table
type Column struct {
- Name string
- Type string
- Default interface{}
+ Name string
+ Inc string
+ Null string
+ Default string
+ Unsign string
+ DataType string
+ remove bool
+ Modify bool
}
-// Create return create sql with the provided tbname and columns
-func Create(tbname string, columns ...Column) string {
- return ""
+// Foreign struct defines a single foreign relationship
+type Foreign struct {
+ ForeignTable string
+ ForeignColumn string
+ OnDelete string
+ OnUpdate string
+ Column
}
-// Drop return the drop sql with the provided tbname and columns
-func Drop(tbname string, columns ...Column) string {
- return ""
+// RenameColumn struct allows renaming of columns
+type RenameColumn struct {
+ OldName string
+ OldNull string
+ OldDefault string
+ OldUnsign string
+ OldDataType string
+ NewName string
+ Column
}
-// TableDDL is still in think
-func TableDDL(tbname string, columns ...Column) string {
- return ""
+// CreateTable creates the table on system
+func (m *Migration) CreateTable(tablename, engine, charset string, p ...func()) {
+ m.TableName = tablename
+ m.Engine = engine
+ m.Charset = charset
+ m.ModifyType = "create"
+}
+
+// AlterTable set the ModifyType to alter
+func (m *Migration) AlterTable(tablename string) {
+ m.TableName = tablename
+ m.ModifyType = "alter"
+}
+
+// NewCol creates a new standard column and attaches it to m struct
+func (m *Migration) NewCol(name string) *Column {
+ col := &Column{Name: name}
+ m.AddColumns(col)
+ return col
+}
+
+//PriCol creates a new primary column and attaches it to m struct
+func (m *Migration) PriCol(name string) *Column {
+ col := &Column{Name: name}
+ m.AddColumns(col)
+ m.AddPrimary(col)
+ return col
+}
+
+//UniCol creates / appends columns to specified unique key and attaches it to m struct
+func (m *Migration) UniCol(uni, name string) *Column {
+ col := &Column{Name: name}
+ m.AddColumns(col)
+
+ uniqueOriginal := &Unique{}
+
+ for _, unique := range m.Uniques {
+ if unique.Definition == uni {
+ unique.AddColumnsToUnique(col)
+ uniqueOriginal = unique
+ }
+ }
+ if uniqueOriginal.Definition == "" {
+ unique := &Unique{Definition: uni}
+ unique.AddColumnsToUnique(col)
+ m.AddUnique(unique)
+ }
+
+ return col
+}
+
+//ForeignCol creates a new foreign column and returns the instance of column
+func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreign *Foreign) {
+
+ foreign = &Foreign{ForeignColumn: foreigncol, ForeignTable: foreigntable}
+ foreign.Name = colname
+ m.AddForeign(foreign)
+ return foreign
+}
+
+//SetOnDelete sets the on delete of foreign
+func (foreign *Foreign) SetOnDelete(del string) *Foreign {
+ foreign.OnDelete = "ON DELETE" + del
+ return foreign
+}
+
+//SetOnUpdate sets the on update of foreign
+func (foreign *Foreign) SetOnUpdate(update string) *Foreign {
+ foreign.OnUpdate = "ON UPDATE" + update
+ return foreign
+}
+
+//Remove marks the columns to be removed.
+//it allows reverse m to create the column.
+func (c *Column) Remove() {
+ c.remove = true
+}
+
+//SetAuto enables auto_increment of column (can be used once)
+func (c *Column) SetAuto(inc bool) *Column {
+ if inc {
+ c.Inc = "auto_increment"
+ }
+ return c
+}
+
+//SetNullable sets the column to be null
+func (c *Column) SetNullable(null bool) *Column {
+ if null {
+ c.Null = ""
+
+ } else {
+ c.Null = "NOT NULL"
+ }
+ return c
+}
+
+//SetDefault sets the default value, prepend with "DEFAULT "
+func (c *Column) SetDefault(def string) *Column {
+ c.Default = "DEFAULT " + def
+ return c
+}
+
+//SetUnsigned sets the column to be unsigned int
+func (c *Column) SetUnsigned(unsign bool) *Column {
+ if unsign {
+ c.Unsign = "UNSIGNED"
+ }
+ return c
+}
+
+//SetDataType sets the dataType of the column
+func (c *Column) SetDataType(dataType string) *Column {
+ c.DataType = dataType
+ return c
+}
+
+//SetOldNullable allows reverting to previous nullable on reverse ms
+func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn {
+ if null {
+ c.OldNull = ""
+
+ } else {
+ c.OldNull = "NOT NULL"
+ }
+ return c
+}
+
+//SetOldDefault allows reverting to previous default on reverse ms
+func (c *RenameColumn) SetOldDefault(def string) *RenameColumn {
+ c.OldDefault = def
+ return c
+}
+
+//SetOldUnsigned allows reverting to previous unsgined on reverse ms
+func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn {
+ if unsign {
+ c.OldUnsign = "UNSIGNED"
+ }
+ return c
+}
+
+//SetOldDataType allows reverting to previous datatype on reverse ms
+func (c *RenameColumn) SetOldDataType(dataType string) *RenameColumn {
+ c.OldDataType = dataType
+ return c
+}
+
+//SetPrimary adds the columns to the primary key (can only be used any number of times in only one m)
+func (c *Column) SetPrimary(m *Migration) *Column {
+ m.Primary = append(m.Primary, c)
+ return c
+}
+
+//AddColumnsToUnique adds the columns to Unique Struct
+func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique {
+
+ unique.Columns = append(unique.Columns, columns...)
+
+ return unique
+}
+
+//AddColumns adds columns to m struct
+func (m *Migration) AddColumns(columns ...*Column) *Migration {
+
+ m.Columns = append(m.Columns, columns...)
+
+ return m
+}
+
+//AddPrimary adds the column to primary in m struct
+func (m *Migration) AddPrimary(primary *Column) *Migration {
+ m.Primary = append(m.Primary, primary)
+ return m
+}
+
+//AddUnique adds the column to unique in m struct
+func (m *Migration) AddUnique(unique *Unique) *Migration {
+ m.Uniques = append(m.Uniques, unique)
+ return m
+}
+
+//AddForeign adds the column to foreign in m struct
+func (m *Migration) AddForeign(foreign *Foreign) *Migration {
+ m.Foreigns = append(m.Foreigns, foreign)
+ return m
+}
+
+//AddIndex adds the column to index in m struct
+func (m *Migration) AddIndex(index *Index) *Migration {
+ m.Indexes = append(m.Indexes, index)
+ return m
+}
+
+//RenameColumn allows renaming of columns
+func (m *Migration) RenameColumn(from, to string) *RenameColumn {
+ rename := &RenameColumn{OldName: from, NewName: to}
+ m.Renames = append(m.Renames, rename)
+ return rename
+}
+
+//GetSQL returns the generated sql depending on ModifyType
+func (m *Migration) GetSQL() (sql string) {
+ sql = ""
+ switch m.ModifyType {
+ case "create":
+ {
+ sql += fmt.Sprintf("CREATE TABLE `%s` (", m.TableName)
+ for index, column := range m.Columns {
+ sql += fmt.Sprintf("\n `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
+ if len(m.Columns) > index+1 {
+ sql += ","
+ }
+ }
+
+ if len(m.Primary) > 0 {
+ sql += fmt.Sprintf(",\n PRIMARY KEY( ")
+ }
+ for index, column := range m.Primary {
+ sql += fmt.Sprintf(" `%s`", column.Name)
+ if len(m.Primary) > index+1 {
+ sql += ","
+ }
+
+ }
+ if len(m.Primary) > 0 {
+ sql += fmt.Sprintf(")")
+ }
+
+ for _, unique := range m.Uniques {
+ sql += fmt.Sprintf(",\n UNIQUE KEY `%s`( ", unique.Definition)
+ for index, column := range unique.Columns {
+ sql += fmt.Sprintf(" `%s`", column.Name)
+ if len(unique.Columns) > index+1 {
+ sql += ","
+ }
+ }
+ sql += fmt.Sprintf(")")
+ }
+ for _, foreign := range m.Foreigns {
+ sql += fmt.Sprintf(",\n `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default)
+ sql += fmt.Sprintf(",\n KEY `%s_%s_foreign`(`%s`),", m.TableName, foreign.Column.Name, foreign.Column.Name)
+ sql += fmt.Sprintf("\n CONSTRAINT `%s_%s_foreign` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) %s %s", m.TableName, foreign.Column.Name, foreign.Column.Name, foreign.ForeignTable, foreign.ForeignColumn, foreign.OnDelete, foreign.OnUpdate)
+
+ }
+ sql += fmt.Sprintf(")ENGINE=%s DEFAULT CHARSET=%s;", m.Engine, m.Charset)
+ break
+ }
+ case "alter":
+ {
+ sql += fmt.Sprintf("ALTER TABLE `%s` ", m.TableName)
+ for index, column := range m.Columns {
+ if !column.remove {
+ beego.BeeLogger.Info("col")
+ sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
+ } else {
+ sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name)
+ }
+
+ if len(m.Columns) > index {
+ sql += ","
+ }
+ }
+ for index, column := range m.Renames {
+ sql += fmt.Sprintf("CHANGE COLUMN `%s` `%s` %s %s %s %s %s", column.OldName, column.NewName, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
+ if len(m.Renames) > index+1 {
+ sql += ","
+ }
+ }
+
+ for index, foreign := range m.Foreigns {
+ sql += fmt.Sprintf("ADD `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default)
+ sql += fmt.Sprintf(",\n ADD KEY `%s_%s_foreign`(`%s`)", m.TableName, foreign.Column.Name, foreign.Column.Name)
+ sql += fmt.Sprintf(",\n ADD CONSTRAINT `%s_%s_foreign` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) %s %s", m.TableName, foreign.Column.Name, foreign.Column.Name, foreign.ForeignTable, foreign.ForeignColumn, foreign.OnDelete, foreign.OnUpdate)
+ if len(m.Foreigns) > index+1 {
+ sql += ","
+ }
+ }
+ sql += ";"
+
+ break
+ }
+ case "reverse":
+ {
+
+ sql += fmt.Sprintf("ALTER TABLE `%s`", m.TableName)
+ for index, column := range m.Columns {
+ if column.remove {
+ sql += fmt.Sprintf("\n ADD `%s` %s %s %s %s %s", column.Name, column.DataType, column.Unsign, column.Null, column.Inc, column.Default)
+ } else {
+ sql += fmt.Sprintf("\n DROP COLUMN `%s`", column.Name)
+ }
+ if len(m.Columns) > index {
+ sql += ","
+ }
+ }
+
+ if len(m.Primary) > 0 {
+ sql += fmt.Sprintf("\n DROP PRIMARY KEY,")
+ }
+
+ for index, unique := range m.Uniques {
+ sql += fmt.Sprintf("\n DROP KEY `%s`", unique.Definition)
+ if len(m.Uniques) > index {
+ sql += ","
+ }
+
+ }
+ for index, column := range m.Renames {
+ sql += fmt.Sprintf("\n CHANGE COLUMN `%s` `%s` %s %s %s %s", column.NewName, column.OldName, column.OldDataType, column.OldUnsign, column.OldNull, column.OldDefault)
+ if len(m.Renames) > index {
+ sql += ","
+ }
+ }
+
+ for _, foreign := range m.Foreigns {
+ sql += fmt.Sprintf("\n DROP KEY `%s_%s_foreign`", m.TableName, foreign.Column.Name)
+ sql += fmt.Sprintf(",\n DROP FOREIGN KEY `%s_%s_foreign`", m.TableName, foreign.Column.Name)
+ sql += fmt.Sprintf(",\n DROP COLUMN `%s`", foreign.Name)
+ }
+ sql += ";"
+ }
+ case "delete":
+ {
+ sql += fmt.Sprintf("DROP TABLE IF EXISTS `%s`;", m.TableName)
+ }
+ }
+
+ return
}
diff --git a/migration/doc.go b/migration/doc.go
new file mode 100644
index 00000000..0c6564d4
--- /dev/null
+++ b/migration/doc.go
@@ -0,0 +1,32 @@
+// Package migration enables you to generate migrations back and forth. It generates both migrations.
+//
+// //Creates a table
+// m.CreateTable("tablename","InnoDB","utf8");
+//
+// //Alter a table
+// m.AlterTable("tablename")
+//
+// Standard Column Methods
+// * SetDataType
+// * SetNullable
+// * SetDefault
+// * SetUnsigned (use only on integer types unless produces error)
+//
+// //Sets a primary column, multiple calls allowed, standard column methods available
+// m.PriCol("id").SetAuto(true).SetNullable(false).SetDataType("INT(10)").SetUnsigned(true)
+//
+// //UniCol Can be used multiple times, allows standard Column methods. Use same "index" string to add to same index
+// m.UniCol("index","column")
+//
+// //Standard Column Initialisation, can call .Remove() after NewCol("") on alter to remove
+// m.NewCol("name").SetDataType("VARCHAR(255) COLLATE utf8_unicode_ci").SetNullable(false)
+// m.NewCol("value").SetDataType("DOUBLE(8,2)").SetNullable(false)
+//
+// //Rename Columns , only use with Alter table, doesn't works with Create, prefix standard column methods with "Old" to
+// //create a true reversible migration eg: SetOldDataType("DOUBLE(12,3)")
+// m.RenameColumn("from","to")...
+//
+// //Foreign Columns, single columns are only supported, SetOnDelete & SetOnUpdate are available, call appropriately.
+// //Supports standard column methods, automatic reverse.
+// m.ForeignCol("local_col","foreign_col","foreign_table")
+package migration
diff --git a/migration/migration.go b/migration/migration.go
index 9e03fc33..97e10c2e 100644
--- a/migration/migration.go
+++ b/migration/migration.go
@@ -52,6 +52,26 @@ type Migrationer interface {
GetCreated() int64
}
+//Migration defines the migrations by either SQL or DDL
+type Migration struct {
+ sqls []string
+ Created string
+ TableName string
+ Engine string
+ Charset string
+ ModifyType string
+ Columns []*Column
+ Indexes []*Index
+ Primary []*Column
+ Uniques []*Unique
+ Foreigns []*Foreign
+ Renames []*RenameColumn
+ RemoveColumns []*Column
+ RemoveIndexes []*Index
+ RemoveUniques []*Unique
+ RemoveForeigns []*Foreign
+}
+
var (
migrationMap map[string]Migrationer
)
@@ -60,20 +80,34 @@ func init() {
migrationMap = make(map[string]Migrationer)
}
-// Migration the basic type which will implement the basic type
-type Migration struct {
- sqls []string
- Created string
-}
-
// Up implement in the Inheritance struct for upgrade
func (m *Migration) Up() {
+ switch m.ModifyType {
+ case "reverse":
+ m.ModifyType = "alter"
+ case "delete":
+ m.ModifyType = "create"
+ }
+ m.sqls = append(m.sqls, m.GetSQL())
}
// Down implement in the Inheritance struct for down
func (m *Migration) Down() {
+ switch m.ModifyType {
+ case "alter":
+ m.ModifyType = "reverse"
+ case "create":
+ m.ModifyType = "delete"
+ }
+ m.sqls = append(m.sqls, m.GetSQL())
+}
+
+//Migrate adds the SQL to the execution list
+func (m *Migration) Migrate(migrationType string) {
+ m.ModifyType = migrationType
+ m.sqls = append(m.sqls, m.GetSQL())
}
// SQL add sql want to execute
diff --git a/namespace.go b/namespace.go
index cfde0111..72f22a72 100644
--- a/namespace.go
+++ b/namespace.go
@@ -267,13 +267,12 @@ func addPrefix(t *Tree, prefix string) {
addPrefix(t.wildcard, prefix)
}
for _, l := range t.leaves {
- if c, ok := l.runObject.(*controllerInfo); ok {
+ if c, ok := l.runObject.(*ControllerInfo); ok {
if !strings.HasPrefix(c.pattern, prefix) {
c.pattern = prefix + c.pattern
}
}
}
-
}
// NSCond is Namespace Condition
@@ -284,16 +283,16 @@ func NSCond(cond namespaceCond) LinkNamespace {
}
// NSBefore Namespace BeforeRouter filter
-func NSBefore(filiterList ...FilterFunc) LinkNamespace {
+func NSBefore(filterList ...FilterFunc) LinkNamespace {
return func(ns *Namespace) {
- ns.Filter("before", filiterList...)
+ ns.Filter("before", filterList...)
}
}
// NSAfter add Namespace FinishRouter filter
-func NSAfter(filiterList ...FilterFunc) LinkNamespace {
+func NSAfter(filterList ...FilterFunc) LinkNamespace {
return func(ns *Namespace) {
- ns.Filter("after", filiterList...)
+ ns.Filter("after", filterList...)
}
}
diff --git a/namespace_test.go b/namespace_test.go
index fc02b5fb..b3f20dff 100644
--- a/namespace_test.go
+++ b/namespace_test.go
@@ -139,10 +139,7 @@ func TestNamespaceCond(t *testing.T) {
ns := NewNamespace("/v2")
ns.Cond(func(ctx *context.Context) bool {
- if ctx.Input.Domain() == "beego.me" {
- return true
- }
- return false
+ return ctx.Input.Domain() == "beego.me"
}).
AutoRouter(&TestController{})
AddNamespace(ns)
diff --git a/orm/cmd.go b/orm/cmd.go
index 3638a75c..0ff4dc40 100644
--- a/orm/cmd.go
+++ b/orm/cmd.go
@@ -150,7 +150,7 @@ func (d *commandSyncDb) Run() error {
}
for _, fi := range mi.fields.fieldsDB {
- if _, ok := columns[fi.column]; ok == false {
+ if _, ok := columns[fi.column]; !ok {
fields = append(fields, fi)
}
}
@@ -175,7 +175,7 @@ func (d *commandSyncDb) Run() error {
}
for _, idx := range indexes[mi.table] {
- if d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) == false {
+ if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) {
if !d.noInfo {
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
}
diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go
index 8119b70b..de47cb02 100644
--- a/orm/cmd_utils.go
+++ b/orm/cmd_utils.go
@@ -89,7 +89,7 @@ checkColumn:
col = T["float64"]
case TypeDecimalField:
s := T["float64-decimal"]
- if strings.Index(s, "%d") == -1 {
+ if !strings.Contains(s, "%d") {
col = s
} else {
col = fmt.Sprintf(s, fi.digits, fi.decimals)
@@ -120,7 +120,7 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string {
Q := al.DbBaser.TableQuote()
typ := getColumnTyp(al, fi)
- if fi.null == false {
+ if !fi.null {
typ += " " + "NOT NULL"
}
@@ -172,7 +172,7 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
} else {
column += col
- if fi.null == false {
+ if !fi.null {
column += " " + "NOT NULL"
}
@@ -192,7 +192,7 @@ func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
}
}
- if strings.Index(column, "%COL%") != -1 {
+ if strings.Contains(column, "%COL%") {
column = strings.Replace(column, "%COL%", fi.column, -1)
}
diff --git a/orm/db.go b/orm/db.go
index bca6071d..12f0f54d 100644
--- a/orm/db.go
+++ b/orm/db.go
@@ -48,7 +48,7 @@ var (
"lte": true,
"eq": true,
"nq": true,
- "ne": true,
+ "ne": true,
"startswith": true,
"endswith": true,
"istartswith": true,
@@ -87,7 +87,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
} else {
panic(fmt.Errorf("wrong db field/column name `%s` for model `%s`", column, mi.fullName))
}
- if fi.dbcol == false || fi.auto && skipAuto {
+ if !fi.dbcol || fi.auto && skipAuto {
continue
}
value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
@@ -224,7 +224,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
value = nil
}
}
- if fi.null == false && value == nil {
+ if !fi.null && value == nil {
return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName)
}
}
@@ -271,7 +271,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
dbcols := make([]string, 0, len(mi.fields.dbcols))
marks := make([]string, 0, len(mi.fields.dbcols))
for _, fi := range mi.fields.fieldsDB {
- if fi.auto == false {
+ if !fi.auto {
dbcols = append(dbcols, fi.column)
marks = append(marks, "?")
}
@@ -326,7 +326,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
} else {
// default use pk value as where condtion.
pkColumn, pkValue, ok := getExistPk(mi, ind)
- if ok == false {
+ if !ok {
return ErrMissPK
}
whereCols = []string{pkColumn}
@@ -507,10 +507,9 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
case DRPostgres:
if len(args) == 0 {
return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName)
- } else {
- args0 = strings.ToLower(args[0])
- iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0)
}
+ args0 = strings.ToLower(args[0])
+ iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0)
default:
return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName)
}
@@ -592,7 +591,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
row := q.QueryRow(query, values...)
var id int64
err = row.Scan(&id)
- if err.Error() == `pq: syntax error at or near "ON"` {
+ if err != nil && err.Error() == `pq: syntax error at or near "ON"` {
err = fmt.Errorf("postgres version must 9.5 or higher")
}
return id, err
@@ -601,7 +600,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a
// execute update sql dbQuerier with given struct reflect.Value.
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
pkName, pkValue, ok := getExistPk(mi, ind)
- if ok == false {
+ if !ok {
return 0, ErrMissPK
}
@@ -654,7 +653,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
} else {
// default use pk value as where condtion.
pkColumn, pkValue, ok := getExistPk(mi, ind)
- if ok == false {
+ if !ok {
return 0, ErrMissPK
}
whereCols = []string{pkColumn}
@@ -699,7 +698,7 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
columns := make([]string, 0, len(params))
values := make([]interface{}, 0, len(params))
for col, val := range params {
- if fi, ok := mi.fields.GetByAny(col); ok == false || fi.dbcol == false {
+ if fi, ok := mi.fields.GetByAny(col); !ok || !fi.dbcol {
panic(fmt.Errorf("wrong field/column name `%s`", col))
} else {
columns = append(columns, fi.column)
@@ -834,7 +833,11 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
if err := rs.Scan(&ref); err != nil {
return 0, err
}
- args = append(args, reflect.ValueOf(ref).Interface())
+ pkValue, err := d.convertValueFromDB(mi.fields.pk, reflect.ValueOf(ref).Interface(), tz)
+ if err != nil {
+ return 0, err
+ }
+ args = append(args, pkValue)
cnt++
}
@@ -929,7 +932,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
if hasRel {
for _, fi := range mi.fields.fieldsDB {
if fi.fieldType&IsRelField > 0 {
- if maps[fi.column] == false {
+ if !maps[fi.column] {
tCols = append(tCols, fi.column)
}
}
@@ -987,7 +990,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
var cnt int64
for rs.Next() {
- if one && cnt == 0 || one == false {
+ if one && cnt == 0 || !one {
if err := rs.Scan(refs...); err != nil {
return 0, err
}
@@ -1067,7 +1070,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
cnt++
}
- if one == false {
+ if !one {
if cnt > 0 {
ind.Set(slice)
} else {
@@ -1110,7 +1113,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
// generate sql with replacing operator string placeholders and replaced values.
func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) {
- sql := ""
+ var sql string
params := getFlatParams(fi, args, tz)
if len(params) == 0 {
@@ -1357,7 +1360,7 @@ end:
func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) {
fieldType := fi.fieldType
- isNative := fi.isFielder == false
+ isNative := !fi.isFielder
setValue:
switch {
@@ -1533,7 +1536,7 @@ setValue:
}
}
- if isNative == false {
+ if !isNative {
fd := field.Addr().Interface().(Fielder)
err := fd.SetRaw(value)
if err != nil {
@@ -1594,7 +1597,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
infos = make([]*fieldInfo, 0, len(exprs))
for _, ex := range exprs {
index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep))
- if suc == false {
+ if !suc {
panic(fmt.Errorf("unknown field/column name `%s`", ex))
}
cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q))
@@ -1733,7 +1736,7 @@ func (d *dbBase) TableQuote() string {
return "`"
}
-// replace value placeholer in parametered sql string.
+// replace value placeholder in parametered sql string.
func (d *dbBase) ReplaceMarks(query *string) {
// default use `?` as mark, do nothing
}
diff --git a/orm/db_alias.go b/orm/db_alias.go
index c95d49c9..c7089239 100644
--- a/orm/db_alias.go
+++ b/orm/db_alias.go
@@ -60,6 +60,8 @@ var (
"sqlite3": DRSqlite,
"tidb": DRTiDB,
"oracle": DROracle,
+ "oci8": DROracle, // github.com/mattn/go-oci8
+ "ora": DROracle, //https://github.com/rana/ora
}
dbBasers = map[DriverType]dbBaser{
DRMySQL: newdbBaseMysql(),
@@ -186,7 +188,7 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) {
return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
}
- if dataBaseCache.add(aliasName, al) == false {
+ if !dataBaseCache.add(aliasName, al) {
return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
}
@@ -244,11 +246,11 @@ end:
// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
func RegisterDriver(driverName string, typ DriverType) error {
- if t, ok := drivers[driverName]; ok == false {
+ if t, ok := drivers[driverName]; !ok {
drivers[driverName] = typ
} else {
if t != typ {
- return fmt.Errorf("driverName `%s` db driver already registered and is other type\n", driverName)
+ return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName)
}
}
return nil
@@ -259,7 +261,7 @@ func SetDataBaseTZ(aliasName string, tz *time.Location) error {
if al, ok := dataBaseCache.get(aliasName); ok {
al.TZ = tz
} else {
- return fmt.Errorf("DataBase alias name `%s` not registered\n", aliasName)
+ return fmt.Errorf("DataBase alias name `%s` not registered", aliasName)
}
return nil
}
@@ -294,5 +296,5 @@ func GetDB(aliasNames ...string) (*sql.DB, error) {
if ok {
return al.DB, nil
}
- return nil, fmt.Errorf("DataBase of alias name `%s` not found\n", name)
+ return nil, fmt.Errorf("DataBase of alias name `%s` not found", name)
}
diff --git a/orm/db_mysql.go b/orm/db_mysql.go
index 1016de2b..51185563 100644
--- a/orm/db_mysql.go
+++ b/orm/db_mysql.go
@@ -103,8 +103,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool
// If no will insert
// Add "`" for mysql sql building
func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
-
- iouStr := ""
+ var iouStr string
argsMap := map[string]string{}
iouStr = "ON DUPLICATE KEY UPDATE"
diff --git a/orm/db_oracle.go b/orm/db_oracle.go
index deca36ad..f5d6aaa2 100644
--- a/orm/db_oracle.go
+++ b/orm/db_oracle.go
@@ -94,3 +94,43 @@ func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool
row.Scan(&cnt)
return cnt > 0
}
+
+// execute insert sql with given struct and given values.
+// insert the given values, not the field values in struct.
+func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
+ Q := d.ins.TableQuote()
+
+ marks := make([]string, len(names))
+ for i := range marks {
+ marks[i] = ":" + names[i]
+ }
+
+ sep := fmt.Sprintf("%s, %s", Q, Q)
+ qmarks := strings.Join(marks, ", ")
+ columns := strings.Join(names, sep)
+
+ multi := len(values) / len(names)
+
+ if isMulti {
+ qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
+ }
+
+ query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
+
+ d.ins.ReplaceMarks(&query)
+
+ if isMulti || !d.ins.HasReturningID(mi, &query) {
+ res, err := q.Exec(query, values...)
+ if err == nil {
+ if isMulti {
+ return res.RowsAffected()
+ }
+ return res.LastInsertId()
+ }
+ return 0, err
+ }
+ row := q.QueryRow(query, values...)
+ var id int64
+ err := row.Scan(&id)
+ return id, err
+}
diff --git a/orm/db_sqlite.go b/orm/db_sqlite.go
index a3cb69a7..a43a5594 100644
--- a/orm/db_sqlite.go
+++ b/orm/db_sqlite.go
@@ -134,7 +134,7 @@ func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool
defer rows.Close()
for rows.Next() {
var tmp, index sql.NullString
- rows.Scan(&tmp, &index, &tmp)
+ rows.Scan(&tmp, &index, &tmp, &tmp, &tmp)
if name == index.String {
return true
}
diff --git a/orm/db_tables.go b/orm/db_tables.go
index e4c74ace..42be5550 100644
--- a/orm/db_tables.go
+++ b/orm/db_tables.go
@@ -63,7 +63,7 @@ func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
// add table info to collection.
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
name := strings.Join(names, ExprSep)
- if _, ok := t.tablesM[name]; ok == false {
+ if _, ok := t.tablesM[name]; !ok {
i := len(t.tables) + 1
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
t.tablesM[name] = jt
@@ -261,7 +261,7 @@ loopFor:
fiN, okN = mmi.fields.GetByAny(exprs[i+1])
}
- if isRel && (fi.mi.isThrough == false || num != i) {
+ if isRel && (!fi.mi.isThrough || num != i) {
if fi.null || t.skipEnd {
inner = false
}
@@ -364,7 +364,7 @@ func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (whe
}
index, _, fi, suc := t.parseExprs(mi, exprs)
- if suc == false {
+ if !suc {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
}
@@ -383,7 +383,7 @@ func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (whe
}
}
- if sub == false && where != "" {
+ if !sub && where != "" {
where = "WHERE " + where
}
@@ -403,7 +403,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
exprs := strings.Split(group, ExprSep)
index, _, fi, suc := t.parseExprs(t.mi, exprs)
- if suc == false {
+ if !suc {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
}
@@ -432,7 +432,7 @@ func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
exprs := strings.Split(order, ExprSep)
index, _, fi, suc := t.parseExprs(t.mi, exprs)
- if suc == false {
+ if !suc {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
}
diff --git a/orm/db_utils.go b/orm/db_utils.go
index 923917ec..7ae10ca5 100644
--- a/orm/db_utils.go
+++ b/orm/db_utils.go
@@ -41,6 +41,8 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac
vu := v.Int()
exist = true
value = vu
+ } else if fi.fieldType&IsRelField > 0 {
+ _, value, exist = getExistPk(fi.relModelInfo, reflect.Indirect(v))
} else {
vu := v.String()
exist = vu != ""
diff --git a/orm/models_boot.go b/orm/models_boot.go
index 7082fa3e..badfd11b 100644
--- a/orm/models_boot.go
+++ b/orm/models_boot.go
@@ -75,7 +75,7 @@ func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) {
}
if mi.fields.pk == nil {
- fmt.Printf(" `%s` need a primary key field, default use 'id' if not set\n", name)
+ fmt.Printf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name)
os.Exit(2)
}
@@ -117,7 +117,7 @@ func bootStrap() {
name := getFullName(elm)
mii, ok := modelCache.getByFullName(name)
if !ok || mii.pkg != elm.PkgPath() {
- err = fmt.Errorf("can not found rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String())
+ err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String())
goto end
}
fi.relModelInfo = mii
@@ -128,7 +128,7 @@ func bootStrap() {
if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) {
pn := fi.relThrough[:i]
rmi, ok := modelCache.getByFullName(fi.relThrough)
- if ok == false || pn != rmi.pkg {
+ if !ok || pn != rmi.pkg {
err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough)
goto end
}
@@ -171,7 +171,7 @@ func bootStrap() {
break
}
}
- if inModel == false {
+ if !inModel {
rmi := fi.relModelInfo
ffi := new(fieldInfo)
ffi.name = mi.name
@@ -185,7 +185,7 @@ func bootStrap() {
} else {
ffi.fieldType = RelReverseMany
}
- if rmi.fields.Add(ffi) == false {
+ if !rmi.fields.Add(ffi) {
added := false
for cnt := 0; cnt < 5; cnt++ {
ffi.name = fmt.Sprintf("%s%d", mi.name, cnt)
@@ -195,7 +195,7 @@ func bootStrap() {
break
}
}
- if added == false {
+ if !added {
panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName))
}
}
@@ -248,7 +248,7 @@ func bootStrap() {
break mForA
}
}
- if found == false {
+ if !found {
err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
goto end
}
@@ -267,7 +267,7 @@ func bootStrap() {
break mForB
}
}
- if found == false {
+ if !found {
mForC:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] {
conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough ||
@@ -287,7 +287,7 @@ func bootStrap() {
}
}
}
- if found == false {
+ if !found {
err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
goto end
}
diff --git a/orm/models_info_f.go b/orm/models_info_f.go
index 4b3d3e27..bbb7d71f 100644
--- a/orm/models_info_f.go
+++ b/orm/models_info_f.go
@@ -47,7 +47,7 @@ func (f *fields) Add(fi *fieldInfo) (added bool) {
} else {
return
}
- if _, ok := f.fieldsByType[fi.fieldType]; ok == false {
+ if _, ok := f.fieldsByType[fi.fieldType]; !ok {
f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0)
}
f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi)
@@ -334,12 +334,12 @@ checkType:
switch onDelete {
case odCascade, odDoNothing:
case odSetDefault:
- if initial.Exist() == false {
+ if !initial.Exist() {
err = errors.New("on_delete: set_default need set field a default value")
goto end
}
case odSetNULL:
- if fi.null == false {
+ if !fi.null {
err = errors.New("on_delete: set_null need set field null")
goto end
}
diff --git a/orm/models_info_m.go b/orm/models_info_m.go
index d6ba1dca..4a3a37f9 100644
--- a/orm/models_info_m.go
+++ b/orm/models_info_m.go
@@ -78,7 +78,7 @@ func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int)
fi.fieldIndex = append(index, i)
fi.mi = mi
fi.inModel = true
- if mi.fields.Add(fi) == false {
+ if !mi.fields.Add(fi) {
err = fmt.Errorf("duplicate column name: %s", fi.column)
break
}
diff --git a/orm/models_test.go b/orm/models_test.go
index 462370b2..9843a87d 100644
--- a/orm/models_test.go
+++ b/orm/models_test.go
@@ -406,6 +406,11 @@ type UintPk struct {
Name string
}
+type PtrPk struct {
+ ID *IntegerPk `orm:"pk;rel(one)"`
+ Positive bool
+}
+
var DBARGS = struct {
Driver string
Source string
diff --git a/orm/orm.go b/orm/orm.go
index 538916e4..fcf82590 100644
--- a/orm/orm.go
+++ b/orm/orm.go
@@ -107,7 +107,7 @@ func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect
if mi, ok := modelCache.getByFullName(name); ok {
return mi, ind
}
- panic(fmt.Errorf(" table: `%s` not found, maybe not RegisterModel", name))
+ panic(fmt.Errorf(" table: `%s` not found, make sure it was registered with `RegisterModel()`", name))
}
// get field info from model info by given field name
@@ -122,21 +122,13 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
// read data to model
func (o *orm) Read(md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md, true)
- err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
- if err != nil {
- return err
- }
- return nil
+ return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
}
// read data to model, like Read(), but use "SELECT FOR UPDATE" form
func (o *orm) ReadForUpdate(md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md, true)
- err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true)
- if err != nil {
- return err
- }
- return nil
+ return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true)
}
// Try to read a row from the database, or insert one if it doesn't exist
@@ -153,6 +145,8 @@ func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, i
id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex)
if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
id = int64(vid.Uint())
+ } else if mi.fields.pk.rel {
+ return o.ReadOrCreate(vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name)
} else {
id = vid.Int()
}
@@ -236,15 +230,11 @@ func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64
// cols set the columns those want to update.
func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
mi, ind := o.getMiInd(md, true)
- num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
- if err != nil {
- return num, err
- }
- return num, nil
+ return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
}
// delete model in database
-// cols shows the delete conditions values read from. deafult is pk
+// cols shows the delete conditions values read from. default is pk
func (o *orm) Delete(md interface{}, cols ...string) (int64, error) {
mi, ind := o.getMiInd(md, true)
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols)
@@ -359,7 +349,7 @@ func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo,
fi := o.getFieldInfo(mi, name)
_, _, exist := getExistPk(mi, ind)
- if exist == false {
+ if !exist {
panic(ErrMissPK)
}
@@ -430,7 +420,7 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
// table name can be string or struct.
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
- name := ""
+ var name string
if table, ok := ptrStructOrTableName.(string); ok {
name = snakeString(table)
if mi, ok := modelCache.get(name); ok {
@@ -487,7 +477,7 @@ func (o *orm) Begin() error {
// commit transaction
func (o *orm) Commit() error {
- if o.isTx == false {
+ if !o.isTx {
return ErrTxDone
}
err := o.db.(txEnder).Commit()
@@ -502,7 +492,7 @@ func (o *orm) Commit() error {
// rollback transaction
func (o *orm) Rollback() error {
- if o.isTx == false {
+ if !o.isTx {
return ErrTxDone
}
err := o.db.(txEnder).Rollback()
diff --git a/orm/orm_querym2m.go b/orm/orm_querym2m.go
index b220bda6..6a270a0d 100644
--- a/orm/orm_querym2m.go
+++ b/orm/orm_querym2m.go
@@ -72,7 +72,7 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
}
_, v1, exist := getExistPk(o.mi, o.ind)
- if exist == false {
+ if !exist {
panic(ErrMissPK)
}
@@ -87,7 +87,7 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
v2 = ind.Interface()
} else {
_, v2, exist = getExistPk(fi.relModelInfo, ind)
- if exist == false {
+ if !exist {
panic(ErrMissPK)
}
}
@@ -104,11 +104,7 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
fi := o.fi
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
- nums, err := qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete()
- if err != nil {
- return nums, err
- }
- return nums, nil
+ return qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete()
}
// check model is existed in relationship of origin model
diff --git a/orm/orm_queryset.go b/orm/orm_queryset.go
index 575f62ae..4e33646d 100644
--- a/orm/orm_queryset.go
+++ b/orm/orm_queryset.go
@@ -153,6 +153,11 @@ func (o querySet) SetCond(cond *Condition) QuerySeter {
return &o
}
+// get condition from QuerySeter
+func (o querySet) GetCond() *Condition {
+ return o.cond
+}
+
// return QuerySeter execution result number
func (o *querySet) Count() (int64, error) {
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
diff --git a/orm/orm_raw.go b/orm/orm_raw.go
index a968b1a1..c8e741ea 100644
--- a/orm/orm_raw.go
+++ b/orm/orm_raw.go
@@ -493,19 +493,33 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
}
}
} else {
- for i := 0; i < ind.NumField(); i++ {
- f := ind.Field(i)
- fe := ind.Type().Field(i)
- _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
- var col string
- if col = tags["column"]; col == "" {
- col = snakeString(fe.Name)
- }
- if v, ok := columnsMp[col]; ok {
- value := reflect.ValueOf(v).Elem().Interface()
- o.setFieldValue(f, value)
+ // define recursive function
+ var recursiveSetField func(rv reflect.Value)
+ recursiveSetField = func(rv reflect.Value) {
+ for i := 0; i < rv.NumField(); i++ {
+ f := rv.Field(i)
+ fe := rv.Type().Field(i)
+
+ // check if the field is a Struct
+ // recursive the Struct type
+ if fe.Type.Kind() == reflect.Struct {
+ recursiveSetField(f)
+ }
+
+ _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
+ var col string
+ if col = tags["column"]; col == "" {
+ col = snakeString(fe.Name)
+ }
+ if v, ok := columnsMp[col]; ok {
+ value := reflect.ValueOf(v).Elem().Interface()
+ o.setFieldValue(f, value)
+ }
}
}
+
+ // init call the recursive function
+ recursiveSetField(ind)
}
if eTyps[0].Kind() == reflect.Ptr {
@@ -671,7 +685,7 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in
ind *reflect.Value
)
- typ := 0
+ var typ int
switch container.(type) {
case *Params:
typ = 1
diff --git a/orm/orm_test.go b/orm/orm_test.go
index adfe0066..f1f2d85e 100644
--- a/orm/orm_test.go
+++ b/orm/orm_test.go
@@ -93,14 +93,14 @@ wrongArg:
}
func AssertIs(a interface{}, args ...interface{}) error {
- if ok, err := ValuesCompare(true, a, args...); ok == false {
+ if ok, err := ValuesCompare(true, a, args...); !ok {
return err
}
return nil
}
func AssertNot(a interface{}, args ...interface{}) error {
- if ok, err := ValuesCompare(false, a, args...); ok == false {
+ if ok, err := ValuesCompare(false, a, args...); !ok {
return err
}
return nil
@@ -135,7 +135,7 @@ func getCaller(skip int) string {
if i := strings.LastIndex(funName, "."); i > -1 {
funName = funName[i+1:]
}
- return fmt.Sprintf("%s:%d: \n%s", fn, line, strings.Join(codes, "\n"))
+ return fmt.Sprintf("%s:%s:%d: \n%s", fn, funName, line, strings.Join(codes, "\n"))
}
func throwFail(t *testing.T, err error, args ...interface{}) {
@@ -193,6 +193,7 @@ func TestSyncDb(t *testing.T) {
RegisterModel(new(InLineOneToOne))
RegisterModel(new(IntegerPk))
RegisterModel(new(UintPk))
+ RegisterModel(new(PtrPk))
err := RunSyncdb("default", true, Debug)
throwFail(t, err)
@@ -216,6 +217,7 @@ func TestRegisterModels(t *testing.T) {
RegisterModel(new(InLineOneToOne))
RegisterModel(new(IntegerPk))
RegisterModel(new(UintPk))
+ RegisterModel(new(PtrPk))
BootStrap()
@@ -1012,6 +1014,8 @@ func TestAll(t *testing.T) {
var users3 []*User
qs = dORM.QueryTable("user")
num, err = qs.Filter("user_name", "nothing").All(&users3)
+ throwFailNow(t, err)
+ throwFailNow(t, AssertIs(num, 0))
throwFailNow(t, AssertIs(users3 == nil, false))
}
@@ -1136,6 +1140,7 @@ func TestRelatedSel(t *testing.T) {
}
err = qs.Filter("user_name", "nobody").RelatedSel("profile").One(&user)
+ throwFail(t, err)
throwFail(t, AssertIs(num, 1))
throwFail(t, AssertIs(user.Profile, nil))
@@ -1244,20 +1249,24 @@ func TestLoadRelated(t *testing.T) {
num, err = dORM.LoadRelated(&user, "Posts", true)
throwFailNow(t, err)
+ throwFailNow(t, AssertIs(num, 2))
throwFailNow(t, AssertIs(len(user.Posts), 2))
throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie"))
num, err = dORM.LoadRelated(&user, "Posts", true, 1)
throwFailNow(t, err)
+ throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(len(user.Posts), 1))
num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id")
throwFailNow(t, err)
+ throwFailNow(t, AssertIs(num, 2))
throwFailNow(t, AssertIs(len(user.Posts), 2))
throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting"))
num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id")
throwFailNow(t, err)
+ throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(len(user.Posts), 1))
throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting"))
@@ -1652,6 +1661,13 @@ func TestRawQueryRow(t *testing.T) {
throwFail(t, AssertIs(pid, nil))
}
+// user_profile table
+type userProfile struct {
+ User
+ Age int
+ Money float64
+}
+
func TestQueryRows(t *testing.T) {
Q := dDbBaser.TableQuote()
@@ -1722,6 +1738,19 @@ func TestQueryRows(t *testing.T) {
throwFailNow(t, AssertIs(usernames[1], "astaxie"))
throwFailNow(t, AssertIs(ids[2], 4))
throwFailNow(t, AssertIs(usernames[2], "nobody"))
+
+ //test query rows by nested struct
+ var l []userProfile
+ query = fmt.Sprintf("SELECT * FROM %suser_profile%s LEFT JOIN %suser%s ON %suser_profile%s.%sid%s = %suser%s.%sid%s", Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q)
+ num, err = dORM.Raw(query).QueryRows(&l)
+ throwFailNow(t, err)
+ throwFailNow(t, AssertIs(num, 2))
+ throwFailNow(t, AssertIs(len(l), 2))
+ throwFailNow(t, AssertIs(l[0].UserName, "slene"))
+ throwFailNow(t, AssertIs(l[0].Age, 28))
+ throwFailNow(t, AssertIs(l[1].UserName, "astaxie"))
+ throwFailNow(t, AssertIs(l[1].Age, 30))
+
}
func TestRawValues(t *testing.T) {
@@ -1974,6 +2003,7 @@ func TestReadOrCreate(t *testing.T) {
created, pk, err := dORM.ReadOrCreate(u, "UserName")
throwFail(t, err)
throwFail(t, AssertIs(created, true))
+ throwFail(t, AssertIs(u.ID, pk))
throwFail(t, AssertIs(u.UserName, "Kyle"))
throwFail(t, AssertIs(u.Email, "kylemcc@gmail.com"))
throwFail(t, AssertIs(u.Password, "other_pass"))
@@ -2128,13 +2158,13 @@ func TestUintPk(t *testing.T) {
Name: name,
}
- created, pk, err := dORM.ReadOrCreate(u, "ID")
+ created, _, err := dORM.ReadOrCreate(u, "ID")
throwFail(t, err)
throwFail(t, AssertIs(created, true))
throwFail(t, AssertIs(u.Name, name))
nu := &UintPk{ID: 8}
- created, pk, err = dORM.ReadOrCreate(nu, "ID")
+ created, pk, err := dORM.ReadOrCreate(nu, "ID")
throwFail(t, err)
throwFail(t, AssertIs(created, false))
throwFail(t, AssertIs(nu.ID, u.ID))
@@ -2144,6 +2174,48 @@ func TestUintPk(t *testing.T) {
dORM.Delete(u)
}
+func TestPtrPk(t *testing.T) {
+ parent := &IntegerPk{ID: 10, Value: "10"}
+
+ id, _ := dORM.Insert(parent)
+ if !IsMysql {
+ // MySql does not support last_insert_id in this case: see #2382
+ throwFail(t, AssertIs(id, 10))
+ }
+
+ ptr := PtrPk{ID: parent, Positive: true}
+ num, err := dORM.InsertMulti(2, []PtrPk{ptr})
+ throwFail(t, err)
+ throwFail(t, AssertIs(num, 1))
+ throwFail(t, AssertIs(ptr.ID, parent))
+
+ nptr := &PtrPk{ID: parent}
+ created, pk, err := dORM.ReadOrCreate(nptr, "ID")
+ throwFail(t, err)
+ throwFail(t, AssertIs(created, false))
+ throwFail(t, AssertIs(pk, 10))
+ throwFail(t, AssertIs(nptr.ID, parent))
+ throwFail(t, AssertIs(nptr.Positive, true))
+
+ nptr = &PtrPk{Positive: true}
+ created, pk, err = dORM.ReadOrCreate(nptr, "Positive")
+ throwFail(t, err)
+ throwFail(t, AssertIs(created, false))
+ throwFail(t, AssertIs(pk, 10))
+ throwFail(t, AssertIs(nptr.ID, parent))
+
+ nptr.Positive = false
+ num, err = dORM.Update(nptr)
+ throwFail(t, err)
+ throwFail(t, AssertIs(num, 1))
+ throwFail(t, AssertIs(nptr.ID, parent))
+ throwFail(t, AssertIs(nptr.Positive, false))
+
+ num, err = dORM.Delete(nptr)
+ throwFail(t, err)
+ throwFail(t, AssertIs(num, 1))
+}
+
func TestSnake(t *testing.T) {
cases := map[string]string{
"i": "i",
diff --git a/orm/types.go b/orm/types.go
index fd3062ab..3e6a9e87 100644
--- a/orm/types.go
+++ b/orm/types.go
@@ -145,6 +145,16 @@ type QuerySeter interface {
// //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
// num, err := qs.SetCond(cond1).Count()
SetCond(*Condition) QuerySeter
+ // get condition from QuerySeter.
+ // sql's where condition
+ // cond := orm.NewCondition()
+ // cond = cond.And("profile__isnull", false).AndNot("status__in", 1)
+ // qs = qs.SetCond(cond)
+ // cond = qs.GetCond()
+ // cond := cond.Or("profile__age__gt", 2000)
+ // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
+ // num, err := qs.SetCond(cond).Count()
+ GetCond() *Condition
// add LIMIT value.
// args[0] means offset, e.g. LIMIT num,offset.
// if Limit <= 0 then Limit will be set to default limit ,eg 1000
diff --git a/orm/utils.go b/orm/utils.go
index bf43ceb0..669d4734 100644
--- a/orm/utils.go
+++ b/orm/utils.go
@@ -92,11 +92,11 @@ func (f StrTo) Int64() (int64, error) {
i := new(big.Int)
ni, ok := i.SetString(f.String(), 10) // octal
if !ok {
- return int64(v), err
+ return v, err
}
return ni.Int64(), nil
}
- return int64(v), err
+ return v, err
}
// Uint string to uint
@@ -130,11 +130,11 @@ func (f StrTo) Uint64() (uint64, error) {
i := new(big.Int)
ni, ok := i.SetString(f.String(), 10)
if !ok {
- return uint64(v), err
+ return v, err
}
return ni.Uint64(), nil
}
- return uint64(v), err
+ return v, err
}
// String string to string
@@ -219,22 +219,17 @@ func snakeString(s string) string {
// camel string, xx_yy to XxYy
func camelString(s string) string {
data := make([]byte, 0, len(s))
- j := false
- k := false
- num := len(s) - 1
+ flag, num := true, len(s)-1
for i := 0; i <= num; i++ {
d := s[i]
- if k == false && d >= 'A' && d <= 'Z' {
- k = true
- }
- if d >= 'a' && d <= 'z' && (j || k == false) {
- d = d - 32
- j = false
- k = true
- }
- if k && d == '_' && num > i && s[i+1] >= 'a' && s[i+1] <= 'z' {
- j = true
+ if d == '_' {
+ flag = true
continue
+ } else if flag {
+ if d >= 'a' && d <= 'z' {
+ d = d - 32
+ }
+ flag = false
}
data = append(data, d)
}
diff --git a/orm/utils_test.go b/orm/utils_test.go
new file mode 100644
index 00000000..8c7c5008
--- /dev/null
+++ b/orm/utils_test.go
@@ -0,0 +1,36 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package orm
+
+import (
+ "testing"
+)
+
+func TestCamelString(t *testing.T) {
+ snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"}
+ camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"}
+
+ answer := make(map[string]string)
+ for i, v := range snake {
+ answer[v] = camel[i]
+ }
+
+ for _, v := range snake {
+ res := camelString(v)
+ if res != answer[v] {
+ t.Error("Unit Test Fail:", v, res, answer[v])
+ }
+ }
+}
diff --git a/parser.go b/parser.go
index d40ee3ce..1933c6c6 100644
--- a/parser.go
+++ b/parser.go
@@ -24,9 +24,13 @@ import (
"io/ioutil"
"os"
"path/filepath"
+ "regexp"
"sort"
+ "strconv"
"strings"
+ "unicode"
+ "github.com/astaxie/beego/context/param"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/utils"
)
@@ -35,6 +39,7 @@ var globalRouterTemplate = `package routers
import (
"github.com/astaxie/beego"
+ "github.com/astaxie/beego/context/param"
)
func init() {
@@ -81,7 +86,7 @@ func parserPkg(pkgRealpath, pkgpath string) error {
if specDecl.Recv != nil {
exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser
if ok {
- parserComments(specDecl.Doc, specDecl.Name.String(), fmt.Sprint(exp.X), pkgpath)
+ parserComments(specDecl, fmt.Sprint(exp.X), pkgpath)
}
}
}
@@ -93,44 +98,170 @@ func parserPkg(pkgRealpath, pkgpath string) error {
return nil
}
-func parserComments(comments *ast.CommentGroup, funcName, controllerName, pkgpath string) error {
- if comments != nil && comments.List != nil {
- for _, c := range comments.List {
- t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
- if strings.HasPrefix(t, "@router") {
- elements := strings.TrimLeft(t, "@router ")
- e1 := strings.SplitN(elements, " ", 2)
- if len(e1) < 1 {
- return errors.New("you should has router information")
- }
- key := pkgpath + ":" + controllerName
- cc := ControllerComments{}
- cc.Method = funcName
- cc.Router = e1[0]
- if len(e1) == 2 && e1[1] != "" {
- e1 = strings.SplitN(e1[1], " ", 2)
- if len(e1) >= 1 {
- cc.AllowHTTPMethods = strings.Split(strings.Trim(e1[0], "[]"), ",")
- } else {
- cc.AllowHTTPMethods = append(cc.AllowHTTPMethods, "get")
- }
- } else {
- cc.AllowHTTPMethods = append(cc.AllowHTTPMethods, "get")
- }
- if len(e1) == 2 && e1[1] != "" {
- keyval := strings.Split(strings.Trim(e1[1], "[]"), " ")
- for _, kv := range keyval {
- kk := strings.Split(kv, ":")
- cc.Params = append(cc.Params, map[string]string{strings.Join(kk[:len(kk)-1], ":"): kk[len(kk)-1]})
- }
- }
- genInfoList[key] = append(genInfoList[key], cc)
- }
+type parsedComment struct {
+ routerPath string
+ methods []string
+ params map[string]parsedParam
+}
+
+type parsedParam struct {
+ name string
+ datatype string
+ location string
+ defValue string
+ required bool
+}
+
+func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error {
+ if f.Doc != nil {
+ parsedComment, err := parseComment(f.Doc.List)
+ if err != nil {
+ return err
}
+ if parsedComment.routerPath != "" {
+ key := pkgpath + ":" + controllerName
+ cc := ControllerComments{}
+ cc.Method = f.Name.String()
+ cc.Router = parsedComment.routerPath
+ cc.AllowHTTPMethods = parsedComment.methods
+ cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment)
+ genInfoList[key] = append(genInfoList[key], cc)
+ }
+
}
return nil
}
+func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam {
+ result := make([]*param.MethodParam, 0, len(funcParams))
+ for _, fparam := range funcParams {
+ for _, pName := range fparam.Names {
+ methodParam := buildMethodParam(fparam, pName.Name, pc)
+ result = append(result, methodParam)
+ }
+ }
+ return result
+}
+
+func buildMethodParam(fparam *ast.Field, name string, pc *parsedComment) *param.MethodParam {
+ options := []param.MethodParamOption{}
+ if cparam, ok := pc.params[name]; ok {
+ //Build param from comment info
+ name = cparam.name
+ if cparam.required {
+ options = append(options, param.IsRequired)
+ }
+ switch cparam.location {
+ case "body":
+ options = append(options, param.InBody)
+ case "header":
+ options = append(options, param.InHeader)
+ case "path":
+ options = append(options, param.InPath)
+ }
+ if cparam.defValue != "" {
+ options = append(options, param.Default(cparam.defValue))
+ }
+ } else {
+ if paramInPath(name, pc.routerPath) {
+ options = append(options, param.InPath)
+ }
+ }
+ return param.New(name, options...)
+}
+
+func paramInPath(name, route string) bool {
+ return strings.HasSuffix(route, ":"+name) ||
+ strings.Contains(route, ":"+name+"/")
+}
+
+var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`)
+
+func parseComment(lines []*ast.Comment) (pc *parsedComment, err error) {
+ pc = &parsedComment{}
+ for _, c := range lines {
+ t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
+ if strings.HasPrefix(t, "@router") {
+ matches := routeRegex.FindStringSubmatch(t)
+ if len(matches) == 3 {
+ pc.routerPath = matches[1]
+ methods := matches[2]
+ if methods == "" {
+ pc.methods = []string{"get"}
+ //pc.hasGet = true
+ } else {
+ pc.methods = strings.Split(methods, ",")
+ //pc.hasGet = strings.Contains(methods, "get")
+ }
+ } else {
+ return nil, errors.New("Router information is missing")
+ }
+ } else if strings.HasPrefix(t, "@Param") {
+ pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param")))
+ if len(pv) < 4 {
+ logs.Error("Invalid @Param format. Needs at least 4 parameters")
+ }
+ p := parsedParam{}
+ names := strings.SplitN(pv[0], "=>", 2)
+ p.name = names[0]
+ funcParamName := p.name
+ if len(names) > 1 {
+ funcParamName = names[1]
+ }
+ p.location = pv[1]
+ p.datatype = pv[2]
+ switch len(pv) {
+ case 5:
+ p.required, _ = strconv.ParseBool(pv[3])
+ case 6:
+ p.defValue = pv[3]
+ p.required, _ = strconv.ParseBool(pv[4])
+ }
+ if pc.params == nil {
+ pc.params = map[string]parsedParam{}
+ }
+ pc.params[funcParamName] = p
+ }
+ }
+ return
+}
+
+// direct copy from bee\g_docs.go
+// analysis params return []string
+// @Param query form string true "The email for login"
+// [query form string true "The email for login"]
+func getparams(str string) []string {
+ var s []rune
+ var j int
+ var start bool
+ var r []string
+ var quoted int8
+ for _, c := range str {
+ if unicode.IsSpace(c) && quoted == 0 {
+ if !start {
+ continue
+ } else {
+ start = false
+ j++
+ r = append(r, string(s))
+ s = make([]rune, 0)
+ continue
+ }
+ }
+
+ start = true
+ if c == '"' {
+ quoted ^= 1
+ continue
+ }
+ s = append(s, c)
+ }
+ if len(s) > 0 {
+ r = append(r, string(s))
+ }
+ return r
+}
+
func genRouterCode(pkgRealpath string) {
os.Mkdir(getRouterDir(pkgRealpath), 0755)
logs.Info("generate router from comments")
@@ -144,6 +275,7 @@ func genRouterCode(pkgRealpath string) {
sort.Strings(sortKey)
for _, k := range sortKey {
cList := genInfoList[k]
+ sort.Sort(ControllerCommentsSlice(cList))
for _, c := range cList {
allmethod := "nil"
if len(c.AllowHTTPMethods) > 0 {
@@ -163,12 +295,24 @@ func genRouterCode(pkgRealpath string) {
}
params = strings.TrimRight(params, ",") + "}"
}
+ methodParams := "param.Make("
+ if len(c.MethodParams) > 0 {
+ lines := make([]string, 0, len(c.MethodParams))
+ for _, m := range c.MethodParams {
+ lines = append(lines, fmt.Sprint(m))
+ }
+ methodParams += "\n " +
+ strings.Join(lines, ",\n ") +
+ ",\n "
+ }
+ methodParams += ")"
globalinfo = globalinfo + `
beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"],
beego.ControllerComments{
Method: "` + strings.TrimSpace(c.Method) + `",
` + "Router: `" + c.Router + "`" + `,
AllowHTTPMethods: ` + allmethod + `,
+ MethodParams: ` + methodParams + `,
Params: ` + params + `})
`
}
diff --git a/plugins/apiauth/apiauth.go b/plugins/apiauth/apiauth.go
index 10636d1c..f816029c 100644
--- a/plugins/apiauth/apiauth.go
+++ b/plugins/apiauth/apiauth.go
@@ -56,6 +56,7 @@
package apiauth
import (
+ "bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
@@ -128,53 +129,32 @@ func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc {
// Signature used to generate signature with the appsecret/method/params/RequestURI
func Signature(appsecret, method string, params url.Values, RequestURL string) (result string) {
- var query string
+ var b bytes.Buffer
+ keys := make([]string, len(params))
pa := make(map[string]string)
for k, v := range params {
pa[k] = v[0]
+ keys = append(keys, k)
}
- vs := mapSorter(pa)
- vs.Sort()
- for i := 0; i < vs.Len(); i++ {
- if vs.Keys[i] == "signature" {
+
+ sort.Strings(keys)
+
+ for _, key := range keys {
+ if key == "signature" {
continue
}
- if vs.Keys[i] != "" && vs.Vals[i] != "" {
- query = fmt.Sprintf("%v%v%v", query, vs.Keys[i], vs.Vals[i])
+
+ val := pa[key]
+ if key != "" && val != "" {
+ b.WriteString(key)
+ b.WriteString(val)
}
}
- stringToSign := fmt.Sprintf("%v\n%v\n%v\n", method, query, RequestURL)
+
+ stringToSign := fmt.Sprintf("%v\n%v\n%v\n", method, b.String(), RequestURL)
sha256 := sha256.New
hash := hmac.New(sha256, []byte(appsecret))
hash.Write([]byte(stringToSign))
return base64.StdEncoding.EncodeToString(hash.Sum(nil))
}
-
-type valSorter struct {
- Keys []string
- Vals []string
-}
-
-func mapSorter(m map[string]string) *valSorter {
- vs := &valSorter{
- Keys: make([]string, 0, len(m)),
- Vals: make([]string, 0, len(m)),
- }
- for k, v := range m {
- vs.Keys = append(vs.Keys, k)
- vs.Vals = append(vs.Vals, v)
- }
- return vs
-}
-
-func (vs *valSorter) Sort() {
- sort.Sort(vs)
-}
-
-func (vs *valSorter) Len() int { return len(vs.Keys) }
-func (vs *valSorter) Less(i, j int) bool { return vs.Keys[i] < vs.Keys[j] }
-func (vs *valSorter) Swap(i, j int) {
- vs.Vals[i], vs.Vals[j] = vs.Vals[j], vs.Vals[i]
- vs.Keys[i], vs.Keys[j] = vs.Keys[j], vs.Keys[i]
-}
diff --git a/plugins/apiauth/apiauth_test.go b/plugins/apiauth/apiauth_test.go
new file mode 100644
index 00000000..1f56cb0f
--- /dev/null
+++ b/plugins/apiauth/apiauth_test.go
@@ -0,0 +1,20 @@
+package apiauth
+
+import (
+ "net/url"
+ "testing"
+)
+
+func TestSignature(t *testing.T) {
+ appsecret := "beego secret"
+ method := "GET"
+ RequestURL := "http://localhost/test/url"
+ params := make(url.Values)
+ params.Add("arg1", "hello")
+ params.Add("arg2", "beego")
+
+ signature := "mFdpvLh48ca4mDVEItE9++AKKQ/IVca7O/ZyyB8hR58="
+ if Signature(appsecret, method, params, RequestURL) != signature {
+ t.Error("Signature error")
+ }
+}
diff --git a/plugins/authz/authz.go b/plugins/authz/authz.go
new file mode 100644
index 00000000..9dc0db76
--- /dev/null
+++ b/plugins/authz/authz.go
@@ -0,0 +1,86 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package authz provides handlers to enable ACL, RBAC, ABAC authorization support.
+// Simple Usage:
+// import(
+// "github.com/astaxie/beego"
+// "github.com/astaxie/beego/plugins/authz"
+// "github.com/casbin/casbin"
+// )
+//
+// func main(){
+// // mediate the access for every request
+// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
+// beego.Run()
+// }
+//
+//
+// Advanced Usage:
+//
+// func main(){
+// e := casbin.NewEnforcer("authz_model.conf", "")
+// e.AddRoleForUser("alice", "admin")
+// e.AddPolicy(...)
+//
+// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(e))
+// beego.Run()
+// }
+package authz
+
+import (
+ "github.com/astaxie/beego"
+ "github.com/astaxie/beego/context"
+ "github.com/casbin/casbin"
+ "net/http"
+)
+
+// NewAuthorizer returns the authorizer.
+// Use a casbin enforcer as input
+func NewAuthorizer(e *casbin.Enforcer) beego.FilterFunc {
+ return func(ctx *context.Context) {
+ a := &BasicAuthorizer{enforcer: e}
+
+ if !a.CheckPermission(ctx.Request) {
+ a.RequirePermission(ctx.ResponseWriter)
+ }
+ }
+}
+
+// BasicAuthorizer stores the casbin handler
+type BasicAuthorizer struct {
+ enforcer *casbin.Enforcer
+}
+
+// GetUserName gets the user name from the request.
+// Currently, only HTTP basic authentication is supported
+func (a *BasicAuthorizer) GetUserName(r *http.Request) string {
+ username, _, _ := r.BasicAuth()
+ return username
+}
+
+// CheckPermission checks the user/method/path combination from the request.
+// Returns true (permission granted) or false (permission forbidden)
+func (a *BasicAuthorizer) CheckPermission(r *http.Request) bool {
+ user := a.GetUserName(r)
+ method := r.Method
+ path := r.URL.Path
+ return a.enforcer.Enforce(user, path, method)
+}
+
+// RequirePermission returns the 403 Forbidden to the client
+func (a *BasicAuthorizer) RequirePermission(w http.ResponseWriter) {
+ w.WriteHeader(403)
+ w.Write([]byte("403 Forbidden\n"))
+}
diff --git a/plugins/authz/authz_model.conf b/plugins/authz/authz_model.conf
new file mode 100644
index 00000000..d1b3dbd7
--- /dev/null
+++ b/plugins/authz/authz_model.conf
@@ -0,0 +1,14 @@
+[request_definition]
+r = sub, obj, act
+
+[policy_definition]
+p = sub, obj, act
+
+[role_definition]
+g = _, _
+
+[policy_effect]
+e = some(where (p.eft == allow))
+
+[matchers]
+m = g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*")
\ No newline at end of file
diff --git a/plugins/authz/authz_policy.csv b/plugins/authz/authz_policy.csv
new file mode 100644
index 00000000..c062dd3e
--- /dev/null
+++ b/plugins/authz/authz_policy.csv
@@ -0,0 +1,7 @@
+p, alice, /dataset1/*, GET
+p, alice, /dataset1/resource1, POST
+p, bob, /dataset2/resource1, *
+p, bob, /dataset2/resource2, GET
+p, bob, /dataset2/folder1/*, POST
+p, dataset1_admin, /dataset1/*, *
+g, cathy, dataset1_admin
\ No newline at end of file
diff --git a/plugins/authz/authz_test.go b/plugins/authz/authz_test.go
new file mode 100644
index 00000000..49aed84c
--- /dev/null
+++ b/plugins/authz/authz_test.go
@@ -0,0 +1,107 @@
+// Copyright 2014 beego Author. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package authz
+
+import (
+ "github.com/astaxie/beego"
+ "github.com/astaxie/beego/context"
+ "github.com/astaxie/beego/plugins/auth"
+ "github.com/casbin/casbin"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, path string, method string, code int) {
+ r, _ := http.NewRequest(method, path, nil)
+ r.SetBasicAuth(user, "123")
+ w := httptest.NewRecorder()
+ handler.ServeHTTP(w, r)
+
+ if w.Code != code {
+ t.Errorf("%s, %s, %s: %d, supposed to be %d", user, path, method, w.Code, code)
+ }
+}
+
+func TestBasic(t *testing.T) {
+ handler := beego.NewControllerRegister()
+
+ handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123"))
+ handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
+
+ handler.Any("*", func(ctx *context.Context) {
+ ctx.Output.SetStatus(200)
+ })
+
+ testRequest(t, handler, "alice", "/dataset1/resource1", "GET", 200)
+ testRequest(t, handler, "alice", "/dataset1/resource1", "POST", 200)
+ testRequest(t, handler, "alice", "/dataset1/resource2", "GET", 200)
+ testRequest(t, handler, "alice", "/dataset1/resource2", "POST", 403)
+}
+
+func TestPathWildcard(t *testing.T) {
+ handler := beego.NewControllerRegister()
+
+ handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123"))
+ handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
+
+ handler.Any("*", func(ctx *context.Context) {
+ ctx.Output.SetStatus(200)
+ })
+
+ testRequest(t, handler, "bob", "/dataset2/resource1", "GET", 200)
+ testRequest(t, handler, "bob", "/dataset2/resource1", "POST", 200)
+ testRequest(t, handler, "bob", "/dataset2/resource1", "DELETE", 200)
+ testRequest(t, handler, "bob", "/dataset2/resource2", "GET", 200)
+ testRequest(t, handler, "bob", "/dataset2/resource2", "POST", 403)
+ testRequest(t, handler, "bob", "/dataset2/resource2", "DELETE", 403)
+
+ testRequest(t, handler, "bob", "/dataset2/folder1/item1", "GET", 403)
+ testRequest(t, handler, "bob", "/dataset2/folder1/item1", "POST", 200)
+ testRequest(t, handler, "bob", "/dataset2/folder1/item1", "DELETE", 403)
+ testRequest(t, handler, "bob", "/dataset2/folder1/item2", "GET", 403)
+ testRequest(t, handler, "bob", "/dataset2/folder1/item2", "POST", 200)
+ testRequest(t, handler, "bob", "/dataset2/folder1/item2", "DELETE", 403)
+}
+
+func TestRBAC(t *testing.T) {
+ handler := beego.NewControllerRegister()
+
+ handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123"))
+ e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")
+ handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e))
+
+ handler.Any("*", func(ctx *context.Context) {
+ ctx.Output.SetStatus(200)
+ })
+
+ // cathy can access all /dataset1/* resources via all methods because it has the dataset1_admin role.
+ testRequest(t, handler, "cathy", "/dataset1/item", "GET", 200)
+ testRequest(t, handler, "cathy", "/dataset1/item", "POST", 200)
+ testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 200)
+ testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403)
+ testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403)
+ testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403)
+
+ // delete all roles on user cathy, so cathy cannot access any resources now.
+ e.DeleteRolesForUser("cathy")
+
+ testRequest(t, handler, "cathy", "/dataset1/item", "GET", 403)
+ testRequest(t, handler, "cathy", "/dataset1/item", "POST", 403)
+ testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 403)
+ testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403)
+ testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403)
+ testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403)
+}
diff --git a/policy.go b/policy.go
index 2b91fdcc..ab23f927 100644
--- a/policy.go
+++ b/policy.go
@@ -23,7 +23,7 @@ import (
// PolicyFunc defines a policy function which is invoked before the controller handler is executed.
type PolicyFunc func(*context.Context)
-// FindRouter Find Router info for URL
+// FindPolicy Find Router info for URL
func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc {
var urlPath = cont.Input.URL()
if !BConfig.RouterCaseSensitive {
@@ -71,7 +71,7 @@ func (p *ControllerRegister) addToPolicy(method, pattern string, r ...PolicyFunc
}
}
-// Register new policy in beego
+// Policy Register new policy in beego
func Policy(pattern, method string, policy ...PolicyFunc) {
BeeApp.Handlers.addToPolicy(method, pattern, policy...)
}
diff --git a/router.go b/router.go
index 74cf02a1..e5a4e80d 100644
--- a/router.go
+++ b/router.go
@@ -17,7 +17,6 @@ package beego
import (
"fmt"
"net/http"
- "os"
"path"
"path/filepath"
"reflect"
@@ -28,6 +27,7 @@ import (
"time"
beecontext "github.com/astaxie/beego/context"
+ "github.com/astaxie/beego/context/param"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/toolbox"
"github.com/astaxie/beego/utils"
@@ -51,15 +51,22 @@ const (
var (
// HTTPMETHOD list the supported http methods.
HTTPMETHOD = map[string]string{
- "GET": "GET",
- "POST": "POST",
- "PUT": "PUT",
- "DELETE": "DELETE",
- "PATCH": "PATCH",
- "OPTIONS": "OPTIONS",
- "HEAD": "HEAD",
- "TRACE": "TRACE",
- "CONNECT": "CONNECT",
+ "GET": "GET",
+ "POST": "POST",
+ "PUT": "PUT",
+ "DELETE": "DELETE",
+ "PATCH": "PATCH",
+ "OPTIONS": "OPTIONS",
+ "HEAD": "HEAD",
+ "TRACE": "TRACE",
+ "CONNECT": "CONNECT",
+ "MKCOL": "MKCOL",
+ "COPY": "COPY",
+ "MOVE": "MOVE",
+ "PROPFIND": "PROPFIND",
+ "PROPPATCH": "PROPPATCH",
+ "LOCK": "LOCK",
+ "UNLOCK": "UNLOCK",
}
// these beego.Controller's methods shouldn't reflect to AutoRouter
exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString",
@@ -102,13 +109,15 @@ func ExceptMethodAppend(action string) {
exceptMethod = append(exceptMethod, action)
}
-type controllerInfo struct {
+// ControllerInfo holds information about the controller.
+type ControllerInfo struct {
pattern string
controllerType reflect.Type
methods map[string]string
handler http.Handler
runFunction FilterFunc
routerType int
+ methodParams []*param.MethodParam
}
// ControllerRegister containers registered router rules, controller handlers and filters.
@@ -144,6 +153,10 @@ func NewControllerRegister() *ControllerRegister {
// Add("/api",&RestController{},"get,post:ApiFunc"
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
+ p.addWithMethodParams(pattern, c, nil, mappingMethods...)
+}
+
+func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) {
reflectVal := reflect.ValueOf(c)
t := reflect.Indirect(reflectVal).Type()
methods := make(map[string]string)
@@ -169,11 +182,12 @@ func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingM
}
}
- route := &controllerInfo{}
+ route := &ControllerInfo{}
route.pattern = pattern
route.methods = methods
route.routerType = routerTypeBeego
route.controllerType = t
+ route.methodParams = methodParams
if len(methods) == 0 {
for _, m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
@@ -191,7 +205,7 @@ func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingM
}
}
-func (p *ControllerRegister) addToRouter(method, pattern string, r *controllerInfo) {
+func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) {
if !BConfig.RouterCaseSensitive {
pattern = strings.ToLower(pattern)
}
@@ -212,13 +226,11 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) {
for _, c := range cList {
reflectVal := reflect.ValueOf(c)
t := reflect.Indirect(reflectVal).Type()
- gopath := os.Getenv("GOPATH")
- if gopath == "" {
+ wgopath := utils.GetGOPATHs()
+ if len(wgopath) == 0 {
panic("you are in dev mode. So please set gopath")
}
pkgpath := ""
-
- wgopath := filepath.SplitList(gopath)
for _, wg := range wgopath {
wg, _ = filepath.EvalSymlinks(filepath.Join(wg, "src", t.PkgPath()))
if utils.FileExists(wg) {
@@ -240,7 +252,7 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) {
key := t.PkgPath() + ":" + t.Name()
if comm, ok := GlobalControllerRouter[key]; ok {
for _, a := range comm {
- p.Add(a.Router, c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)
+ p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)
}
}
}
@@ -328,7 +340,7 @@ func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
if _, ok := HTTPMETHOD[method]; method != "*" && !ok {
panic("not support http method: " + method)
}
- route := &controllerInfo{}
+ route := &ControllerInfo{}
route.pattern = pattern
route.routerType = routerTypeRESTFul
route.runFunction = f
@@ -354,7 +366,7 @@ func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) {
// Handler add user defined Handler
func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) {
- route := &controllerInfo{}
+ route := &ControllerInfo{}
route.pattern = pattern
route.routerType = routerTypeHandler
route.handler = h
@@ -389,7 +401,7 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface)
controllerName := strings.TrimSuffix(ct.Name(), "Controller")
for i := 0; i < rt.NumMethod(); i++ {
if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
- route := &controllerInfo{}
+ route := &ControllerInfo{}
route.routerType = routerTypeBeego
route.methods = map[string]string{"*": rt.Method(i).Name}
route.controllerType = ct
@@ -495,7 +507,7 @@ func (p *ControllerRegister) geturl(t *Tree, url, controllName, methodName strin
}
}
for _, l := range t.leaves {
- if c, ok := l.runObject.(*controllerInfo); ok {
+ if c, ok := l.runObject.(*ControllerInfo); ok {
if c.routerType == routerTypeBeego &&
strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllName) {
find := false
@@ -619,11 +631,12 @@ func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath str
func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
startTime := time.Now()
var (
- runRouter reflect.Type
- findRouter bool
- runMethod string
- routerInfo *controllerInfo
- isRunnable bool
+ runRouter reflect.Type
+ findRouter bool
+ runMethod string
+ methodParams []*param.MethodParam
+ routerInfo *ControllerInfo
+ isRunnable bool
)
context := p.pool.Get().(*beecontext.Context)
context.Reset(rw, r)
@@ -663,7 +676,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
goto Admin
}
- if r.Method != "GET" && r.Method != "HEAD" {
+ if r.Method != http.MethodGet && r.Method != http.MethodHead {
if BConfig.CopyRequestBody && !context.Input.IsUpload() {
context.Input.CopyBody(BConfig.MaxMemory)
}
@@ -691,7 +704,6 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
// User can define RunController and RunMethod in filter
if context.Input.RunController != nil && context.Input.RunMethod != "" {
findRouter = true
- isRunnable = true
runMethod = context.Input.RunMethod
runRouter = context.Input.RunController
} else {
@@ -735,12 +747,13 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
routerInfo.handler.ServeHTTP(rw, r)
} else {
runRouter = routerInfo.controllerType
+ methodParams = routerInfo.methodParams
method := r.Method
- if r.Method == "POST" && context.Input.Query("_method") == "PUT" {
- method = "PUT"
+ if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPost {
+ method = http.MethodPut
}
- if r.Method == "POST" && context.Input.Query("_method") == "DELETE" {
- method = "DELETE"
+ if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodDelete {
+ method = http.MethodDelete
}
if m, ok := routerInfo.methods[method]; ok {
runMethod = m
@@ -770,8 +783,8 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
//if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf
if BConfig.WebConfig.EnableXSRF {
execController.XSRFToken()
- if r.Method == "POST" || r.Method == "DELETE" || r.Method == "PUT" ||
- (r.Method == "POST" && (context.Input.Query("_method") == "DELETE" || context.Input.Query("_method") == "PUT")) {
+ if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut ||
+ (r.Method == http.MethodPost && (context.Input.Query("_method") == http.MethodDelete || context.Input.Query("_method") == http.MethodPut)) {
execController.CheckXSRFCookie()
}
}
@@ -781,25 +794,30 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
if !context.ResponseWriter.Started {
//exec main logic
switch runMethod {
- case "GET":
+ case http.MethodGet:
execController.Get()
- case "POST":
+ case http.MethodPost:
execController.Post()
- case "DELETE":
+ case http.MethodDelete:
execController.Delete()
- case "PUT":
+ case http.MethodPut:
execController.Put()
- case "HEAD":
+ case http.MethodHead:
execController.Head()
- case "PATCH":
+ case http.MethodPatch:
execController.Patch()
- case "OPTIONS":
+ case http.MethodOptions:
execController.Options()
default:
if !execController.HandlerFunc(runMethod) {
- var in []reflect.Value
method := vc.MethodByName(runMethod)
- method.Call(in)
+ in := param.ConvertParams(methodParams, method.Type(), context)
+ out := method.Call(in)
+
+ //For backward compatibility we only handle response if we had incoming methodParams
+ if methodParams != nil {
+ p.handleParamResponse(context, execController, out)
+ }
}
}
@@ -830,7 +848,15 @@ Admin:
//admin module record QPS
if BConfig.Listen.EnableAdmin {
timeDur := time.Since(startTime)
- if FilterMonitorFunc(r.Method, r.URL.Path, timeDur) {
+ pattern := ""
+ if routerInfo != nil {
+ pattern = routerInfo.pattern
+ }
+ statusCode := context.ResponseWriter.Status
+ if statusCode == 0 {
+ statusCode = 200
+ }
+ if FilterMonitorFunc(r.Method, r.URL.Path, timeDur, pattern, statusCode) {
if runRouter != nil {
go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, runRouter.Name(), timeDur)
} else {
@@ -879,8 +905,22 @@ Admin:
}
}
+func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) {
+ //looping in reverse order for the case when both error and value are returned and error sets the response status code
+ for i := len(results) - 1; i >= 0; i-- {
+ result := results[i]
+ if result.Kind() != reflect.Interface || !result.IsNil() {
+ resultValue := result.Interface()
+ context.RenderMethodResult(resultValue)
+ }
+ }
+ if !context.ResponseWriter.Started && context.Output.Status == 0 {
+ context.Output.SetStatus(200)
+ }
+}
+
// FindRouter Find Router info for URL
-func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *controllerInfo, isFind bool) {
+func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) {
var urlPath = context.Input.URL()
if !BConfig.RouterCaseSensitive {
urlPath = strings.ToLower(urlPath)
@@ -888,7 +928,7 @@ func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo
httpMethod := context.Input.Method()
if t, ok := p.routers[httpMethod]; ok {
runObject := t.Match(urlPath, context)
- if r, ok := runObject.(*controllerInfo); ok {
+ if r, ok := runObject.(*ControllerInfo); ok {
return r, true
}
}
diff --git a/router_test.go b/router_test.go
index 936fd5e8..720b4ca8 100644
--- a/router_test.go
+++ b/router_test.go
@@ -502,10 +502,10 @@ func TestFilterBeforeRouter(t *testing.T) {
rw, r := testRequest("GET", url)
mux.ServeHTTP(rw, r)
- if strings.Contains(rw.Body.String(), "BeforeRouter1") == false {
+ if !strings.Contains(rw.Body.String(), "BeforeRouter1") {
t.Errorf(testName + " BeforeRouter did not run")
}
- if strings.Contains(rw.Body.String(), "hello") == true {
+ if strings.Contains(rw.Body.String(), "hello") {
t.Errorf(testName + " BeforeRouter did not return properly")
}
}
@@ -525,13 +525,13 @@ func TestFilterBeforeExec(t *testing.T) {
rw, r := testRequest("GET", url)
mux.ServeHTTP(rw, r)
- if strings.Contains(rw.Body.String(), "BeforeExec1") == false {
+ if !strings.Contains(rw.Body.String(), "BeforeExec1") {
t.Errorf(testName + " BeforeExec did not run")
}
- if strings.Contains(rw.Body.String(), "hello") == true {
+ if strings.Contains(rw.Body.String(), "hello") {
t.Errorf(testName + " BeforeExec did not return properly")
}
- if strings.Contains(rw.Body.String(), "BeforeRouter") == true {
+ if strings.Contains(rw.Body.String(), "BeforeRouter") {
t.Errorf(testName + " BeforeRouter ran in error")
}
}
@@ -552,16 +552,16 @@ func TestFilterAfterExec(t *testing.T) {
rw, r := testRequest("GET", url)
mux.ServeHTTP(rw, r)
- if strings.Contains(rw.Body.String(), "AfterExec1") == false {
+ if !strings.Contains(rw.Body.String(), "AfterExec1") {
t.Errorf(testName + " AfterExec did not run")
}
- if strings.Contains(rw.Body.String(), "hello") == false {
+ if !strings.Contains(rw.Body.String(), "hello") {
t.Errorf(testName + " handler did not run properly")
}
- if strings.Contains(rw.Body.String(), "BeforeRouter") == true {
+ if strings.Contains(rw.Body.String(), "BeforeRouter") {
t.Errorf(testName + " BeforeRouter ran in error")
}
- if strings.Contains(rw.Body.String(), "BeforeExec") == true {
+ if strings.Contains(rw.Body.String(), "BeforeExec") {
t.Errorf(testName + " BeforeExec ran in error")
}
}
@@ -583,19 +583,19 @@ func TestFilterFinishRouter(t *testing.T) {
rw, r := testRequest("GET", url)
mux.ServeHTTP(rw, r)
- if strings.Contains(rw.Body.String(), "FinishRouter1") == true {
+ if strings.Contains(rw.Body.String(), "FinishRouter1") {
t.Errorf(testName + " FinishRouter did not run")
}
- if strings.Contains(rw.Body.String(), "hello") == false {
+ if !strings.Contains(rw.Body.String(), "hello") {
t.Errorf(testName + " handler did not run properly")
}
- if strings.Contains(rw.Body.String(), "AfterExec1") == true {
+ if strings.Contains(rw.Body.String(), "AfterExec1") {
t.Errorf(testName + " AfterExec ran in error")
}
- if strings.Contains(rw.Body.String(), "BeforeRouter") == true {
+ if strings.Contains(rw.Body.String(), "BeforeRouter") {
t.Errorf(testName + " BeforeRouter ran in error")
}
- if strings.Contains(rw.Body.String(), "BeforeExec") == true {
+ if strings.Contains(rw.Body.String(), "BeforeExec") {
t.Errorf(testName + " BeforeExec ran in error")
}
}
@@ -615,14 +615,14 @@ func TestFilterFinishRouterMultiFirstOnly(t *testing.T) {
rw, r := testRequest("GET", url)
mux.ServeHTTP(rw, r)
- if strings.Contains(rw.Body.String(), "FinishRouter1") == false {
+ if !strings.Contains(rw.Body.String(), "FinishRouter1") {
t.Errorf(testName + " FinishRouter1 did not run")
}
- if strings.Contains(rw.Body.String(), "hello") == false {
+ if !strings.Contains(rw.Body.String(), "hello") {
t.Errorf(testName + " handler did not run properly")
}
// not expected in body
- if strings.Contains(rw.Body.String(), "FinishRouter2") == true {
+ if strings.Contains(rw.Body.String(), "FinishRouter2") {
t.Errorf(testName + " FinishRouter2 did run")
}
}
@@ -642,44 +642,52 @@ func TestFilterFinishRouterMulti(t *testing.T) {
rw, r := testRequest("GET", url)
mux.ServeHTTP(rw, r)
- if strings.Contains(rw.Body.String(), "FinishRouter1") == false {
+ if !strings.Contains(rw.Body.String(), "FinishRouter1") {
t.Errorf(testName + " FinishRouter1 did not run")
}
- if strings.Contains(rw.Body.String(), "hello") == false {
+ if !strings.Contains(rw.Body.String(), "hello") {
t.Errorf(testName + " handler did not run properly")
}
- if strings.Contains(rw.Body.String(), "FinishRouter2") == false {
+ if !strings.Contains(rw.Body.String(), "FinishRouter2") {
t.Errorf(testName + " FinishRouter2 did not run properly")
}
}
func beegoFilterNoOutput(ctx *context.Context) {
- return
}
+
func beegoBeforeRouter1(ctx *context.Context) {
ctx.WriteString("|BeforeRouter1")
}
+
func beegoBeforeRouter2(ctx *context.Context) {
ctx.WriteString("|BeforeRouter2")
}
+
func beegoBeforeExec1(ctx *context.Context) {
ctx.WriteString("|BeforeExec1")
}
+
func beegoBeforeExec2(ctx *context.Context) {
ctx.WriteString("|BeforeExec2")
}
+
func beegoAfterExec1(ctx *context.Context) {
ctx.WriteString("|AfterExec1")
}
+
func beegoAfterExec2(ctx *context.Context) {
ctx.WriteString("|AfterExec2")
}
+
func beegoFinishRouter1(ctx *context.Context) {
ctx.WriteString("|FinishRouter1")
}
+
func beegoFinishRouter2(ctx *context.Context) {
ctx.WriteString("|FinishRouter2")
}
+
func beegoResetParams(ctx *context.Context) {
ctx.ResponseWriter.Header().Set("splat", ctx.Input.Param(":splat"))
}
diff --git a/session/couchbase/sess_couchbase.go b/session/couchbase/sess_couchbase.go
index d5be11d0..707d042c 100644
--- a/session/couchbase/sess_couchbase.go
+++ b/session/couchbase/sess_couchbase.go
@@ -155,11 +155,16 @@ func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error {
func (cp *Provider) SessionRead(sid string) (session.Store, error) {
cp.b = cp.getBucket()
- var doc []byte
+ var (
+ kv map[interface{}]interface{}
+ err error
+ doc []byte
+ )
- err := cp.b.Get(sid, &doc)
- var kv map[interface{}]interface{}
- if doc == nil {
+ err = cp.b.Get(sid, &doc)
+ if err != nil {
+ return nil, err
+ } else if doc == nil {
kv = make(map[interface{}]interface{})
} else {
kv, err = session.DecodeGob(doc)
@@ -230,7 +235,6 @@ func (cp *Provider) SessionDestroy(sid string) error {
// SessionGC Recycle
func (cp *Provider) SessionGC() {
- return
}
// SessionAll return all active session
diff --git a/session/ledis/ledis_session.go b/session/ledis/ledis_session.go
index 68f37b08..77685d1e 100644
--- a/session/ledis/ledis_session.go
+++ b/session/ledis/ledis_session.go
@@ -12,8 +12,10 @@ import (
"github.com/siddontang/ledisdb/ledis"
)
-var ledispder = &Provider{}
-var c *ledis.DB
+var (
+ ledispder = &Provider{}
+ c *ledis.DB
+)
// SessionStore ledis session store
type SessionStore struct {
@@ -97,27 +99,33 @@ func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error {
}
cfg := new(config.Config)
cfg.DataDir = lp.savePath
- nowLedis, err := ledis.Open(cfg)
- c, err = nowLedis.Select(lp.db)
+
+ var ledisInstance *ledis.Ledis
+ ledisInstance, err = ledis.Open(cfg)
if err != nil {
- println(err)
- return nil
+ return err
}
- return nil
+ c, err = ledisInstance.Select(lp.db)
+ return err
}
// SessionRead read ledis session by sid
func (lp *Provider) SessionRead(sid string) (session.Store, error) {
- kvs, err := c.Get([]byte(sid))
- var kv map[interface{}]interface{}
+ var (
+ kv map[interface{}]interface{}
+ err error
+ )
+
+ kvs, _ := c.Get([]byte(sid))
+
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
- kv, err = session.DecodeGob(kvs)
- if err != nil {
+ if kv, err = session.DecodeGob(kvs); err != nil {
return nil, err
}
}
+
ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime}
return ls, nil
}
@@ -125,10 +133,7 @@ func (lp *Provider) SessionRead(sid string) (session.Store, error) {
// SessionExist check ledis session exist by sid
func (lp *Provider) SessionExist(sid string) bool {
count, _ := c.Exists([]byte(sid))
- if count == 0 {
- return false
- }
- return true
+ return !(count == 0)
}
// SessionRegenerate generate new sid for ledis session
@@ -145,18 +150,7 @@ func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error)
c.Set([]byte(sid), data)
c.Expire([]byte(sid), lp.maxlifetime)
}
- kvs, err := c.Get([]byte(sid))
- var kv map[interface{}]interface{}
- if len(kvs) == 0 {
- kv = make(map[interface{}]interface{})
- } else {
- kv, err = session.DecodeGob([]byte(kvs))
- if err != nil {
- return nil, err
- }
- }
- ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime}
- return ls, nil
+ return lp.SessionRead(sid)
}
// SessionDestroy delete ledis session by id
@@ -167,7 +161,6 @@ func (lp *Provider) SessionDestroy(sid string) error {
// SessionGC Impelment method, no used.
func (lp *Provider) SessionGC() {
- return
}
// SessionAll return all active session
diff --git a/session/memcache/sess_memcache.go b/session/memcache/sess_memcache.go
index f1069bc9..755979c4 100644
--- a/session/memcache/sess_memcache.go
+++ b/session/memcache/sess_memcache.go
@@ -205,11 +205,7 @@ func (rp *MemProvider) SessionDestroy(sid string) error {
}
}
- err := client.Delete(sid)
- if err != nil {
- return err
- }
- return nil
+ return client.Delete(sid)
}
func (rp *MemProvider) connectInit() error {
@@ -219,7 +215,6 @@ func (rp *MemProvider) connectInit() error {
// SessionGC Impelment method, no used.
func (rp *MemProvider) SessionGC() {
- return
}
// SessionAll return all activeSession
diff --git a/session/mysql/sess_mysql.go b/session/mysql/sess_mysql.go
index 838ec669..4c9251e7 100644
--- a/session/mysql/sess_mysql.go
+++ b/session/mysql/sess_mysql.go
@@ -170,10 +170,7 @@ func (mp *Provider) SessionExist(sid string) bool {
row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
- if err == sql.ErrNoRows {
- return false
- }
- return true
+ return !(err == sql.ErrNoRows)
}
// SessionRegenerate generate new sid for mysql session
@@ -212,7 +209,6 @@ func (mp *Provider) SessionGC() {
c := mp.connectInit()
c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime)
c.Close()
- return
}
// SessionAll count values in mysql session
diff --git a/session/postgres/sess_postgresql.go b/session/postgres/sess_postgresql.go
index 73f9c13a..ffc27def 100644
--- a/session/postgres/sess_postgresql.go
+++ b/session/postgres/sess_postgresql.go
@@ -184,11 +184,7 @@ func (mp *Provider) SessionExist(sid string) bool {
row := c.QueryRow("select session_data from session where session_key=$1", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
-
- if err == sql.ErrNoRows {
- return false
- }
- return true
+ return !(err == sql.ErrNoRows)
}
// SessionRegenerate generate new sid for postgresql session
@@ -228,7 +224,6 @@ func (mp *Provider) SessionGC() {
c := mp.connectInit()
c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime)
c.Close()
- return
}
// SessionAll count values in postgresql session
diff --git a/session/redis/sess_redis.go b/session/redis/sess_redis.go
index c46fa7cd..d0424515 100644
--- a/session/redis/sess_redis.go
+++ b/session/redis/sess_redis.go
@@ -128,7 +128,7 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
}
if len(configs) > 1 {
poolsize, err := strconv.Atoi(configs[1])
- if err != nil || poolsize <= 0 {
+ if err != nil || poolsize < 0 {
rp.poolsize = MaxPoolSize
} else {
rp.poolsize = poolsize
@@ -155,7 +155,7 @@ func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error {
return nil, err
}
if rp.password != "" {
- if _, err := c.Do("AUTH", rp.password); err != nil {
+ if _, err = c.Do("AUTH", rp.password); err != nil {
c.Close()
return nil, err
}
@@ -176,13 +176,16 @@ func (rp *Provider) SessionRead(sid string) (session.Store, error) {
c := rp.poollist.Get()
defer c.Close()
- kvs, err := redis.String(c.Do("GET", sid))
var kv map[interface{}]interface{}
+
+ kvs, err := redis.String(c.Do("GET", sid))
+ if err != nil && err != redis.ErrNil {
+ return nil, err
+ }
if len(kvs) == 0 {
kv = make(map[interface{}]interface{})
} else {
- kv, err = session.DecodeGob([]byte(kvs))
- if err != nil {
+ if kv, err = session.DecodeGob([]byte(kvs)); err != nil {
return nil, err
}
}
@@ -216,20 +219,7 @@ func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error)
c.Do("RENAME", oldsid, sid)
c.Do("EXPIRE", sid, rp.maxlifetime)
}
-
- kvs, err := redis.String(c.Do("GET", sid))
- var kv map[interface{}]interface{}
- if len(kvs) == 0 {
- kv = make(map[interface{}]interface{})
- } else {
- kv, err = session.DecodeGob([]byte(kvs))
- if err != nil {
- return nil, err
- }
- }
-
- rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime}
- return rs, nil
+ return rp.SessionRead(sid)
}
// SessionDestroy delete redis session by id
@@ -243,7 +233,6 @@ func (rp *Provider) SessionDestroy(sid string) error {
// SessionGC Impelment method, no used.
func (rp *Provider) SessionGC() {
- return
}
// SessionAll return all activeSession
diff --git a/session/sess_cookie.go b/session/sess_cookie.go
index 3fefa360..145e53c9 100644
--- a/session/sess_cookie.go
+++ b/session/sess_cookie.go
@@ -74,21 +74,16 @@ func (st *CookieSessionStore) SessionID() string {
// SessionRelease Write cookie session to http response cookie
func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) {
- str, err := encodeCookie(cookiepder.block,
- cookiepder.config.SecurityKey,
- cookiepder.config.SecurityName,
- st.values)
- if err != nil {
- return
+ encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values)
+ if err == nil {
+ cookie := &http.Cookie{Name: cookiepder.config.CookieName,
+ Value: url.QueryEscape(encodedCookie),
+ Path: "/",
+ HttpOnly: true,
+ Secure: cookiepder.config.Secure,
+ MaxAge: cookiepder.config.Maxage}
+ http.SetCookie(w, cookie)
}
- cookie := &http.Cookie{Name: cookiepder.config.CookieName,
- Value: url.QueryEscape(str),
- Path: "/",
- HttpOnly: true,
- Secure: cookiepder.config.Secure,
- MaxAge: cookiepder.config.Maxage}
- http.SetCookie(w, cookie)
- return
}
type cookieConfig struct {
@@ -166,7 +161,6 @@ func (pder *CookieProvider) SessionDestroy(sid string) error {
// SessionGC Implement method, no used.
func (pder *CookieProvider) SessionGC() {
- return
}
// SessionAll Implement method, return 0.
diff --git a/session/sess_file.go b/session/sess_file.go
index e8484db3..0758d6b4 100644
--- a/session/sess_file.go
+++ b/session/sess_file.go
@@ -15,8 +15,7 @@
package session
import (
- "errors"
- "io"
+ "fmt"
"io/ioutil"
"net/http"
"os"
@@ -88,9 +87,16 @@ func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
var f *os.File
if err == nil {
f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777)
+ if err != nil {
+ SLogger.Println(err)
+ return
+ }
} else if os.IsNotExist(err) {
f, err = os.Create(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid))
-
+ if err != nil {
+ SLogger.Println(err)
+ return
+ }
} else {
return
}
@@ -135,6 +141,9 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) {
} else {
return nil, err
}
+
+ defer f.Close()
+
os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now())
var kv map[interface{}]interface{}
b, err := ioutil.ReadAll(f)
@@ -149,7 +158,7 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) {
return nil, err
}
}
- f.Close()
+
ss := &FileSessionStore{sid: sid, values: kv}
return ss, nil
}
@@ -161,10 +170,7 @@ func (fp *FileProvider) SessionExist(sid string) bool {
defer filepder.lock.Unlock()
_, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
- if err == nil {
- return true
- }
- return false
+ return err == nil
}
// SessionDestroy Remove all files in this save path
@@ -204,49 +210,58 @@ func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) {
filepder.lock.Lock()
defer filepder.lock.Unlock()
- err := os.MkdirAll(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])), 0777)
- if err != nil {
- SLogger.Println(err.Error())
- }
- err = os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777)
- if err != nil {
- SLogger.Println(err.Error())
- }
- _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
- var newf *os.File
+ oldPath := path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]))
+ oldSidFile := path.Join(oldPath, oldsid)
+ newPath := path.Join(fp.savePath, string(sid[0]), string(sid[1]))
+ newSidFile := path.Join(newPath, sid)
+
+ // new sid file is exist
+ _, err := os.Stat(newSidFile)
if err == nil {
- return nil, errors.New("newsid exist")
- } else if os.IsNotExist(err) {
- newf, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
+ return nil, fmt.Errorf("newsid %s exist", newSidFile)
}
- _, err = os.Stat(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]), oldsid))
- var f *os.File
- if err == nil {
- f, err = os.OpenFile(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]), oldsid), os.O_RDWR, 0777)
- io.Copy(newf, f)
- } else if os.IsNotExist(err) {
- newf, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
- } else {
- return nil, err
- }
- f.Close()
- os.Remove(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])))
- os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now())
- var kv map[interface{}]interface{}
- b, err := ioutil.ReadAll(newf)
+ err = os.MkdirAll(newPath, 0777)
if err != nil {
- return nil, err
+ SLogger.Println(err.Error())
}
- if len(b) == 0 {
- kv = make(map[interface{}]interface{})
- } else {
- kv, err = DecodeGob(b)
+
+ // if old sid file exist
+ // 1.read and parse file content
+ // 2.write content to new sid file
+ // 3.remove old sid file, change new sid file atime and ctime
+ // 4.return FileSessionStore
+ _, err = os.Stat(oldSidFile)
+ if err == nil {
+ b, err := ioutil.ReadFile(oldSidFile)
if err != nil {
return nil, err
}
+
+ var kv map[interface{}]interface{}
+ if len(b) == 0 {
+ kv = make(map[interface{}]interface{})
+ } else {
+ kv, err = DecodeGob(b)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ ioutil.WriteFile(newSidFile, b, 0777)
+ os.Remove(oldSidFile)
+ os.Chtimes(newSidFile, time.Now(), time.Now())
+ ss := &FileSessionStore{sid: sid, values: kv}
+ return ss, nil
}
- ss := &FileSessionStore{sid: sid, values: kv}
+
+ // if old sid file not exist, just create new sid file and return
+ newf, err := os.Create(newSidFile)
+ if err != nil {
+ return nil, err
+ }
+ newf.Close()
+ ss := &FileSessionStore{sid: sid, values: make(map[interface{}]interface{})}
return ss, nil
}
diff --git a/session/sess_test.go b/session/sess_test.go
index b40865f3..906abec2 100644
--- a/session/sess_test.go
+++ b/session/sess_test.go
@@ -74,8 +74,7 @@ func TestCookieEncodeDecode(t *testing.T) {
if err != nil {
t.Fatal("encodeCookie:", err)
}
- dst := make(map[interface{}]interface{})
- dst, err = decodeCookie(block, hashKey, securityName, str, 3600)
+ dst, err := decodeCookie(block, hashKey, securityName, str, 3600)
if err != nil {
t.Fatal("decodeCookie", err)
}
@@ -115,7 +114,7 @@ func TestParseConfig(t *testing.T) {
if cf2.Gclifetime != 3600 {
t.Fatal("parseconfig get gclifetime error")
}
- if cf2.EnableSetCookie != false {
+ if cf2.EnableSetCookie {
t.Fatal("parseconfig get enableSetCookie error")
}
cconfig := new(cookieConfig)
diff --git a/session/session.go b/session/session.go
index fb4b2821..cf647521 100644
--- a/session/session.go
+++ b/session/session.go
@@ -81,6 +81,7 @@ func Register(name string, provide Provider) {
provides[name] = provide
}
+// ManagerConfig define the session config
type ManagerConfig struct {
CookieName string `json:"cookieName"`
EnableSetCookie bool `json:"enableSetCookie,omitempty"`
@@ -92,9 +93,9 @@ type ManagerConfig struct {
ProviderConfig string `json:"providerConfig"`
Domain string `json:"domain"`
SessionIDLength int64 `json:"sessionIDLength"`
- EnableSidInHttpHeader bool `json:"enableSidInHttpHeader"`
- SessionNameInHttpHeader string `json:"sessionNameInHttpHeader"`
- EnableSidInUrlQuery bool `json:"enableSidInUrlQuery"`
+ EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"`
+ SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"`
+ EnableSidInURLQuery bool `json:"EnableSidInURLQuery"`
}
// Manager contains Provider and its configuration.
@@ -125,14 +126,14 @@ func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) {
cf.Maxlifetime = cf.Gclifetime
}
- if cf.EnableSidInHttpHeader {
- if cf.SessionNameInHttpHeader == "" {
- panic(errors.New("SessionNameInHttpHeader is empty"))
+ if cf.EnableSidInHTTPHeader {
+ if cf.SessionNameInHTTPHeader == "" {
+ panic(errors.New("SessionNameInHTTPHeader is empty"))
}
- strMimeHeader := textproto.CanonicalMIMEHeaderKey(cf.SessionNameInHttpHeader)
- if cf.SessionNameInHttpHeader != strMimeHeader {
- strErrMsg := "SessionNameInHttpHeader (" + cf.SessionNameInHttpHeader + ") has the wrong format, it should be like this : " + strMimeHeader
+ strMimeHeader := textproto.CanonicalMIMEHeaderKey(cf.SessionNameInHTTPHeader)
+ if cf.SessionNameInHTTPHeader != strMimeHeader {
+ strErrMsg := "SessionNameInHTTPHeader (" + cf.SessionNameInHTTPHeader + ") has the wrong format, it should be like this : " + strMimeHeader
panic(errors.New(strErrMsg))
}
}
@@ -163,7 +164,7 @@ func (manager *Manager) getSid(r *http.Request) (string, error) {
cookie, errs := r.Cookie(manager.config.CookieName)
if errs != nil || cookie.Value == "" {
var sid string
- if manager.config.EnableSidInUrlQuery {
+ if manager.config.EnableSidInURLQuery {
errs := r.ParseForm()
if errs != nil {
return "", errs
@@ -173,8 +174,8 @@ func (manager *Manager) getSid(r *http.Request) (string, error) {
}
// if not found in Cookie / param, then read it from request headers
- if manager.config.EnableSidInHttpHeader && sid == "" {
- sids, isFound := r.Header[manager.config.SessionNameInHttpHeader]
+ if manager.config.EnableSidInHTTPHeader && sid == "" {
+ sids, isFound := r.Header[manager.config.SessionNameInHTTPHeader]
if isFound && len(sids) != 0 {
return sids[0], nil
}
@@ -226,9 +227,9 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
}
r.AddCookie(cookie)
- if manager.config.EnableSidInHttpHeader {
- r.Header.Set(manager.config.SessionNameInHttpHeader, sid)
- w.Header().Set(manager.config.SessionNameInHttpHeader, sid)
+ if manager.config.EnableSidInHTTPHeader {
+ r.Header.Set(manager.config.SessionNameInHTTPHeader, sid)
+ w.Header().Set(manager.config.SessionNameInHTTPHeader, sid)
}
return
@@ -236,9 +237,9 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
// SessionDestroy Destroy session by its id in http request cookie.
func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
- if manager.config.EnableSidInHttpHeader {
- r.Header.Del(manager.config.SessionNameInHttpHeader)
- w.Header().Del(manager.config.SessionNameInHttpHeader)
+ if manager.config.EnableSidInHTTPHeader {
+ r.Header.Del(manager.config.SessionNameInHTTPHeader)
+ w.Header().Del(manager.config.SessionNameInHTTPHeader)
}
cookie, err := r.Cookie(manager.config.CookieName)
@@ -306,9 +307,9 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque
}
r.AddCookie(cookie)
- if manager.config.EnableSidInHttpHeader {
- r.Header.Set(manager.config.SessionNameInHttpHeader, sid)
- w.Header().Set(manager.config.SessionNameInHttpHeader, sid)
+ if manager.config.EnableSidInHTTPHeader {
+ r.Header.Set(manager.config.SessionNameInHTTPHeader, sid)
+ w.Header().Set(manager.config.SessionNameInHTTPHeader, sid)
}
return
@@ -328,7 +329,7 @@ func (manager *Manager) sessionID() (string, error) {
b := make([]byte, manager.config.SessionIDLength)
n, err := rand.Read(b)
if n != len(b) || err != nil {
- return "", fmt.Errorf("Could not successfully read from the system CSPRNG.")
+ return "", fmt.Errorf("Could not successfully read from the system CSPRNG")
}
return hex.EncodeToString(b), nil
}
diff --git a/session/ssdb/sess_ssdb.go b/session/ssdb/sess_ssdb.go
index 4dcf160a..de0c6360 100644
--- a/session/ssdb/sess_ssdb.go
+++ b/session/ssdb/sess_ssdb.go
@@ -11,44 +11,40 @@ import (
"github.com/ssdb/gossdb/ssdb"
)
-var ssdbProvider = &SsdbProvider{}
+var ssdbProvider = &Provider{}
-type SsdbProvider struct {
+// Provider holds ssdb client and configs
+type Provider struct {
client *ssdb.Client
host string
port int
maxLifetime int64
}
-func (p *SsdbProvider) connectInit() error {
+func (p *Provider) connectInit() error {
var err error
if p.host == "" || p.port == 0 {
return errors.New("SessionInit First")
}
p.client, err = ssdb.Connect(p.host, p.port)
- if err != nil {
- return err
- }
- return nil
+ return err
}
-func (p *SsdbProvider) SessionInit(maxLifetime int64, savePath string) error {
- var e error = nil
+// SessionInit init the ssdb with the config
+func (p *Provider) SessionInit(maxLifetime int64, savePath string) error {
p.maxLifetime = maxLifetime
address := strings.Split(savePath, ":")
p.host = address[0]
- p.port, e = strconv.Atoi(address[1])
- if e != nil {
- return e
- }
- err := p.connectInit()
- if err != nil {
+
+ var err error
+ if p.port, err = strconv.Atoi(address[1]); err != nil {
return err
}
- return nil
+ return p.connectInit()
}
-func (p *SsdbProvider) SessionRead(sid string) (session.Store, error) {
+// SessionRead return a ssdb client session Store
+func (p *Provider) SessionRead(sid string) (session.Store, error) {
if p.client == nil {
if err := p.connectInit(); err != nil {
return nil, err
@@ -71,7 +67,8 @@ func (p *SsdbProvider) SessionRead(sid string) (session.Store, error) {
return rs, nil
}
-func (p *SsdbProvider) SessionExist(sid string) bool {
+// SessionExist judged whether sid is exist in session
+func (p *Provider) SessionExist(sid string) bool {
if p.client == nil {
if err := p.connectInit(); err != nil {
panic(err)
@@ -85,9 +82,10 @@ func (p *SsdbProvider) SessionExist(sid string) bool {
return false
}
return true
-
}
-func (p *SsdbProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
+
+// SessionRegenerate regenerate session with new sid and delete oldsid
+func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
//conn.Do("setx", key, v, ttl)
if p.client == nil {
if err := p.connectInit(); err != nil {
@@ -119,27 +117,27 @@ func (p *SsdbProvider) SessionRegenerate(oldsid, sid string) (session.Store, err
return rs, nil
}
-func (p *SsdbProvider) SessionDestroy(sid string) error {
+// SessionDestroy destroy the sid
+func (p *Provider) SessionDestroy(sid string) error {
if p.client == nil {
if err := p.connectInit(); err != nil {
return err
}
}
_, err := p.client.Del(sid)
- if err != nil {
- return err
- }
- return nil
+ return err
}
-func (p *SsdbProvider) SessionGC() {
- return
+// SessionGC not implemented
+func (p *Provider) SessionGC() {
}
-func (p *SsdbProvider) SessionAll() int {
+// SessionAll not implemented
+func (p *Provider) SessionAll() int {
return 0
}
+// SessionStore holds the session information which stored in ssdb
type SessionStore struct {
sid string
lock sync.RWMutex
@@ -148,12 +146,15 @@ type SessionStore struct {
client *ssdb.Client
}
+// Set the key and value
func (s *SessionStore) Set(key, value interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
s.values[key] = value
return nil
}
+
+// Get return the value by the key
func (s *SessionStore) Get(key interface{}) interface{} {
s.lock.Lock()
defer s.lock.Unlock()
@@ -163,30 +164,36 @@ func (s *SessionStore) Get(key interface{}) interface{} {
return nil
}
+// Delete the key in session store
func (s *SessionStore) Delete(key interface{}) error {
s.lock.Lock()
defer s.lock.Unlock()
delete(s.values, key)
return nil
}
+
+// Flush delete all keys and values
func (s *SessionStore) Flush() error {
s.lock.Lock()
defer s.lock.Unlock()
s.values = make(map[interface{}]interface{})
return nil
}
+
+// SessionID return the sessionID
func (s *SessionStore) SessionID() string {
return s.sid
}
+// SessionRelease Store the keyvalues into ssdb
func (s *SessionStore) SessionRelease(w http.ResponseWriter) {
b, err := session.EncodeGob(s.values)
if err != nil {
return
}
s.client.Do("setx", s.sid, string(b), s.maxLifetime)
-
}
+
func init() {
session.Register("ssdb", ssdbProvider)
}
diff --git a/staticfile.go b/staticfile.go
index b7be24f3..bbb2a1fb 100644
--- a/staticfile.go
+++ b/staticfile.go
@@ -90,8 +90,6 @@ func serverStaticRouter(ctx *context.Context) {
}
http.ServeContent(ctx.ResponseWriter, ctx.Request, filePath, sch.modTime, sch)
- return
-
}
type serveContentHolder struct {
@@ -109,14 +107,14 @@ var (
func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, error) {
mapKey := acceptEncoding + ":" + filePath
mapLock.RLock()
- mapFile, _ := staticFileMap[mapKey]
+ mapFile := staticFileMap[mapKey]
mapLock.RUnlock()
if isOk(mapFile, fi) {
return mapFile.encoding != "", mapFile.encoding, mapFile, nil
}
mapLock.Lock()
defer mapLock.Unlock()
- if mapFile, _ = staticFileMap[mapKey]; !isOk(mapFile, fi) {
+ if mapFile = staticFileMap[mapKey]; !isOk(mapFile, fi) {
file, err := os.Open(filePath)
if err != nil {
return false, "", nil, err
diff --git a/swagger/swagger.go b/swagger/swagger.go
index e0ac5cf5..035d5a49 100644
--- a/swagger/swagger.go
+++ b/swagger/swagger.go
@@ -22,19 +22,19 @@ package swagger
// Swagger list the resource
type Swagger struct {
- SwaggerVersion string `json:"swagger,omitempty" yaml:"swagger,omitempty"`
- Infos Information `json:"info" yaml:"info"`
- Host string `json:"host,omitempty" yaml:"host,omitempty"`
- BasePath string `json:"basePath,omitempty" yaml:"basePath,omitempty"`
- Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"`
- Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"`
- Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"`
- Paths map[string]*Item `json:"paths" yaml:"paths"`
- Definitions map[string]Schema `json:"definitions,omitempty" yaml:"definitions,omitempty"`
- SecurityDefinitions map[string]Security `json:"securityDefinitions,omitempty" yaml:"securityDefinitions,omitempty"`
- Security map[string][]string `json:"security,omitempty" yaml:"security,omitempty"`
- Tags []Tag `json:"tags,omitempty" yaml:"tags,omitempty"`
- ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"`
+ SwaggerVersion string `json:"swagger,omitempty" yaml:"swagger,omitempty"`
+ Infos Information `json:"info" yaml:"info"`
+ Host string `json:"host,omitempty" yaml:"host,omitempty"`
+ BasePath string `json:"basePath,omitempty" yaml:"basePath,omitempty"`
+ Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"`
+ Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"`
+ Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"`
+ Paths map[string]*Item `json:"paths" yaml:"paths"`
+ Definitions map[string]Schema `json:"definitions,omitempty" yaml:"definitions,omitempty"`
+ SecurityDefinitions map[string]Security `json:"securityDefinitions,omitempty" yaml:"securityDefinitions,omitempty"`
+ Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"`
+ Tags []Tag `json:"tags,omitempty" yaml:"tags,omitempty"`
+ ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"`
}
// Information Provides metadata about the API. The metadata can be used by the clients if needed.
@@ -75,16 +75,17 @@ type Item struct {
// Operation Describes a single API operation on a path.
type Operation struct {
- Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"`
- Summary string `json:"summary,omitempty" yaml:"summary,omitempty"`
- Description string `json:"description,omitempty" yaml:"description,omitempty"`
- OperationID string `json:"operationId,omitempty" yaml:"operationId,omitempty"`
- Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"`
- Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"`
- Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"`
- Parameters []Parameter `json:"parameters,omitempty" yaml:"parameters,omitempty"`
- Responses map[string]Response `json:"responses,omitempty" yaml:"responses,omitempty"`
- Deprecated bool `json:"deprecated,omitempty" yaml:"deprecated,omitempty"`
+ Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"`
+ Summary string `json:"summary,omitempty" yaml:"summary,omitempty"`
+ Description string `json:"description,omitempty" yaml:"description,omitempty"`
+ OperationID string `json:"operationId,omitempty" yaml:"operationId,omitempty"`
+ Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"`
+ Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"`
+ Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"`
+ Parameters []Parameter `json:"parameters,omitempty" yaml:"parameters,omitempty"`
+ Responses map[string]Response `json:"responses,omitempty" yaml:"responses,omitempty"`
+ Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"`
+ Deprecated bool `json:"deprecated,omitempty" yaml:"deprecated,omitempty"`
}
// Parameter Describes a single operation parameter.
@@ -100,7 +101,7 @@ type Parameter struct {
Default interface{} `json:"default,omitempty" yaml:"default,omitempty"`
}
-// A limited subset of JSON-Schema's items object. It is used by parameter definitions that are not located in "body".
+// ParameterItems A limited subset of JSON-Schema's items object. It is used by parameter definitions that are not located in "body".
// http://swagger.io/specification/#itemsObject
type ParameterItems struct {
Type string `json:"type,omitempty" yaml:"type,omitempty"`
diff --git a/template.go b/template.go
index 5415f5f0..d4859cd7 100644
--- a/template.go
+++ b/template.go
@@ -31,10 +31,11 @@ import (
)
var (
- beegoTplFuncMap = make(template.FuncMap)
- // beeTemplates caching map and supported template file extensions.
- beeTemplates = make(map[string]*template.Template)
- templatesLock sync.RWMutex
+ beegoTplFuncMap = make(template.FuncMap)
+ beeViewPathTemplateLocked = false
+ // beeViewPathTemplates caching map and supported template file extensions per view
+ beeViewPathTemplates = make(map[string]map[string]*template.Template)
+ templatesLock sync.RWMutex
// beeTemplateExt stores the template extension which will build
beeTemplateExt = []string{"tpl", "html"}
// beeTemplatePreprocessors stores associations of extension -> preprocessor handler
@@ -45,23 +46,33 @@ var (
// writing the output to wr.
// A template will be executed safely in parallel.
func ExecuteTemplate(wr io.Writer, name string, data interface{}) error {
+ return ExecuteViewPathTemplate(wr, name, BConfig.WebConfig.ViewsPath, data)
+}
+
+// ExecuteViewPathTemplate applies the template with name and from specific viewPath to the specified data object,
+// writing the output to wr.
+// A template will be executed safely in parallel.
+func ExecuteViewPathTemplate(wr io.Writer, name string, viewPath string, data interface{}) error {
if BConfig.RunMode == DEV {
templatesLock.RLock()
defer templatesLock.RUnlock()
}
- if t, ok := beeTemplates[name]; ok {
- var err error
- if t.Lookup(name) != nil {
- err = t.ExecuteTemplate(wr, name, data)
- } else {
- err = t.Execute(wr, data)
+ if beeTemplates, ok := beeViewPathTemplates[viewPath]; ok {
+ if t, ok := beeTemplates[name]; ok {
+ var err error
+ if t.Lookup(name) != nil {
+ err = t.ExecuteTemplate(wr, name, data)
+ } else {
+ err = t.Execute(wr, data)
+ }
+ if err != nil {
+ logs.Trace("template Execute err:", err)
+ }
+ return err
}
- if err != nil {
- logs.Trace("template Execute err:", err)
- }
- return err
+ panic("can't find templatefile in the path:" + viewPath + "/" + name)
}
- panic("can't find templatefile in the path:" + name)
+ panic("Unknown view path:" + viewPath)
}
func init() {
@@ -149,6 +160,24 @@ func AddTemplateExt(ext string) {
beeTemplateExt = append(beeTemplateExt, ext)
}
+// AddViewPath adds a new path to the supported view paths.
+//Can later be used by setting a controller ViewPath to this folder
+//will panic if called after beego.Run()
+func AddViewPath(viewPath string) error {
+ if beeViewPathTemplateLocked {
+ if _, exist := beeViewPathTemplates[viewPath]; exist {
+ return nil //Ignore if viewpath already exists
+ }
+ panic("Can not add new view paths after beego.Run()")
+ }
+ beeViewPathTemplates[viewPath] = make(map[string]*template.Template)
+ return BuildTemplate(viewPath)
+}
+
+func lockViewPaths() {
+ beeViewPathTemplateLocked = true
+}
+
// BuildTemplate will build all template files in a directory.
// it makes beego can render any template file in view directory.
func BuildTemplate(dir string, files ...string) error {
@@ -158,6 +187,10 @@ func BuildTemplate(dir string, files ...string) error {
}
return errors.New("dir open err")
}
+ beeTemplates, ok := beeViewPathTemplates[dir]
+ if !ok {
+ panic("Unknown view path: " + dir)
+ }
self := &templateFile{
root: dir,
files: make(map[string][]string),
@@ -184,7 +217,7 @@ func BuildTemplate(dir string, files ...string) error {
t, err = getTemplate(self.root, file, v...)
}
if err != nil {
- logs.Trace("parse template err:", file, err)
+ logs.Error("parse template err:", file, err)
} else {
beeTemplates[file] = t
}
@@ -197,9 +230,12 @@ func BuildTemplate(dir string, files ...string) error {
func getTplDeep(root, file, parent string, t *template.Template) (*template.Template, [][]string, error) {
var fileAbsPath string
+ var rParent string
if filepath.HasPrefix(file, "../") {
+ rParent = filepath.Join(filepath.Dir(parent), file)
fileAbsPath = filepath.Join(root, filepath.Dir(parent), file)
} else {
+ rParent = file
fileAbsPath = filepath.Join(root, file)
}
if e := utils.FileExists(fileAbsPath); !e {
@@ -224,7 +260,7 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp
if !HasTemplateExt(m[1]) {
continue
}
- t, _, err = getTplDeep(root, m[1], file, t)
+ _, _, err = getTplDeep(root, m[1], rParent, t)
if err != nil {
return nil, [][]string{}, err
}
@@ -263,7 +299,7 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others
t, subMods1, err = getTplDeep(root, otherFile, "", t)
if err != nil {
logs.Trace("template parse file err:", err)
- } else if subMods1 != nil && len(subMods1) > 0 {
+ } else if len(subMods1) > 0 {
t, err = _getTemplate(t, root, subMods1, others...)
}
break
@@ -271,8 +307,9 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others
}
//second check define
for _, otherFile := range others {
+ var data []byte
fileAbsPath := filepath.Join(root, otherFile)
- data, err := ioutil.ReadFile(fileAbsPath)
+ data, err = ioutil.ReadFile(fileAbsPath)
if err != nil {
continue
}
@@ -284,7 +321,7 @@ func _getTemplate(t0 *template.Template, root string, subMods [][]string, others
t, subMods1, err = getTplDeep(root, otherFile, "", t)
if err != nil {
logs.Trace("template parse file err:", err)
- } else if subMods1 != nil && len(subMods1) > 0 {
+ } else if len(subMods1) > 0 {
t, err = _getTemplate(t, root, subMods1, others...)
}
break
@@ -328,6 +365,7 @@ func DelStaticPath(url string) *App {
return BeeApp
}
+// AddTemplateEngine add a new templatePreProcessor which support extension
func AddTemplateEngine(extension string, fn templatePreProcessor) *App {
AddTemplateExt(extension)
beeTemplateEngines[extension] = fn
diff --git a/template_test.go b/template_test.go
index 4f13736c..2153ef72 100644
--- a/template_test.go
+++ b/template_test.go
@@ -15,6 +15,7 @@
package beego
import (
+ "bytes"
"os"
"path/filepath"
"testing"
@@ -67,9 +68,10 @@ func TestTemplate(t *testing.T) {
f.Close()
}
}
- if err := BuildTemplate(dir); err != nil {
+ if err := AddViewPath(dir); err != nil {
t.Fatal(err)
}
+ beeTemplates := beeViewPathTemplates[dir]
if len(beeTemplates) != 3 {
t.Fatalf("should be 3 but got %v", len(beeTemplates))
}
@@ -103,6 +105,12 @@ var user = `
func TestRelativeTemplate(t *testing.T) {
dir := "_beeTmp"
+
+ //Just add dir to known viewPaths
+ if err := AddViewPath(dir); err != nil {
+ t.Fatal(err)
+ }
+
files := []string{
"easyui/public/menu.tpl",
"easyui/rbac/user.tpl",
@@ -126,6 +134,7 @@ func TestRelativeTemplate(t *testing.T) {
if err := BuildTemplate(dir, files[1]); err != nil {
t.Fatal(err)
}
+ beeTemplates := beeViewPathTemplates[dir]
if err := beeTemplates["easyui/rbac/user.tpl"].ExecuteTemplate(os.Stdout, "easyui/rbac/user.tpl", nil); err != nil {
t.Fatal(err)
}
@@ -134,3 +143,116 @@ func TestRelativeTemplate(t *testing.T) {
}
os.RemoveAll(dir)
}
+
+var add = `{{ template "layout_blog.tpl" . }}
+{{ define "css" }}
+
+{{ end}}
+
+
+{{ define "content" }}
+
{{ .Title }}
+
This is SomeVar: {{ .SomeVar }}
+{{ end }}
+
+{{ define "js" }}
+
+{{ end}}`
+
+var layoutBlog = `
+
+
+ Lin Li
+
+
+
+
+ {{ block "css" . }}{{ end }}
+
+
+
+